FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

TensorFlow+Kerasに入門(3. Keras2cppでCIFAR10のモデルを変換してみる)

f:id:msyksphinz:20180701195704p:plain

FPGAの部屋のmarseeさんの記事を見て、TensorFlow+Kerasに入門してみた。 というかmarseeさんの記事で掲載されているソースコードをほとんどCopy & Pasteして実行してみているだけだが...

TensorFlow+KerasでCifar10を学習するサンプルプログラムを実行して、そこから得られたモデルを使ってKeras2cppでモデルの変換を行ってみたい。

最終的な目標は、Keras2cppを使ってC++のコードを出力し、それをネイティブC++環境で実行することだ。

まずは Kerasのサンプルコードを使用してcifar10のCNNのトレーニングを実行した。VirtualBox上のJupyter Notebook、さらにGPUを使わずに実行したので数時間はかかってしまった。 生成したモデルとトレーニングデータはファイルに保存しておくことにした。

github.com

f:id:msyksphinz:20180701193943p:plain

学習済みのモデルは、モデルファイルはJSON形式、重みパラメータはh5ファイルとして保存した。

from keras.models import load_model

model.save('cifar10_cnn_model.h5') # creates a HDF5 file 'my_model.h5'
with open('cifar10_cnn_model.json', 'w') as fout:
    fout.write(model.to_json())

保存したモデルを使って、keras2cppで動作するためのモデルに変換する。これでdumped.nnetが生成される。

$ python ./dump_to_simple_cpp.py -a ../keras_model/cifar10_cnn_model.json -w ../keras_model/cifar10_cnn_model.h5 -o dumped.nnet
Using TensorFlow backend.
Read architecture from ../keras_model/cifar10_cnn_model.json
Read weights from ../keras_model/cifar10_cnn_model.h5
Writing to dumped.nnet
2018-07-01 19:42:53.619547: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2

MNISTのモデルでもそうだが、keras2cppで推論を実行するためにはテストデータを用意しなければならない。これはcifar10のデータをダウンロードして生成した。

    tr_data10, tr_labels10, te_data10, te_labels10, label_names10 = get_cifar10(datapath)

    print("3 32 32")
    print("%f. " % (te_data10[0][0] / 10.0))
    for dim in range(3):
        for y in range(32):
            print("["),
            for x in range (32):
                print("%f. " % (int(te_data10[0][dim * 32 * 32 + y * 32 + x]) / 256.0)),
            print("]")

以下のようにしてcifar10_test_data.datを生成する。

python input_cifar.py > cifar10_test_data.dat

中身は以下のようになっている。3×32×32の入力データの配列を作成した。

3 32 32
[ 0.617188.  0.621094.  0.644531.  0.648438.  0.625000.  0.609375.  0.632812.  0.621094.  0.617188.  0.621094.  0.628906.  0.625000.  0.628906.  0.648438.  0.660156.  0.664062.  0.652344.  0.632812.  0.625000.  0.625000.  0.609375.  0.582031.  0.585938.  0.578125.  0.582031.  0.558594.  0.546875.  0.550781.  0.558594.  0.535156.  0.492188.  0.453125.  ]
[ 0.593750.  0.589844.  0.621094.  0.648438.  0.632812.  0.625000.  0.640625.  0.632812.  0.636719.  0.609375.  0.605469.  0.621094.  0.636719.  0.664062.  0.667969.  0.667969.  0.660156.  0.625000.  0.601562.  0.589844.  0.566406.  0.542969.  0.546875.  0.550781.  0.582031.  0.574219.  0.566406.  0.554688.  0.558594.  0.531250.  0.488281.  0.464844.  ]
[ 0.589844.  0.589844.  0.617188.  0.652344.  0.625000.  0.636719.  0.644531.  0.644531.  0.636719.  0.632812.  0.617188.  0.613281.  0.628906.  0.648438.  0.652344.  0.660156.  0.664062.  0.621094.  0.566406.  0.472656.  0.429688.  0.382812.  0.394531.  0.445312.  0.468750.  0.523438.  0.558594.  0.546875.  0.554688.  0.542969.  0.507812.  0.468750.  ]
[ 0.605469.  0.605469.  0.625000.  0.679688.  0.652344.  0.652344.  0.660156.  0.660156.  0.644531.  0.644531.  0.652344.  0.746094.  0.691406.  0.613281.  0.632812.  0.640625.  0.617188.  0.582031.  0.406250.  0.402344.  0.382812.  0.359375.  0.312500.  0.289062.  0.335938.  0.324219.  0.441406.  0.515625.  0.546875.  0.546875.  0.531250.  0.496094.  ]
[ 0.605469.  0.609375.  0.628906.  0.664062.  0.660156.  0.636719.  0.660156.  0.648438.  0.640625.  0.640625.  0.675781.  0.960938.  0.761719.  0.589844.  0.570312.  0.554688.  0.433594.  0.304688.  0.332031.  0.441406.  0.437500.  0.414062.  0.378906.  0.363281.  0.289062.  0.328125.  0.332031.  0.410156.  0.500000.  0.539062.  0.519531.  0.503906.  ]
[ 0.578125.  0.519531.  0.507812.  0.574219.  0.628906.  0.644531.  0.652344.  0.652344.  0.636719.  0.644531.  0.636719.  0.703125.  0.613281.  0.500000.  0.378906.  0.257812.  0.269531.  0.257812.  0.347656.  0.460938.  0.476562.  0.464844.  0.445312.  0.367188.  0.386719.  0.355469.  0.226562.  0.261719.  0.421875.  0.546875.  0.539062.  0.523438.  ]
[ 0.496094.  0.425781.  0.183594.  0.343750.  0.597656.  0.664062.  0.656250.  0.664062.  0.660156.  0.648438.  0.640625.  0.574219.  0.503906.  0.496094.  0.390625.  0.265625.  0.304688.  0.281250.  0.324219.  0.515625.  0.570312.  0.484375.  0.410156.  0.417969.  0.449219.  0.332031.  0.246094.  0.179688.  0.308594.  0.515625.  0.550781.  0.523438.  ]
[ 0.511719.  0.386719.  0.164062.  0.273438.  0.558594.  0.652344.  0.644531.  0.656250.  0.667969.  0.628906.  0.546875.  0.468750.  0.507812.  0.562500.  0.453125.  0.343750.  0.355469.  0.332031.  0.300781.  0.484375.  0.636719.  0.531250.  0.398438.  0.414062.  0.390625.  0.332031.  0.210938.  0.191406.  0.222656.  0.417969.  0.539062.  0.531250.  ]
[ 0.664062.  0.402344.  0.210938.  0.484375.  0.597656.  0.628906.  0.636719.  0.648438.  0.644531.  0.679688.  0.441406.  0.488281.  0.613281.  0.609375.  0.472656.  0.335938.  0.320312.  0.328125.  0.312500.  0.316406.  0.539062.  0.570312.  0.441406.  0.339844.  0.324219.  0.335938.  0.277344.  0.218750.  0.156250.  0.289062.  0.519531.  0.535156.  ]
[ 0.703125.  0.523438.  0.367188.  0.601562.  0.679688.  0.617188.  0.609375.  0.597656.  0.808594.  0.925781.  0.808594.  0.609375.  0.679688.  0.578125.  0.488281.  0.363281.  0.335938.  0.289062.  0.230469.  0.296875.  0.535156.  0.558594.  0.519531.  0.414062.  0.335938.  0.339844.  0.328125.  0.292969.  0.195312.  0.156250.  0.371094.  0.515625.  ]
[ 0.714844.  0.421875.  0.554688.  0.644531.  0.691406.  0.605469.  0.621094.  0.476562.  0.832031.  0.925781.  0.859375.  0.640625.  0.714844.  0.609375.  0.488281.  0.468750.  0.304688.  0.312500.  0.175781.  0.355469.  0.683594.  0.613281.  0.605469.  0.417969.  0.339844.  0.402344.  0.343750.  0.304688.  0.230469.  0.160156.  0.230469.  0.406250.  ]
[ 0.734375.  0.390625.  0.527344.  0.664062.  0.730469.  0.648438.  0.675781.  0.523438.  0.457031.  0.757812.  0.777344.  0.664062.  0.722656.  0.738281.  0.523438.  0.457031.  0.398438.  0.328125.  0.148438.  0.488281.  0.820312.  0.625000.  0.570312.  0.363281.  0.324219.  0.367188.  0.406250.  0.332031.  0.285156.  0.214844.  0.242188.  0.296875.  ]
[ 0.738281.  0.351562.  0.496094.  0.683594.  0.679688.  0.648438.  0.695312.  0.621094.  0.378906.  0.656250.  0.656250.  0.535156.  0.726562.  0.843750.  0.625000.  0.480469.  0.468750.  0.449219.  0.195312.  0.585938.  0.757812.  0.605469.  0.480469.  0.355469.  0.328125.  0.328125.  0.371094.  0.335938.  0.328125.  0.285156.  0.308594.  0.285156.  ]
...

これを読み込ませてkeras2cppのサンプルコードをコンパイルし、実行する。実行するとSegmentation Faultが出てしまった。

$ g++ -g -Wall -O0 -std=c++11 keras_model.cc example_main.cc && ./a.out
This is simple example with Keras neural network model loading into C++.
Keras model will be used in C++ for prediction only.
3
Reading model from ./dumped.nnet
Layers 18
Layer 0 Conv2D
Layer is empty, maybe it is not defined? Cannot define network.
DataChunk2D 3x32x32
Segmentation fault (core dumped)

デバッグしてみると、Kerasのモデルにおいて、Convolutional2DとConv2Dの名前の違いがあるらしい?

diff --git a/keras_model.cc b/keras_model.cc
index 6c65550..1159839 100644
--- a/keras_model.cc
+++ b/keras_model.cc
@@ -420,6 +420,8 @@ void keras::KerasModel::load_weights(const string &input_fname) {
     Layer *l = 0L;
     if(layer_type == "Convolution2D") {
       l = new LayerConv2D();
+    } else if(layer_type == "Conv2D") {
+      l = new LayerConv2D();
     } else if(layer_type == "Activation") {
       l = new LayerActivation();
     } else if(layer_type == "MaxPooling2D") {

上記の修正を入れて再度実行してみる。実行結果は、、、メモリ不足?対策を考えなければ。

$ g++ -g -Wall -O0 -std=c++11 keras_model.cc example_main.cc && ./a.out
This is simple example with Keras neural network model loading into C++.
Keras model will be used in C++ for prediction only.
3
Reading model from ./dumped.nnet
Layers 18
Layer 0 Conv2D
Layer 0 Conv2D
terminate called after throwing an instance of 'std::bad_alloc'
  what():  std::bad_alloc
Aborted (core dumped)