T-Headが提案しているRISC-VのMatrix Extensionについて、マニュアルを読みながら理解していこうと思う。
とりあえずマニュアルで、どのようなレジスタが存在しているのかを理解していく。プログラミングモデルとサンプルコードも読み進めていきたい。
命令仕様
行列乗算命令
行列乗算命令は、ms1とms2で指定された行列レジスタから行列A(sizeM×sizeK)と行列B(sizeN×sizeK)を取り出し、mdレジスタから$A[M][K]\times(B[N][K])T$の乗算結果を行列C(sizeM×sizeN)に累積し、出力は累積レジスタを上書きされる。
- 行列Aの形状:M行(sizeM)、K列(sizeK/要素サイズ(バイト))
- 行列Bの形状: N行(sizeN), K列(sizeK/要素サイズ(バイト))
- 行列Cの形状: M行(sizeM), N列(sizeN)
関数の説明:
for(int i=0; i<sizeM; i++) { for(int j=0; j<sizeN; j++) { for(int k=0; k<(sizeK/element size); k++) C[i,j] += A[i,k]*B[j,k]; }}}
ISA仕様では、浮動小数点および整数の行列乗算と演算をサポートするために、さまざまな命令を提供しています。ハードウェア設計は、サポートされるデータ型の柔軟性を持っている。
カテゴリ | 命令 | オペランド型 A,B | 累積 型Type C | Optional Feature |
---|---|---|---|---|
Float | fmmacc.h | fp16/bf16 | fp16 | MATRIX_MULT_F16F16 |
fmmacc.s | fp32 | fp32 | MATRIX_MULT_F32F32 | |
fmmacc.d | fp64 | fp64 | MATRIX_MULT_F64F64 | |
fwmmacc.h | fp16/bf16 | fp32 | MATRIX_MULT_F16F32 | |
fwmmacc.s | fp32 | fp64 | MATRIX_MULT_F32F64 | |
Int | mmaqa.b mmaqu.b mmaqasu.b mmaqaus.b | int8 | int32 | MATRIX_MULT_I8I32 |
mmaqa.h mmaqu.h mmaqasu.h mmaqaus.h | int16 | int64 | MATRIX_MULT_I16I64 | |
pmmaqa.b pmmaqu.b pmmaqasu.b pmmaqaus.b | int4(mx8) | int32(mxm) | MATRIX_MULT_I4I32 |
浮動小数点行列乗算は、浮動小数点制御およびステータス・レジスタ fcsr を使用して、浮動小数点算術演算の動的丸めモードを選択し、発生した例外フラグを保持します。
浮動小数点行列乗算は、frm の動的丸めモードを使用します。frm に無効な値(101-111)が設定されている場合、それ以降に動的丸めモー ドで浮動小数点演算を実行しようとすると、不正な命令例外が発生します。
rounding mode | Mnemonic | Meaning |
---|---|---|
000 | RNE | Round to Nearest, ties to Even |
001 | RTZ | Round towards Zero |
010 | RDN | Round Down (towards -∞) |
011 | RUP | Round Up (towards +∞) |
100 | RMM | Round to Nearest, ties to Max Magnitude |
101 | Invalid. Reserved for future use | |
110 | Invalid. Reserved for future use | |
111 | Invalid in rounding mode register |
浮動小数点ユニット・ステータス・フィールド mstatus.FS
がオフの場合、行列浮動小数点命令 を実行しようとすると不正命令例外が発生する。浮動小数点拡張状態(浮動小数点 CSR や f レジスタなど)を変更する行列浮動小数点命令は、mstatus.FS
を Dirty に設定する必要がある。浮動小数点行列乗算の基本演算は浮動小数点dotで、浮動小数点dot演算は IEEE-754/2008 標準に従う。
標準の行列浮動小数点命令は、要素を IEEE-754/2008 互換値として扱う。行列浮動小数点オペランドの EEW が、サポートされている IEEE 浮動小数点型に対応しない場合、その命令エンコーディングは予約されている。bf16-extension の場合、16 ビット浮動小数点要素は bf16 または fp16 と見なされる。
浮動小数点行列乗算(Non-widen)
Non-widen浮動小数点行列乗算は、ソースオペランドとデスティネーションオペランドのデータ幅が同じであることを示す。
fmmacc.h
: fp16/bf16 浮動小数点、xmisa レジスタの bit[3] が 0 の場合は不正。fmmacc.s
: fp32 浮動小数点、xmisa レジスタの bit[4] が 0 の場合は不正fmmacc.d
: fp64 浮動小数点、xmisa レジスタの bit[5] が 0 の場合は不正
# float matrix multiplication, md = md + ms1*ms2 fmmacc.h md, ms2, ms1 fmmacc.s md, ms2, ms1 fmmacc.d md, ms2, ms1
fmmacc.sの場合、最大行列形状は次のようになる:
- matrixA: M ⇐ RLEN/32, K ⇐ RLEN/16
- matrixB: N ⇐ RLEN/16, K ⇐ RLEN/16
- matrixC: M ⇐ RLEN/32, N ⇐ RLEN/16
行列Bのデータ幅は行列AとCの2倍であるため、行列Bではms2とms2+1で指定される2つの行列レジスタ(レジスタペア)が使用される。奇数番目のms2を指定する命令は予約されている。RLEN=128の場合の動作を以下に示す。
fmmacc.d(64ビット浮動小数点行列乗算・加算命令)の場合、最大行列形状は以下である:
- matrixA: M ⇐ RLEN/32, K ⇐ RLEN/64
- matrixB: N ⇐ RLEN/32, K ⇐ RLEN/64
- matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32
行列Cのデータ幅は行列AとBの2倍であるため、行列Cではmdとmd+1で指定される2つの行列レジスタ(レジスタペア)が使用される。奇数番号のmdを指定する命令は予約されている。RLEN=128の場合の動作を以下に示す。
典型的なRLENにおけるfmmacc命令の最大行列サイズのまとめ:
matrix A | matrix B | matrix C | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
RLEN | M | K | data width | N | K | data width | M | N | data width | Gops/GHz | latency | |
fmacc.s | 128 | 4 | 4 | 512 bits | 4 | 4 | 512 bits | 4 | 4 | 512 bits | 32 | 4 |
256 | 8 | 8 | 2048 bits | 8 | 8 | 2048 bits | 8 | 8 | 2048 bits | 128 | 8 | |
512 | 16 | 16 | 8192 bits | 16 | 16 | 8192 bits | 16 | 16 | 8192 bits | 512 | 16 | |
fmacc.h | 128 | 4 | 8 | 512 bits | 8 | 8 | 1024 bits | 4 | 8 | 512 bits | 64 | 8 |
256 | 8 | 16 | 2048 bits | 16 | 16 | 4096 bits | 8 | 16 | 2048 bits | 256 | 16 | |
512 | 16 | 32 | 8192 bits | 32 | 32 | 16384 bits | 16 | 32 | 8192 bits | 1024 | 32 | |
fmacc.d | 128 | 4 | 2 | 512 bits | 4 | 2 | 512 bits | 4 | 4 | 1024 bits | 16 | |
256 | 8 | 4 | 2048 bits | 8 | 4 | 2048 bits | 8 | 8 | 4096 bits | 64 | ||
512 | 16 | 8 | 8192 bits | 16 | 8 | 8192 bits | 16 | 16 | 16384 bits | 256 |
浮動小数点行列乗算(widen)
widen浮動小数点行列乗算は、デスティネーションオペランドのデータ幅がソースオペランドの2倍であることを示す。ソースオペランドのデータ幅は命令エンコーディング含まれている。
fwmmacc.h
: fp16/bf16 浮動小数点ソースと fp32 結果、xmisa レジスタの bit[8] が 0 の場合は不正。fwmmacc.s
: fp32 浮動小数点ソースと fp64 結果 , xmisa レジスタの bit[9] が 0 の場合は不正
#float matrix multiplication, output widen, md = md + ms1*ms2 fwmmacc.h md, ms2, ms1 fwmmacc.s md, ms2, ms1
fwmmacc.h
、16 ビット浮動小数点 widen 行列乗算および加算命令では、要素は fp16 または bf16 データ型がサポートされている場合は bf16 と見なすことができる。最大行列形状は:
- matrixA: M ⇐ RLEN/32, K ⇐ RLEN/16.
- matrixB: N ⇐ RLEN/32, K ⇐ RLEN/16
- matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32
fwmmacc.s
, 32ビット浮動小数点widen行列乗算および加算命令では、最大行列形状は以下のようになる:
- matrixA: M ⇐ RLEN/32, K ⇐ RLEN/32
- matrixB: N ⇐ RLEN/32, K ⇐ RLEN/32
- matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32
行列Cのデータ幅は行列AとBの2倍であるため、行列Cではmdとmd+1で指定される2つの行列レジスタ(レジスタペア)が使用される。奇数番号の md を指定する命令は予約されている.典型的な RLEN における fwmmacc
命令の最大行列サイズのまとめ.
整数行列の乗算(4x widen)
出力データ幅が入力データ幅の4倍である整数行列乗算。命令エンコーディングでサポートされるソースオペランドのデータ幅は int8 と int16 で、その他のデータ幅は予約されている。符号付き/符号なし両方のバージョンが提供されている。したがって、ソース・オペランドは、両オペランド符号付き/両オペランド符号なし/符号付き-符号なし/符号なし-符号なしの両方が可能であり、乗算の結果は、加算および累積の前に符号拡張されます。オーバーフローは無視され、結果は折り返される。
mmaqa.b
/mmaqau.b
/mmaqaus.b
/mmaqasu.b
: int8 の 4 倍拡大行列乗算,xmisa レジスタの bit[1] が 0 の場合は不正.mmaqa.h
/mmaqau.h
/mmaqaus.h
/mmaqasu.h
: int16 の 4 倍拡大行列乗算,xmisa レジスタの bit[2] が 0 の場合は不正.
#8bit data width #signed matrix multiply mmaqa.b md, ms2, ms1 #unsigned matrix multiply mmaqau.b md, ms2, ms1 #unsigned-signed matrix multiply mmaqaus.b md, ms2, ms1 #signed-unsigned matrix multiply mmaqasu.b md, ms2, ms1 #16bit data width #signed matrix multiply mmaqa.h md, ms2, ms1 #unsigned matrix multiply mmaqau.h md, ms2, ms1 #unsigned-signed matrix multiply mmaqaus.h md, ms2, ms1 #signed-unsigned matrix multiply mmaqasu.h md, ms2, ms1
int8 の4回行列乗算の場合、最大行列形状は以下のようになる:
- matrixA: M ⇐ RLEN/32, K ⇐ RLEN/8
- matrixB: N ⇐ RLEN/32, K ⇐ RLEN/8
- matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32
int16の4倍行列乗算の場合、行列Cのデータ幅は行列AとBの4倍であるため、mdとmd+1で指定される行列Cで2つの行列レジスタ(レジスタペア)が使用される。奇数番目のmdを指定する命令は予約されている:
- matrixA: M ⇐ RLEN/32, K ⇐ RLEN/16
- matrixB: N ⇐ RLEN/32, K ⇐ RLEN/16
- matrixC: M ⇐ RLEN/32, N ⇐ RLEN/32
典型的な RLEN における整数行列の乗算および加算命令の最大行列サイズのまとめ:
matrix A | matrix B | matrix C | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
RLEN | M | K | data width | N | K | data width | M | N | data width | Gops/GHz | latency | |
int8 4x | 128 | 4 | 16 | 512 bits | 4 | 16 | 512 bits | 4 | 4 | 512 bits | 128 | 4 |
256 | 8 | 32 | 2048 bits | 8 | 32 | 2048 bits | 8 | 8 | 2048 bits | 512 | 8 | |
512 | 16 | 64 | 8192 bits | 16 | 64 | 8192 bits | 16 | 16 | 8192 bits | 2048 | 16 | |
int16 4x | 128 | 4 | 8 | 512 bits | 4 | 8 | 512 bits | 4 | 4 | 1024 bits | 64 | 4 |
256 | 8 | 16 | 2048 bits | 8 | 16 | 2048 bits | 8 | 8 | 4096 bits | 256 | 8 | |
512 | 16 | 32 | 8192 bits | 16 | 32 | 8192 bits | 16 | 16 | 16384 bits | 1024 | 16 |
行列ロード/ストア命令
行列ロード命令は行列をメモリから行列レジスタにロードし、行列ストア命令は行列を行列レジスタからメモリにストアする。
要素のデータ幅は、バイト/ハーフワード/ワード/ダブルワードを含む命令エンコーディングで、その他のデータ幅は予約されている。ベースアドレスはrs1、行ストライド(バイト)はrs2、md/ms3は行列ロード先、行列ストア元のレジスタインデックスである。
#matrix load mld<b/h/w/d> md, rs2, (rs1) #matrix store mst<b/h/w/d> ms3, rs2, (rs1) #whole matrix load mld<1/2/4/8>m md, (rs1) #whole matrix store mst<1/2/4/8>m ms3, (rs1)
マトリックス形状(MxK)はマトリックスサイズ構成レジスタにあり、MはsizeM、KはsizeK(バイト単位)で与えられる。M=sizeM ⇐ RLEN/32, K=sizeK/ 要素サイズ(バイト), sizeK ⇐ RLEN/8.sizeM < RLEN/32 または sizeK < RLEN/8 の場合、行インデックス > sizeM または列インデックス > (sizeK/ バイト単位の要素サイズ) の行列レジスタデータは、ロードのためにゼロをセットし、ストアのためにメモリに書き込まない。
(1)ノーマル (2)全レジスタロード/ストア の2種類がある。
全レジスタロード/ストアは、sizeM = RLEN/32、sizeK = RLEN/8 で、最大行列サイズのデータをメモリから/メモリへロード/ストアする。マトリックス・サイズの構成は無視される。
rs2[4:3]は0にセットされ、それ以外は予約されている。rs2[2:0]はnfフィールドで、NFIELDSエンコーディングを使用してロードおよびストアするマトリクス・レジスタの数をエンコードする。
nf[2:0] | レジスタ数 |
---|---|
000 | 1 |
001 | 2 |
011 | 4 |
111 | 8 |
others | reserved |
すべての行列ロード/ストア命令は、ゼロ以外の行開始値を生成し、受け入れることができる。行開始レジスタは、行列命令の実行終了時にゼロにリセットさる。
ZIHINTNTL拡張により、マトリクス・メモリ・アクセス命令は、異なるメモリ階層に適合するストリーム・メモリ・アクセス命令として動作することができる。ストリーム・メモリ・アクセス命令は、通常のマトリクス・ロード/ストア命令と同じ機能を持つが、ハードウェア実装によって最適化できる可能性がある、近い将来データが再利用されない可能性がある点が異なる。