FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

Intel Advanced Matrix Extension(AMX)とは何なのか

遅ればせながらIntel Advanced Matrix Extension(AMX)の資料を読み漁ってみた。参考にしたのは以下の資料。

Intelが発表した次世代の命令拡張Intel Advanced Matrx Extension(AMX)は、2021年の第4世代XeonのベースコアとなるSapphire Rapidsから搭載される予定のAI向け拡張命令セットのこと。IntelはこれまでにDeep Learning向けに複数の新規命令拡張を行ってきた。

  • AVX512_VNNI : Cascade Lakeから導入された。CNNカーネルを高速化するためのMultiplication and Addition処理の高速化
  • AVX512_BF16 : 単精度浮動小数点命令をBFloat16に変換するための新規命令追加

しかし今回導入されたAMXはこれらとは全く異なり、AVXの拡張ではなく全く新しい命令と言うことができる。AMXはAVXをベースとするわけではなく、新しい行列演算用のレジスタを導入し、また命令も全く新規に追加されるからである。

AMXは新たに「タイル(Tile)」と呼ばれる新しいレジスタを導入する(ここで原文には「ランク2 テンソルレジスタ」と書いてあるが私はこの意味が良く分からない。2次元行列向けと言うこと?)。また、演算用にアクセラレータも導入される。これらの拡張命令はAVX/2やAVX512と同様に、基本的なパイプラインには影響を与えず、あくまで拡張命令として動作する。

AMXのレジスタは8本のタイルで構成されており、それぞれTMM0からTMM7と名前が付けられている。各タイルは16行×64バイトのサイズであり、トータルで1kBの容量を持っている。これが8本なのでAMXレジスタ全体としては8kBの容量を持っていることになる。

このタイルのことをパレットと呼び、デフォルトではパレット0、パレット1が定義されている。パレット0のことを「initialized state」と書いてあるが正直良く分からない。パレット1が上記のTMM0からTMM7までの8kBのことを指す。

タイルには制御用のレジスタも付加されており、プログラマはこのレジスタを操作することによってタイルのサイズを変更することができる。実現したいアプリケーションの構成によって、タイルの構成を変更することでアプリケーションに最適な構成を実現できる。

f:id:msyksphinz:20200816151100p:plain

AMXの命令

AMXは現在12個の命令が定義されており、大きく3種類に分類される。

コンフィグレーションは、アプリケーションを実行する際の先頭で行うだけで良い。AVX命令と同様に、画像などのデータをタイルにロードし、タイルに対する操作を行い、結果をストアするという手順で実行する。

現在の仕様では、アクセラレータが1つだけ定義されている。TMUL(Tile Matrix Multiply)では、$TileC[M][N] += TileA[M][K] * TileB[K][N]$に相当する演算(つまり$TileA$と$TileB$の行列積を一発で実行する)命令が定義されている。

12個の命令は大きく分けて3種類に分類されており、

  • AMX-TILE:最もベースとなるタイルコンフィグレーション命令やロード命令など。7命令。
  • AMX-INT8:整数演算向けのタイルオペレーション命令。4命令。
  • AMX-BF16:Bfloat16向けのタイルオペレーション命令。1命令。

先ほどの12命令を分類すると以下のように分類できるらしい。

分類 命令種類
AMX-TILE LDTILECFG, STTILECFG, TILELOADD, TILELOADDT1, TILESTORED, TILERELEASE, TILEZERO
AMX-INT8 TDPBSSD, TDPBSUD, TDPBUSD, TDPBUUD
AMX-BF16 TDPBF16PS

TMULのアーキテクチャ

TMULはその名の通り行列積を実行するためのユニットであるが、実体はFMAを2次元に配置し、行列Aと行列Bから値を流し込むことによって演算を行うスタイルになっている。ここで具体的にどのようなフローで行列積を計算しているのか追いかけようと思ったが良く分からなかった。縦軸をSIMD方向として積の演算を行い、それをすべて加算でまとめ上げると確かに行列積の計算になる気がするが、それではFMA(累積)を使う意味がないのではないか?このあたりのアーキテクチャはもう少し詳細が出てきたら読んでみたい。

f:id:msyksphinz:20200816151122p:plain