FPGA開発日記

FPGAというより、コンピュータアーキテクチャかもね! カテゴリ別記事インデックス https://msyksphinz.github.io

ゼロから学ぶ畳み込みニューラルネットワーク 調査中

RISC-V で MNIST を実行できるようになったので、次はCNNを実行してみたい。

多くのCNNのコードはPythonで記述してあるのだが、もう少しバイナリに近い言語で書いてあったほうが解析とRISC-Vの移植がやりやすい。

師匠のブログを読みながら、少しずつ進めてみることにする。

まずはディープラーニングC++実装ということで以下を調査してみた。

C++で学ぶディープラーニング

C++で学ぶディープラーニング

が、これはどうもCUDAが用意されていることが前提で、Virtual Box上でUbuntuを実行している環境ではCUDAを実行することができない。 一応買ってみたものの、あまり使えないなあ。

C/C++フレームワークということで色々調べているのだが、とりあえずx86で動かすならこんなのがお手軽かなあ。

github.com

ゼロから学ぶディープラーニング

第7章あたりを読み直している。

f:id:msyksphinz:20180113222247p:plain

tinyDNN を使って、MNIST を実行する

github.com

MNISTを学習するためのネットワークは、以下のように記述されている。

  • tiny-dnn/examples/mnist/train.cpp
  using fc = tiny_dnn::layers::fc;
  using conv = tiny_dnn::layers::conv;
  using ave_pool = tiny_dnn::layers::ave_pool;
  using tanh = tiny_dnn::activation::tanh;

  using tiny_dnn::core::connection_table;
  using padding = tiny_dnn::padding;

  nn << conv(32, 32, 5, 1, 6,   // C1, 1@32x32-in, 6@28x28-out
             padding::valid, true, 1, 1, backend_type)
     << tanh()
     << ave_pool(28, 28, 6, 2)   // S2, 6@28x28-in, 6@14x14-out
     << tanh()
     << conv(14, 14, 5, 6, 16,   // C3, 6@14x14-in, 16@10x10-out
             connection_table(tbl, 6, 16),
             padding::valid, true, 1, 1, backend_type)
     << tanh()
     << ave_pool(10, 10, 16, 2)  // S4, 16@10x10-in, 16@5x5-out
     << tanh()
     << conv(5, 5, 5, 16, 120,   // C5, 16@5x5-in, 120@1x1-out
             padding::valid, true, 1, 1, backend_type)
     << tanh()
     << fc(120, 10, true, backend_type)  // F6, 120-in, 10-out
     << tanh();

Convolutional Network が構成されている(よね?)ので、これもCNNということができるかな。 プログラムを実行するとLeNetというファイルが生成されたので、おそらくCNNだろう。

とりあえずMNISTのデータセットをダウンロードしてやってみる。 なぜかデータセットのファイル名が一致しなくてファイル名を修正しなければならなかった。

cd mnist/images
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
mv t10k-images-idx3-ubyte t10k-images.idx3-ubyte
mv t10k-labels-idx1-ubyte t10k-labels.idx1-ubyte
mv train-images-idx3-ubyte train-images.idx3-ubyte
mv train-labels-idx1-ubyte train-labels.idx1-ubyte
cd -

猛烈に時間がかかるので、epoch回数を減らして、トレーニング数を減らして実行してみる。

$ ./example_mnist_train --data_path mnist/images --learning_rate 1 --epochs 3 --minibatch_size 23 --backend_type internal
Running with the following parameters:
Data path: mnist/images
Learning rate: 1
Minibatch size: 23
Number of epochs: 3
Backend type: Internal

load models...
start training

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************Epoch 1/3 finished. 40.9915s elapsed.
9442/10000

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************Epoch 2/3 finished. 42.4635s elapsed.
9593/10000

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************Epoch 3/3 finished. 116.432s elapsed.
9645/10000

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
end training.
accuracy:96.45% (9645/10000)
    *     0     1     2     3     4     5     6     7     8     9
    0   969     0     6     1     1     5     8     2     7     4
    1     0  1118     0     0     0     2     4     5     0     8
    2     0     2   986     5     3     0     0    21     1     0
    3     0     4    11   975     0    15     0     4     5     9
    4     0     0     2     1   946     0     1     3     4     9
    5     1     1     0     8     0   848     6     0     4    10
    6     3     2     2     1     9     9   935     0     4     1
    7     1     0     8     7     0     2     0   972     5     4
    8     4     8    16     8     3     6     3     2   941     9
    9     2     0     1     4    20     5     1    19     3   955

example_mnist_test を実行する

早速トレーニングデータを使って実験してみよう。以下の画像データをダウンロードして実行してみる。

f:id:msyksphinz:20180113221956p:plain

$ wget https://github.com/tiny-dnn/tiny-dnn/wiki/4.bmp
$ ./example_mnist_test mnist/4.bmp
4,86.7497
7,84.1264
9,72.9942

CIFAR-10 のトレーニングと実行

同様に、 CIFAR-10 を使ってトレーニングと推論を実行してみる。

$ cd cifar10/images/
$ wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
$ tar xvfz cifar-10-binary.tar.gz
cifar-10-batches-bin/
cifar-10-batches-bin/data_batch_1.bin
cifar-10-batches-bin/batches.meta.txt
cifar-10-batches-bin/data_batch_3.bin
cifar-10-batches-bin/data_batch_4.bin
cifar-10-batches-bin/test_batch.bin
cifar-10-batches-bin/readme.html
cifar-10-batches-bin/data_batch_5.bin
cifar-10-batches-bin/data_batch_2.bin

トレーニングを実行する。非常に時間がかかる。

$ ./example_cifar_train --data_path cifar10/cifar-10-batches-bin --learning_rate 0.01 --epochs 3 --minibatch_size 20 --backend_type internal
Running with the following parameters:
Data path: cifar10/cifar-10-batches-bin
Learning rate: 0.01
Minibatch size: 20
Number of epochs: 3
Backend type: Internal

load models...
start learning

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
Epoch 1/3 finished. 473.216s elapsed.
4794/10000

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
Epoch 2/3 finished. 454.429s elapsed.
5186/10000

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
Epoch 3/3 finished. 469.389s elapsed.
5433/10000

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
end training.
accuracy:54.33% (5433/10000)
    *     0     1     2     3     4     5     6     7     8     9
    0   535    15    68    23    39     7     5     9    70    24
    1    66   722    26    23    21     8    20    14   101   190
    2    46     3   301    42    96    47    34    15    13     7
    3    52    25   164   509   120   265   131   101    40    37
    4    13     4    86    28   276    32    36    33     5     3
    5     9     3   106   136    78   436    25    65     5     8
    6    24    17   129   102   209    60   701    45     9    34
    7    31    19    71    88   128   118    20   665    16    34
    8   149    43    25    12    19    11     8     9   679    54
    9    75   149    24    37    14    16    20    44    62   609

以下の画像を入力して実行してみる。画像の抽出は以下のサイトを参考にさせてもらった。

f:id:msyksphinz:20180113230026p:plain ← なんじゃこりゃ?

xiaoxia.exblog.jp

$ ./example_cifar_test cifar10/cifar-10-batches-py/data_batch_1.0_leptodactylus_pentadactylus_s_000004.png.bmp
3,81.7409
2,60.2516
8,55.0314

でも、これ合ってないなあ。。。もうちょっと学習の回数を増やそう。