FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

RISC-V Matrix Extension Specificationについて読み進める (3. 命令定義)

T-Headが提案しているRISC-VのMatrix Extensionについて、マニュアルを読みながら理解していこうと思う。

とりあえずマニュアルで、どのようなレジスタが存在しているのかを理解していく。プログラミングモデルとサンプルコードも読み進めていきたい。

github.com


命令仕様

行列乗算命令

行列乗算命令は、ms1とms2で指定された行列レジスタから行列A(sizeM×sizeK)と行列B(sizeN×sizeK)を取り出し、mdレジスタから$A[M][K]\times(B[N][K])T$の乗算結果を行列C(sizeM×sizeN)に累積し、出力は累積レジスタを上書きされる。

  • 行列Aの形状:M行(sizeM)、K列(sizeK/要素サイズ(バイト))
  • 行列Bの形状: N行(sizeN), K列(sizeK/要素サイズ(バイト))
  • 行列Cの形状: M行(sizeM), N列(sizeN)

関数の説明:

for(int i=0; i<sizeM; i++) {
  for(int j=0; j<sizeN; j++) {
      for(int k=0; k<(sizeK/element size); k++)
         C[i,j] += A[i,k]*B[j,k];
}}}

ISA仕様では、浮動小数点および整数の行列乗算と演算をサポートするために、さまざまな命令を提供しています。ハードウェア設計は、サポートされるデータ型の柔軟性を持っている。

カテゴリ 命令 オペランド型 A,B 累積 型Type C Optional Feature
Float fmmacc.h fp16/bf16 fp16 MATRIX_MULT_F16F16
fmmacc.s fp32 fp32 MATRIX_MULT_F32F32
fmmacc.d fp64 fp64 MATRIX_MULT_F64F64
fwmmacc.h fp16/bf16 fp32 MATRIX_MULT_F16F32
fwmmacc.s fp32 fp64 MATRIX_MULT_F32F64
Int mmaqa.b mmaqu.b mmaqasu.b mmaqaus.b int8 int32 MATRIX_MULT_I8I32
mmaqa.h mmaqu.h mmaqasu.h mmaqaus.h int16 int64 MATRIX_MULT_I16I64
pmmaqa.b pmmaqu.b pmmaqasu.b pmmaqaus.b int4(mx8) int32(mxm) MATRIX_MULT_I4I32

浮動小数点行列乗算は、浮動小数点制御およびステータス・レジスタ fcsr を使用して、浮動小数点算術演算の動的丸めモードを選択し、発生した例外フラグを保持します。

浮動小数点行列乗算は、frm の動的丸めモードを使用します。frm に無効な値(101-111)が設定されている場合、それ以降に動的丸めモー ドで浮動小数点演算を実行しようとすると、不正な命令例外が発生します。

rounding mode Mnemonic Meaning
000 RNE Round to Nearest, ties to Even
001 RTZ Round towards Zero
010 RDN Round Down (towards -∞)
011 RUP Round Up (towards +∞)
100 RMM Round to Nearest, ties to Max Magnitude
101 Invalid. Reserved for future use
110 Invalid. Reserved for future use
111 Invalid in rounding mode register

浮動小数点ユニット・ステータス・フィールド mstatus.FS がオフの場合、行列浮動小数点命令 を実行しようとすると不正命令例外が発生する。浮動小数点拡張状態(浮動小数点 CSR や f レジスタなど)を変更する行列浮動小数点命令は、mstatus.FS を Dirty に設定する必要がある。浮動小数点行列乗算の基本演算は浮動小数点dotで、浮動小数点dot演算は IEEE-754/2008 標準に従う。

標準の行列浮動小数点命令は、要素を IEEE-754/2008 互換値として扱う。行列浮動小数点オペランドの EEW が、サポートされている IEEE 浮動小数点型に対応しない場合、その命令エンコーディングは予約されている。bf16-extension の場合、16 ビット浮動小数点要素は bf16 または fp16 と見なされる。

浮動小数点行列乗算(Non-widen)

Non-widen浮動小数点行列乗算は、ソースオペランドとデスティネーションオペランドのデータ幅が同じであることを示す。

  • fmmacc.h: fp16/bf16 浮動小数点、xmisa レジスタの bit[3] が 0 の場合は不正。
  • fmmacc.s: fp32 浮動小数点、xmisa レジスタの bit[4] が 0 の場合は不正
  • fmmacc.d: fp64 浮動小数点、xmisa レジスタの bit[5] が 0 の場合は不正
# float matrix multiplication, md = md + ms1*ms2
fmmacc.h md, ms2, ms1
fmmacc.s md, ms2, ms1
fmmacc.d md, ms2, ms1

fmmacc.sの場合、最大行列形状は次のようになる:

  • matrixA: M ⇐ RLEN/32, K ⇐ RLEN/16
  • matrixB: N ⇐ RLEN/16, K ⇐ RLEN/16
  • matrixC: M ⇐ RLEN/32, N ⇐ RLEN/16

行列Bのデータ幅は行列AとCの2倍であるため、行列Bではms2とms2+1で指定される2つの行列レジスタ(レジスタペア)が使用される。奇数番目のms2を指定する命令は予約されている。RLEN=128の場合の動作を以下に示す。

fmmacc.d(64ビット浮動小数点行列乗算・加算命令)の場合、最大行列形状は以下である:

  • matrixA: M ⇐ RLEN/32, K ⇐ RLEN/64
  • matrixB: N ⇐ RLEN/32, K ⇐ RLEN/64
  • matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32

行列Cのデータ幅は行列AとBの2倍であるため、行列Cではmdとmd+1で指定される2つの行列レジスタ(レジスタペア)が使用される。奇数番号のmdを指定する命令は予約されている。RLEN=128の場合の動作を以下に示す。

典型的なRLENにおけるfmmacc命令の最大行列サイズのまとめ:

matrix A matrix B matrix C
RLEN M K data width N K data width M N data width Gops/GHz latency
fmacc.s 128 4 4 512 bits 4 4 512 bits 4 4 512 bits 32 4
256 8 8 2048 bits 8 8 2048 bits 8 8 2048 bits 128 8
512 16 16 8192 bits 16 16 8192 bits 16 16 8192 bits 512 16
fmacc.h 128 4 8 512 bits 8 8 1024 bits 4 8 512 bits 64 8
256 8 16 2048 bits 16 16 4096 bits 8 16 2048 bits 256 16
512 16 32 8192 bits 32 32 16384 bits 16 32 8192 bits 1024 32
fmacc.d 128 4 2 512 bits 4 2 512 bits 4 4 1024 bits 16
256 8 4 2048 bits 8 4 2048 bits 8 8 4096 bits 64
512 16 8 8192 bits 16 8 8192 bits 16 16 16384 bits 256

浮動小数点行列乗算(widen)

widen浮動小数点行列乗算は、デスティネーションオペランドのデータ幅がソースオペランドの2倍であることを示す。ソースオペランドのデータ幅は命令エンコーディング含まれている。

  • fwmmacc.h: fp16/bf16 浮動小数点ソースと fp32 結果、xmisa レジスタの bit[8] が 0 の場合は不正。
  • fwmmacc.s: fp32 浮動小数点ソースと fp64 結果 , xmisa レジスタの bit[9] が 0 の場合は不正
#float matrix multiplication, output widen, md = md + ms1*ms2
fwmmacc.h md, ms2, ms1
fwmmacc.s md, ms2, ms1

fwmmacc.h、16 ビット浮動小数点 widen 行列乗算および加算命令では、要素は fp16 または bf16 データ型がサポートされている場合は bf16 と見なすことができる。最大行列形状は:

  • matrixA: M ⇐ RLEN/32, K ⇐ RLEN/16.
  • matrixB: N ⇐ RLEN/32, K ⇐ RLEN/16
  • matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32

fwmmacc.s, 32ビット浮動小数点widen行列乗算および加算命令では、最大行列形状は以下のようになる:

  • matrixA: M ⇐ RLEN/32, K ⇐ RLEN/32
  • matrixB: N ⇐ RLEN/32, K ⇐ RLEN/32
  • matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32

行列Cのデータ幅は行列AとBの2倍であるため、行列Cではmdとmd+1で指定される2つの行列レジスタ(レジスタペア)が使用される。奇数番号の md を指定する命令は予約されている.典型的な RLEN における fwmmacc 命令の最大行列サイズのまとめ.

整数行列の乗算(4x widen)

出力データ幅が入力データ幅の4倍である整数行列乗算。命令エンコーディングでサポートされるソースオペランドのデータ幅は int8 と int16 で、その他のデータ幅は予約されている。符号付き/符号なし両方のバージョンが提供されている。したがって、ソース・オペランドは、両オペランド符号付き/両オペランド符号なし/符号付き-符号なし/符号なし-符号なしの両方が可能であり、乗算の結果は、加算および累積の前に符号拡張されます。オーバーフローは無視され、結果は折り返される。

  • mmaqa.b/mmaqau.b/mmaqaus.b/mmaqasu.b: int8 の 4 倍拡大行列乗算,xmisa レジスタの bit[1] が 0 の場合は不正.
  • mmaqa.h/mmaqau.h/mmaqaus.h/mmaqasu.h: int16 の 4 倍拡大行列乗算,xmisa レジスタの bit[2] が 0 の場合は不正.
#8bit data width
#signed matrix multiply
mmaqa.b md, ms2, ms1
#unsigned matrix multiply
mmaqau.b md, ms2, ms1
#unsigned-signed matrix multiply
mmaqaus.b md, ms2, ms1
#signed-unsigned matrix multiply
mmaqasu.b md, ms2, ms1

#16bit data width
#signed matrix multiply
mmaqa.h md, ms2, ms1
#unsigned matrix multiply
mmaqau.h md, ms2, ms1
#unsigned-signed matrix multiply
mmaqaus.h md, ms2, ms1
#signed-unsigned matrix multiply
mmaqasu.h md, ms2, ms1

int8 の4回行列乗算の場合、最大行列形状は以下のようになる:

  • matrixA: M ⇐ RLEN/32, K ⇐ RLEN/8
  • matrixB: N ⇐ RLEN/32, K ⇐ RLEN/8
  • matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32

int16の4倍行列乗算の場合、行列Cのデータ幅は行列AとBの4倍であるため、mdとmd+1で指定される行列Cで2つの行列レジスタ(レジスタペア)が使用される。奇数番目のmdを指定する命令は予約されている:

  • matrixA: M ⇐ RLEN/32, K ⇐ RLEN/16
  • matrixB: N ⇐ RLEN/32, K ⇐ RLEN/16
  • matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32

典型的な RLEN における整数行列の乗算および加算命令の最大行列サイズのまとめ:

matrix A matrix B matrix C
RLEN M K data width N K data width M N data width Gops/GHz latency
int8 4x 128 4 16 512 bits 4 16 512 bits 4 4 512 bits 128 4
256 8 32 2048 bits 8 32 2048 bits 8 8 2048 bits 512 8
512 16 64 8192 bits 16 64 8192 bits 16 16 8192 bits 2048 16
int16 4x 128 4 8 512 bits 4 8 512 bits 4 4 1024 bits 64 4
256 8 16 2048 bits 8 16 2048 bits 8 8 4096 bits 256 8
512 16 32 8192 bits 16 32 8192 bits 16 16 16384 bits 1024 16

行列ロード/ストア命令

行列ロード命令は行列をメモリから行列レジスタにロードし、行列ストア命令は行列を行列レジスタからメモリにストアする。

要素のデータ幅は、バイト/ハーフワード/ワード/ダブルワードを含む命令エンコーディングで、その他のデータ幅は予約されている。ベースアドレスはrs1、行ストライド(バイト)はrs2、md/ms3は行列ロード先、行列ストア元のレジスタインデックスである。

#matrix load
mld<b/h/w/d> md, rs2, (rs1)
#matrix store
mst<b/h/w/d>  ms3, rs2, (rs1)
#whole matrix load
mld<1/2/4/8>m md,  (rs1)
#whole matrix store
mst<1/2/4/8>m ms3, (rs1)

マトリックス形状(MxK)はマトリックスサイズ構成レジスタにあり、MはsizeM、KはsizeK(バイト単位)で与えられる。M=sizeM ⇐ RLEN/32, K=sizeK/ 要素サイズ(バイト), sizeK ⇐ RLEN/8.sizeM < RLEN/32 または sizeK < RLEN/8 の場合、行インデックス > sizeM または列インデックス > (sizeK/ バイト単位の要素サイズ) の行列レジスタデータは、ロードのためにゼロをセットし、ストアのためにメモリに書き込まない。

(1)ノーマル (2)全レジスタロード/ストア の2種類がある。

全レジスタロード/ストアは、sizeM = RLEN/32、sizeK = RLEN/8 で、最大行列サイズのデータをメモリから/メモリへロード/ストアする。マトリックス・サイズの構成は無視される。

rs2[4:3]は0にセットされ、それ以外は予約されている。rs2[2:0]はnfフィールドで、NFIELDSエンコーディングを使用してロードおよびストアするマトリクス・レジスタの数をエンコードする。

nf[2:0] レジスタ数
000 1
001 2
011 4
111 8
others reserved

すべての行列ロード/ストア命令は、ゼロ以外の行開始値を生成し、受け入れることができる。行開始レジスタは、行列命令の実行終了時にゼロにリセットさる。

ZIHINTNTL拡張により、マトリクス・メモリ・アクセス命令は、異なるメモリ階層に適合するストリーム・メモリ・アクセス命令として動作することができる。ストリーム・メモリ・アクセス命令は、通常のマトリクス・ロード/ストア命令と同じ機能を持つが、ハードウェア実装によって最適化できる可能性がある、近い将来データが再利用されない可能性がある点が異なる。