FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

RISC-VプロセッサHiFive1で機械学習コードを動作させる(2. ニューラルネットのパラメータのロード)

f:id:msyksphinz:20170821013230p:plain

MNISTのデータをロードするところまでできるようになった。まずは学習処理ではなく、学習結果のパラメータをロードしてデータを評価できるようにする。

学習済みデータをオブジェクトファイルに変換する

前回と同様、パラメータなどの初期値データはファイルからロードできないので、オブジェクトファイルに変換してバイナリファイルに埋め込んでしまう。

そのために、学習済みデータをx86で実行した結果から抽出して、RISC-Vのオブジェクトに変換してリンクする。

学習済みパラメータをバイナリファイルとして吐き出す。

  • training/machine_learning/bp/
  FILE *wh0_fp = fopen ("wh0.bin", "w");
  for (int x = 0; x < INPUTNO * HIDDENNO; x++) {
    fwrite (&wh0[x], sizeof(wh0[x]), 1, wh0_fp);
  }
  fclose (wh0_fp);

  FILE *wb0_fp = fopen ("wb0.bin", "w");
  for (int x = 0; x < HIDDENNO; x++) {
    fwrite (&wb0[x], sizeof(wb0[x]), 1, wb0_fp);
  }
  fclose (wb0_fp);

  FILE *wh1_fp = fopen ("wh1.bin", "w");
  for (int x = 0; x < HIDDENNO * OUTPUTNO; x++) {
    fwrite (&wh1[x], sizeof(wh1[x]), 1, wh1_fp);
  }
  fclose (wh1_fp);

  FILE *wb1_fp = fopen ("wb1.bin", "w");
  for (int x = 0; x < OUTPUTNO; x++) {
    fwrite (&wb1[x], sizeof(wb1[x]), 1, wb1_fp);
  }
  fclose (wb1_fp);
  wb1_fp = fopen ("wb1.txt", "w");
  for (int x = 0; x < OUTPUTNO; x++) {
    fprintf (wb1_fp, "%04x ", wb1[x]);
  }
  fclose (wb1_fp);

これで、学習済みパラメータを格納したwb0,bin, wb1.bin, wh0.bin, wh1.bin が出力される。一応、hexdump しておこう。

$ hexdump wb0.bin
0000000 01fb 0000 0635 0000 0325 0000 1289 0000
0000010 18b5 0000 f66c ffff 0153 0000 181b 0000
0000020 14fa 0000 f99b ffff 038d 0000 0436 0000
0000030 008e 0000 0ad2 0000 016e 0000 fbe5 ffff
0000040 0c11 0000 0949 0000 f787 ffff f85e ffff
0000050 ed7b ffff ff8a ffff 06e1 0000 1048 0000
0000060 fc9a ffff 1ed6 0000 023f 0000 06a2 0000
0000070 f78f ffff 05ca 0000 0cd1 0000 0a1e 0000
0000080 04ce 0000 0c55 0000 fefd ffff 098a 0000
0000090 1443 0000 00fa 0000 0b5b 0000 1454 0000
00000a0 fa8b ffff 0732 0000 0cf3 0000 f7da ffff
00000b0 0981 0000 0986 0000 0d71 0000 ec68 ffff
00000c0 08f1 0000 0344 0000
00000c8

バイナリデータをRISC-Vオブジェクトに変換し、RISC-Vバイナリにロードする

次に生成したバイナリデータをRISC-Vのオブジェクトファイルに変換し、実行ファイルにリンクする。

これはMakefileに記述しており、以下のようになる。

LDFLAGS += wb0_init.o
LDFLAGS += wb1_init.o
LDFLAGS += wh0_init.o
LDFLAGS += wh1_init.o

LINK_DEPS += wb0_init.o
LINK_DEPS += wb1_init.o
LINK_DEPS += wh0_init.o
LINK_DEPS += wh1_init.o

%_init.o: %.bin
        riscv64-unknown-elf-objcopy -I binary -O elf32-littleriscv -B riscv --rename-section .data=.rodata $< $@

wb0.binからwb0_init.oを生成していく。objcopyでRISC-Vのオブジェクトに変換んし、mnistのバイナリ生成時にリンクさせる。

読み込んだオブジェクトファイルをHiFive1ボードからダンプするプログラムを書く

実際に初期値をロードできたかどうか、ダンププログラムを書いて確認しておく。

extern char _binary_wb0_bin_start[];
extern char _binary_wb0_bin_end[];
extern char _binary_wb1_bin_start[];
extern char _binary_wb1_bin_end[];
extern char _binary_wh0_bin_start[];
extern char _binary_wh0_bin_end[];
extern char _binary_wh1_bin_start[];
extern char _binary_wh1_bin_end[];

...

  const fix16_t *wh0 = (fix16_t *)_binary_wh0_bin_start;  // [INPUTNO * HIDDENNO];
  const fix16_t *wb0 = (fix16_t *)_binary_wb0_bin_start;  // [HIDDENNO];
  const fix16_t *wh1 = (fix16_t *)_binary_wh1_bin_start;  // [HIDDENNO * OUTPUTNO];
  const fix16_t *wb1 = (fix16_t *)_binary_wb1_bin_start;  // [OUTPUTNO];

  for (i = 0; i < HIDDENNO * INPUTNO; i++) {
    fix16_t hex_value = wh0[i];
    for (int j = 7; j >=0; j--) {
      write (STDOUT_FILENO, hex_enum[(hex_value >> (j * 4)) & 0x0f], 2);
    }
    if ((i % HIDDENNO) == (HIDDENNO-1)) { write (STDOUT_FILENO, "\r\n", 2); }
  }

  for (i = 0; i < HIDDENNO; i++) {
    fix16_t hex_value = wb0[i];
    for (int j = 7; j >=0; j--) {
      write (STDOUT_FILENO, hex_enum[(hex_value >> (j * 4)) & 0x0f], 2);
    }
    if ((i % HIDDENNO) == (HIDDENNO-1)) { write (STDOUT_FILENO, "\r\n", 2); }
  }

  for (i = 0; i < HIDDENNO * OUTPUTNO; i++) {
    fix16_t hex_value = wh1[i];
    for (int j = 7; j >=0; j--) {
      write (STDOUT_FILENO, hex_enum[(hex_value >> (j * 4)) & 0x0f], 2);
    }
    if ((i % HIDDENNO) == (HIDDENNO-1)) { write (STDOUT_FILENO, "\r\n", 2); }
  }

  for (i = 0; i < OUTPUTNO; i++) {
    fix16_t hex_value = wb1[i];
    for (int j = 7; j >=0; j--) {
      write (STDOUT_FILENO, hex_enum[(hex_value >> (j * 4)) & 0x0f], 2);
    }
    if ((i % HIDDENNO) == (HIDDENNO-1)) { write (STDOUT_FILENO, "\r\n", 2); }
  }

ロードしたオブジェクトファイルは、それぞれ_binary_wh0_bin_start, _binary_wb1_bin_start, _binary_wb1_bin_start, _binary_wh1_bin_start として参照できる。これらをダンプしてUARTに出力している。

f:id:msyksphinz:20170822011522p:plain

これだけでは何が何だかさっぱりだが、データ的には想定したものがロードできた。

msyksphinz.hatenablog.com