FPGA開発日記

FPGAというより、コンピュータアーキテクチャかもね! カテゴリ別記事インデックス https://sites.google.com/site/fpgadevelopindex/

ニューラルネットワーク C/C++の実装 (4. Python環境からパラメータを引っ張ってくる)

ニューラルネットワークを1からC/C++で記述してMNISTを動作させてみる話、実装のどこが間違っているのか分からないので、Pythonの実装で重み部分を抽出し、パラメータをそのまま当てはめて検算していく。

前回Pythonからパラメータを抽出するプログラムは抽出部分を間違えていた。これじゃgrad(差分)の部分を抽出してしまっている。正しいのはコチラ。

w1_fp = open('w1.h', 'w')
for y_idx in range(network.layers['Affine1'].W.shape[0]):
    for x_idx in range (network.layers['Affine1'].W.shape[1]):
        w1_fp.write ("{0:.10f} ".format(network.layers['Affine1'].W[y_idx][x_idx]))
    w1_fp.write ("\n")
w1_fp.close()

b1_fp = open('b1.h', 'w')
for x_idx in range (network.layers['Affine1'].b.shape[0]):
    b1_fp.write ("{0:.10f} ".format(network.layers['Affine1'].b[x_idx]))
b1_fp.close()

w2_fp = open('w2.h', 'w')
for y_idx in range(network.layers['Affine2'].W.shape[0]):
    for x_idx in range (network.layers['Affine2'].W.shape[1]):
        w2_fp.write ("{0:.10f} ".format(network.layers['Affine2'].W[y_idx][x_idx]))
    w2_fp.write ("\n")
w2_fp.close()

b2_fp = open('b2.h', 'w')
for x_idx in range (grad['b2'].shape[0]):
    b2_fp.write ("{0:.10f} ".format(network.layers['Affine2'].b[x_idx]))
b2_fp.close()

さらに重みの配列の形も間違えていたのでこちらも修正して実行すると、無事にMNISTが認識して、97%の精度で正解が出るようになった!

./twolayernet
...
t = 7, ans_data=7
9990    0.970871
t = 8, ans_data=8
9991    0.970874
t = 9, ans_data=9
9992    0.970877
t = 0, ans_data=0
9993    0.970880
t = 1, ans_data=1
9994    0.970883
t = 2, ans_data=2
9995    0.970885
t = 3, ans_data=3
9996    0.970888
t = 4, ans_data=4
9997    0.970891
t = 5, ans_data=5
9998    0.970894
t = 6, ans_data=6
9999    0.970897

デバッグをしている最中に、やっとニューラルネットの配列の形について理解できるようになった。

今回のネットワークは、入力データ 28\times 28 に対して中間層が50個のニューロンで構成されている。さらに出力は10個のラベルに区分けされる。

ネットワークには10000個のデータを同時に流すようになっており、バッチサイズは10000個となっている。したがって、

  • 入力データ:  (10000, 28\times 28)
  • 中間層Affine1重み :  (784, 50)
  • 中間層Affine2重み :  (50, 10)

となる。図にすると以下のようになる。今回のC/C++実装はバッチサイズ1のため、10000のところが1になったと考えればよい。

f:id:msyksphinz:20170628012908p:plain

バッチサイズの拡張と、学習部分の実装について進めていこう。