FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

RISC-V Matrix Extension Specificationについて読み進める (4. 行列乗算のサンプルを見てみる)

T-Headが提案しているRISC-VのMatrix Extensionについて、マニュアルを読みながら理解していこうと思う。

とりあえずマニュアルで、どのようなレジスタが存在しているのかを理解していく。プログラミングモデルとサンプルコードも読み進めていきたい。

github.com


github.com

int main()
{
  printf("===== demo: matmul-intrinsic =====\n");
  /* init data */
  int32_t x[N] = {16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
  int32_t y[N] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  int32_t z[N] = {0};

  uint8_t msize_m = 2;
  uint8_t msize_n = 2;
  uint16_t msize_k = 8; // sizeof(int32_t) * 2;
  long stride = 8;      // sizeof(int32_t) * 2;

  /* Configuration matrix size */
  mcfgm(msize_m);
  mcfgn(msize_n);
  mcfgk(msize_k);

  /* init matrix value*/
  mint32_t ma = mld_i32(x, stride);
  mint32_t mb = mld_i32(y, stride);
  mint32_t ans = mld_i32(z, stride);

  print_data("Initial value of matrix", ma, mb, ans);

  ans = mmul_mi32(ma, mb);
  print_data("Results of multiplication", ma, mb, ans);

  return 0;
}

上記のコードは、N=16であり、実行されるのはmmul命令なので要素毎の乗算となればよい。m=2で行のサイズ、n=2なのでおそらく2x2の行列を処理することになる。

まず、下記のコードで、2つの行列の大きさを設定する。

  mcfgm(msize_m);
  mcfgn(msize_n);
  mcfgk(msize_k);

次にデータのロードを行うのが以下のコードだ。

  /* init matrix value*/
  mint32_t ma = mld_i32(x, stride);
  mint32_t mb = mld_i32(y, stride);
  mint32_t ans = mld_i32(z, stride);

stride = 8が設定されているのだが、これはどのように考えればいいんだ?

ちょっとmld命令の意味が理解できないので、qemuの実装をあたってみる。

 static void mmext_mld(void *md, target_ulong rs1, target_ulong s2,
                       mmext_ld_fn *ld_elem, mmext_set_elem *set_elem,
                       CPURISCVState *env, uint8_t esz, uintptr_t ra,
                       bool streaming){
     uint32_t i, k;
     target_ulong addr;
/* ... 途中省略 ... */
     for (i = 0; i < get_mrows(env); i++) {
         for (k = 0; k < (get_mlenb(env) >> esz); k++) {
             addr = rs1 + i * s2 + k * (1 << esz);
             if (i < env->sizem && k < (env->sizek >> esz)) {
                 set_elem(md, i, k, env, ld_elem(env, addr, ra));
             } else {
                 set_elem(md, i, k, env, 0);
             }
         }
     }
/* ... 途中省略 ... */

s2というのは、1行の長さを示すストライドを示しているのか?上記であれば、stride = 8なので、i32で2要素が1行を構成するという理解でいいのか?

で、同じ形でmaとmbをロードするということになる。 最終的には、2x2の行列どうしでmmul命令によって乗算を行い、その結果をansに格納することになる。

... ちゃんとQEMUで動作を確認したほうがよさそうだな。