FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

「ゼロから作るDeep Learning」第7章のCNNでCIFAR-10に挑戦してみる (2. C言語でのCNN実行環境を実装する)

「ゼロから作るDeep Learning」の第7章、CNNを勉強したので、PythonではなくてC言語で1から実装してみたい。

何故C言語かというと、RISC-Vで動作させたいし、最終的にはRoCCを使ってアクセラレーションに挑戦してみたい。そのためには、C言語で実装するのが一番わかりやすいのだ。

CIFAR-10 の画像データのダウンロードとダンプツールの開発

CIFAR-10をダウンロードし、解凍してオブジェクトを生成する。念のためCIFAR-10のデータをダンプするツールを開発しておく。

github.com

CIFAR10のデータ構造は、画像データ毎に、

  • 最初の1バイト: 画像のラベル
  • 次の32x32バイト : 画像のR要素
  • 次の32x32バイト : 画像のG要素
  • 次の32x32バイト : 画像のB要素

なので、以下のように実装してR要素だけダンプするようなツールを作って、CIFARの画像を確認できるようにした。

  • cnn_cidar10_x86/cifar10_dump.c
  uint8_t *cifar10_data = _binary_cifar_10_batches_bin_data_batch_1_bin_start;
  while (cifar10_data != _binary_cifar_10_batches_bin_data_batch_1_bin_end) {
    uint8_t cifar10_label = cifar10_data[0];
    cifar10_data++;
    fprintf (dump_fp, "LABEL = %d\n", cifar10_label);

    for (int y = 0; y < 32; y++) {
      for(int x = 0; x < 32; x++) {
        fprintf (dump_fp, "%02x ", cifar10_data[y * 32 + x]);
      }
      fprintf (dump_fp, "\n");
    }
    cifar10_data += 32 * 32 * 3;
  }

CNNのネットワークの学習済みの重みをダンプする

とりあえず、CNN内部のネットワークの重みは他次元配列になっているのだが、これをすべてダンプしてC言語に変換したい。 「ゼロから作るDeep Learning」では、PKL形式で保存されているので、これを読み込んでダンプするプログラムを作っておいた。 重みによって次元が異なるので、再帰を使ってどのような次元の配列でもダンプできるようにしておく。

github.com

  • ch07/dump_params.py
def recurse_dump(array, dim=0):
    if array.ndim == 1:
        for tab in range(dim):
            print(" ", end="")
        print("{", end="")
        for i in range(len(array)):
            print("%10.20f " % array[i], end="")
            if i != len(array)-1:
                print(",", end="")
        for tab in range(dim):
            print(" ", end="")
        print("}", end="")
    else:
        for tab in range(dim):
            print(" ", end="")
        print("{")
        for i in range(len(array)):
            array_elem = array[i]
            recurse_dump(array_elem, dim+1)
            if i != len(array)-1:
                print(",")
            else:
                print("")
        for tab in range(dim):
            print(" ", end="")
        print(" }", end="")

print("const double conv_w0[%d][%d][%d][%d] = " % (network.layers['Conv1'].W.shape[0],
                                                   network.layers['Conv1'].W.shape[1],
                                                   network.layers['Conv1'].W.shape[2],
                                                   network.layers['Conv1'].W.shape[3]), end="")
recurse_dump(network.layers['Conv1'].W)
print(";")
print("const double conv_b0[%d] = " % (network.layers['Affine1'].b.shape[0]), end="")
recurse_dump(network.layers['Affine1'].b)
print(";")

次のように出力される。

const double conv_w0[30][3][5][5] = {
 {
  {
   {-0.05391399321562118790 ,-0.08282846932293289055 ,-0.09093539095966461649 ,-0.05669240093168621819 ,0.02650859637447007186    },
   {-0.04668663031582918899 ,-0.04442664828139031991 ,-0.02699593254157254402 ,-0.06292007092671214608 ,-0.04121446537614883304    },
   {-0.08444171389035450004 ,-0.05161831927383831592 ,-0.04931356823456299610 ,-0.01802056561936658807 ,-0.05474163819510180495    },
   {-0.09596361288506885523 ,-0.05460559915561497002 ,-0.07528187433209208856 ,-0.07109413109025114474 ,-0.06304327973495237047    },
...

   {-0.00523652995230692274 ,0.01657456249978067273 ,0.00229446783835494375 ,-0.13799807309192410609 ,-0.08449199188719012932    },
   {-0.09448965052534342990 ,-0.12164972247881274126 ,-0.11875352537119687746 ,-0.03785788148686732024 ,0.04228460261730874331    },
   {-0.12138666268525750980 ,-0.02712271471645598397 ,0.01609190848564413107 ,-0.00891810275382830774 ,-0.05523533875481138194    }
   }
  }
 };