FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

Chiselを使ってMNISTハードウェアアクセラレータを実装(実装中)

前回までで、MNISTをハードウェアアクセラレータをどのように実装すれば良いか、RISC-Vの命令セットシミュレータ(ISS)を使って動作をシミュレーションし、プログラムを構築した。

次に、実際にハードウェアを作成する。コーディングにはChiselを使う。 なぜChiselを使うかというと、Rocket-Chipとの相性が良いからだ。 チートシートもあるので、これを参照しながら進めていく。

関連記事

int64_t を使わずにfix16_mulを実装する方法

libfixmathには2種類の実装があり、内部でint64_tを使って固定小数点を実装する方法と、そうでない方法がある。 (ってかそもそもfloat16_tのはずなのに桁が32ビットあるライブラリの時点でおかしいのだが...別のライブラリに切り替えを検討している)

実装されているfix16のルーチンで、fix16_mulは以下の2種類が存在する。

  • 実装その1 (int64_tを使う)
fix16_t fix16_mul(fix16_t inArg0, fix16_t inArg1)
{
    int64_t product = (int64_t)inArg0 * inArg1;
    ...
  • 実装その2 (int32_tを使う)
fix16_t fix16_mul(fix16_t inArg0, fix16_t inArg1)
{
    // Each argument is divided to 16-bit parts.
    //                 AB
    //         *    CD
    // -----------
    //                 BD  16 * 16 -> 32 bit products
    //              CB
    //              AD
    //             AC
    //          |----| 64 bit product
    int32_t A = (inArg0 >> 16), C = (inArg1 >> 16);
    uint32_t B = (inArg0 & 0xFFFF), D = (inArg1 & 0xFFFF);
    
    int32_t AC = A*C;
    int32_t AD_CB = A*D + C*B;
    uint32_t BD = B*D;
    
    int32_t product_hi = AC + (AD_CB >> 16);
    ...

下側の実装のほうがハードウェア化する際には筋が良さそうなので、まずはこちらを実装していく。

まずはソフトウェアのライブラリを下側の実装に切り替え、MNISTとしての動作に問題がないかを確認する。以下のようにちゃんと動作したので、問題ないだろう。

spike --extension=matrix16_rocc ./train_twolayernet_fix16_full
=== TestNetwork ===
=== TestNetwork ===
 === start ===
Final Result : Correct = 185 / 200
Time = 390255174 - 15031 = 390240143

前回の実装よりも、命令数は伸びている。64bitの命令を使わずに32ビットで分割しているのだから、当たり前である。

この実装をハードウェア化するにあたり、うまくパイプラインを切れれば良いが、まずは全部ワイヤとして実装し、後でパイプラインレジスタを切っていけば良いだろう。

速攻で実装したし、まだ検証していないままで今日は時間切れだが、だいたい以下のようになった。

github.com

(※ 前後は省略)

  // int32_t  A = (a_val >> 16),    C = (b_val >> 16);
  // uint32_t B = (a_val & 0xFFFF), D = (b_val & 0xFFFF);
  w_a_hi := Cat(Fill(16, w_a_val(31)), w_a_val(31,16))
  w_b_hi := Cat(Fill(16, w_b_val(31)), w_b_val(31,16))
  w_a_lo := Cat(UInt(0, 16), w_a_val(15, 0))
  w_b_lo := Cat(UInt(0, 16), w_b_val(15, 0))

  val w_ah_bh       = Wire(SInt(width=32))
  val w_ah_bl_al_bh = Wire(SInt(width=32))
  val w_al_bl       = Wire(UInt(width=32))

  // int32_t  AC    = A*C;
  // int32_t  AD_CB = A*D + C*B;
  // uint32_t BD    = B*D;
  w_ah_bh       := w_a_hi * w_b_hi
  w_ah_bl_al_bh := w_a_hi * w_b_lo + w_a_lo * w_b_hi
  w_al_bl       := w_a_lo * w_b_lo

  val product_hi = Wire(SInt(width=32))
  product_hi := w_al_bl + w_ah_bl_al_bh(31,16)
  val product_hi2 = Wire(SInt(width=32))

  val product_lo = Wire(UInt(width=32))
  product_lo := w_al_bl + Cat(w_ah_bl_al_bh, UInt(0,width=16))

  when (product_lo < w_al_bl) {
    product_hi2 := product_hi + SInt(1)
  } .otherwise {
    product_hi2 := product_hi
  }

  val product_hi3 = Wire(SInt(width=32))

  when (product_lo - UInt(0x8000) - product_hi(31) > product_lo) {
    product_hi3 := product_hi2 - SInt(1)
  } .otherwise {
    product_hi3 := product_hi2
  }

  w_result := Cat(product_hi3(15, 0), product_lo(31,16)).asSInt() + SInt(1)

今度はシミュレーションして検証だ。パイプラインレジスタを切っていき、最後にFPGAインプリメントしよう。

f:id:msyksphinz:20180118000750p:plain