FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

高性能プロセッサの分岐予測のサーベイ論文を読んで分岐予測について学ぶ (6. RISC-VのアウトオブオーダコアBOOMのTAGE実装を見てみる)

プロセッサアーキテクチャについて再度復習その6。分岐予測の基本について学んだので、実際の実装を見てみたいと思う。

RISC-VのアウトオブオーダBOOMの実装を眺めてみることにした。こちらはChiselをベースにしているので読み解くのは少し厄介だが、できない事は無い。

github.com


TAGE分岐予測器

ベースとなるクラス。br-predictor.scalaが基底クラスとなっている。TageBrPredictorBoomBrPredictorから派生している。

  • boom/src/main/scala/bpu/bpd/br-predictor.scala
abstract class BoomBrPredictor(
   val historyLength: Int)(implicit p: Parameters) extends BoomModule
{
...
  • boom/src/main/scala/bpu/bpd/tage/tage.scala
class TageBrPredictor(
  numTables: Int,
  tableSizes: Seq[Int],
  historyLengths: Seq[Int],
  tagSizes: Seq[Int],
  cntrSz: Int,
  ubitSz: Int
  )(implicit p: Parameters)
  extends BoomBrPredictor(
    historyLength = historyLengths.max)
...

TAGEのパラメータ

TAGEを構成するにあたり、必要なのは以下のパラメータだ。

  • numTable : TAGEテーブルの数 (=4)
  • tableSize : TAGEテーブルのエントリ数。すべてのテーブルで1024。
  • historyLength : 分岐履歴レジスタのうち、TAGEのハッシュ生成に使うビット長。
    • 4つのテーブルでそれぞれ、27, 45, 63, 90
  • tagSize : エントリのヒット・ミスを示すためのタグビットのサイズ。
    • 4つのテーブルですべて9ビット
  • cntrSz : 各エントリのカウンタのサイズ (=3)。
  • ubitSz : usefulnessカウンタのサイズ (=1)。

テーブルの構成。TAGEには分岐予測のためのテーブルが必要だ。実際のテーブルのImplementationはTageTableに記載されているが、デフォルトではこのTAGEテーブルがnumTables(=4)だけ実装されている。

  val tables = for (i <- 0 until numTables) yield {
    val table = Module(new TageTable(
      id                 = i,
      numEntries        = tableSizes(i),
...
    // check that the user ordered his TAGE tables properly
    if (i > 0) require(historyLengths(i) > historyLengths(i-1))

    table
  }

TAGETableの構成

TAGETableは以下の入出力ポートを持つ。

  • bp1_req : 予測のリクエスト。インデックスとタグを含む。
  class TageTableReq(val idxSz: Int, val tagSz: Int) extends Bundle
  {
    val index = UInt(idxSz.W)
    val tag = UInt(tagSz.W)
  }
  • bp2_req : 予測のレスポンス。bp1_reqに対して1サイクル遅れる。以下の情報が返される。
  class TageTableResp(val tagSz: Int, val cntrSz: Int, val ubitSz: Int)(implicit p: Parameters) extends BoomBundle
  {
    val tag  = UInt(tagSz.W)                 // 入力TAGの値をそのまま返す。
    val cntr = UInt(cntrSz.W)                // 2-bit counterの値そのもの
    val cidx = UInt(log2Ceil(fetchWidth).W)  // cidx。複数命令Issueの場合は対象命令の位置を返す。
    val ubit = UInt(ubitSz.W)                // usefulnessビット
  
    def predictsTaken = cntr(cntrSz-1)       // 2ビットカウンタのビットのうち、最上位ビットが立っていればTaken。
  }
  • write : TAGEの情報を更新するための信号。予測結果に基づいて情報をアップデートする。
  class TageTableWrite(val idxSz: Int, val tagSz: Int, val cntrSz: Int, val ubitSz: Int)(implicit p: Parameters)
    extends BoomBundle
  {
    val index = UInt(idxSz.W) // アップデート対位法のテーブルインデックス
    val old = new TageTableEntry(tagSz, cntrSz, ubitSz)
  
    // What kind of write are we going to perform?
    val allocate = Bool()     // テーブルのアロケートを行う。
    val update   = Bool()     // テーブルのアップデートを行う。
    val degrade  = Bool()     // テーブルの削除を行う。
    // What was the outcome of the branch?
    val mispredict = Bool()   // 分岐の結果、予測が外れた。
    val taken = Bool()        // 分岐の結果が、Takenであった。
  }

予測前の初期化。ステートマシンによりテーブルの初期化を行う。

  val s_reset :: s_wait :: s_clear :: s_idle :: Nil = Enum(4)
  val fsm_state = RegInit(s_reset)
  val nResetLagCycles = 64
  val nBanks = 1
  val (lag_counter, lag_done) = Counter(fsm_state === s_wait, nResetLagCycles)
  val (clear_row_addr, clear_done) = Counter(fsm_state === s_clear, numEntries/nBanks)

  switch (fsm_state) {
    is (s_reset) { fsm_state := s_wait }
    is (s_wait)  { when (lag_done) { fsm_state := s_clear } }
    is (s_clear) { when (clear_done) { fsm_state := s_idle } }
    is (s_idle)  { when (io.do_reset) { fsm_state := s_clear } }
  }

リセット後にステートマシンがs_waitに代わり、64サイクル待つ(リセットの非同期対応?)。次にs_clearステートに移り、テーブルのエントリを1つずつ初期化する。すべてが完了すると、s_idle状態に遷移し、定常状態に移る。

  //------------------------------------------------------------
  // Update (Commit)

  when (io.write.valid || fsm_state === s_clear)
...
    when (fsm_state === s_clear) {
      widx := clear_row_addr
      wentry.tag := 0.U
      wentry.cntr := 0.U
      wentry.cidx := 0.U
      wentry.ubit := 0.U
...
    ram.write(widx, wentry)

TAGEのテーブルそのものは以下で構成されている。ビットカウンタが3ビットあり、しかも0から始まると言う事は、Strongly Not Takenから始まると言う事だろうか?これでいいの?

class TageTableEntry(val tagSz: Int, val cntrSz: Int, val ubitSz: Int)(implicit p: Parameters) extends BoomBundle
{
  val tag  = UInt(tagSz.W)                 // Tag.
  val cntr = UInt(cntrSz.W)                // Prediction counter.
  val cidx = UInt(log2Ceil(fetchWidth).W)  // Control-flow instruction index.
  val ubit = UInt(ubitSz.W)                // Usefulness counter.
}

実体はどのように作ってもかまわないだろうが、Read/Writeをそれぞれ1ポートずつ持つRAMが生成されている。上記のポートは合計16ビットのデータを保持するので、16ビットのSRAMを接続しなければならないと言う事になる。

  val ram = SyncReadMema(numEntries, new TageTableEntry(tagSz, cntrSz, ubitSz))
  ram.suggestName("TageTableDataArray")
module TageTableDataArray(
  input  [9:0] R0_addr,
  input        R0_en,
  input        R0_clk,
  output [8:0] R0_data_tag,
  output [2:0] R0_data_cntr,
  output [2:0] R0_data_cidx,
  output       R0_data_ubit,
  input  [9:0] W0_addr,
  input        W0_en,
  input        W0_clk,
  input  [8:0] W0_data_tag,
  input  [2:0] W0_data_cntr,
  input  [2:0] W0_data_cidx,
  input        W0_data_ubit
);
...

予測の手順

bp1から入ってきた信号は、まずはすべてのテーブルに対してRead Requestが発行される。

  tables_io.zipWithIndex.map{ case (table, i) =>
    table.InitializeIo

    // perform tag hash
    val n = numTables
    bp1_tags(i) := TagHash(r_f1_fetchpc, bp1_idxs((i+1) % n), bp1_idxs((i+2) % n))

    // Send prediction request. ---
    table.bp1_req.valid := f1_valid
    table.bp1_req.bits.index := bp1_idxs(i)
    table.bp1_req.bits.tag := bp1_tags(i)

    table.do_reset := false.B // TODO
  }

この時のインデックスとタグの作り方だが、

  • テーブルインデックス : フェッチPC、グローバルヒストリ(r_f1_history)を使ってインデックスを生成する。
    tables_io.zipWithIndex.map{ case (table, i) =>
      bp1_idxs(i) := IdxHash(r_f1_fetchpc, r_f1_history, historyLengths(i), log2Ceil(tableSizes(i)))
    }
  • テーブルタグ : フェッチPC、1つ後ろのテーブルインデックス、2つ後ろのテーブルインデックスを使用している。これはなんでだ?
     tables_io.zipWithIndex.map{ case (table, i) =>
  ...
       bp1_tags(i) := TagHash(r_f1_fetchpc, bp1_idxs((i+1) % n), bp1_idxs((i+2) % n))
  ...

各TAGEテーブルでの参照は、RAMに対してReadリクエストを出している。1サイクル後に応答が返ってくる。同時に、タグがヒットするかを確認している。

  • boom/src/main/scala/bpu/bpd/tage/tage-table.scala
  val s2_out = ram.read(s1_idx, s1_valid)
  val s2_tag_hit = s2_out.tag === RegEnable(s1_tag, s1_valid)
  // TAGEテーブルの応答
  io.bp2_resp.valid     := s2_tag_hit && RegNext(fsm_state === s_idle, false.B)
  io.bp2_resp.bits.tag  := s2_out.tag
  io.bp2_resp.bits.cntr := s2_out.cntr
  io.bp2_resp.bits.cidx := s2_out.cidx
  io.bp2_resp.bits.ubit := s2_out.ubit

TAGEテーブルから戻ってきた信号は、ElasticReg(2エントリのキューのようなもの)に格納する。このキューはテーブルの数だけ並んでいる。

  • boom/src/main/scala/bpu/bpd/tage/tage.scala
  // Buffer all of the table responses into a queue.
  // Match the other ElasticRegs in the FrontEnd.
  val q_f3_resps = for (i <- 0 until numTables) yield {
    val q_resp = withReset(reset.toBool || io.fe_clear || io.f4_redirect)
     {Module(new ElasticReg(Valid(new TageTableResp(tagSizes.max, cntrSz, ubitSz))))}

    q_resp.io.enq.valid := io.f2_valid
    q_resp.io.enq.bits := tables_io(i).bp2_resp
    q_resp.io.deq.ready := io.resp.ready

    assert (q_resp.io.enq.ready === !io.f2_stall)
    assert (q_resp.io.deq.valid === q_f3_history.io.deq.valid)

    q_resp
  }

それぞれのTAGEテーブルについて、PCがヒットしたか、Taken/Not Takenの情報を取得する。

  // get predictions.
  val f3_tag_hits    = q_f3_resps.map( q => q.io.deq.valid && q.io.deq.bits.valid )
  val f3_predictions = q_f3_resps.map( q => q.io.deq.bits.bits )

予測テーブルのヒットマトリックスを作成する。ヒットマトリックスというのは、縦軸がTAGEテーブルエントリの番号、横軸が同時にフェッチされた命令ラインのインデックス(4命令同時デコードならば4ビット)となる。これにより、どの命令が度のテーブルでヒットしたかが明らかになる。

  val f3_hits_matrix = for (i <- 0 until numTables) yield {
    // For each table, return a one-hot bit-mask if there's a tag-hit AND cfi-idx hit.
    io.f3_is_br.asUInt & Mux(f3_tag_hits(i), UIntToOH(f3_predictions(i).cidx), 0.U)
  }

この中からヒットした値を探し出すのだが、HistoryLengthの長い方が優先されるので、TAGEテーブルをインデックスとは逆純に追いかけていき、最初にヒットしたのがf3_best_hit(w), f3_best_ids(w)であり、2番目にHistoryLengthが長くてヒットしたものがf3_alt_hit(w), f3_alt_ids(w)として格納される(この第2候補は実際には使用されない。)

この情報を使用して、対象となる命令のTaken/NotTakenを決定する。

    f3_takens(w) :=
      Mux(f3_best_hits(w),
        VecInit(f3_predictions)(f3_best_ids(w)).predictsTaken,
        false.B)

こうして、TAGEにヒットした場合は、実際の命令の位置predicted_cidxを計算したのちに予測結果をResponce信号に接続して応答する。

  val resp_info = Wire(new TageResp(
    fetchWidth = fetchWidth,
    numTables = numTables,
    maxHistoryLength = historyLengths.max,
    maxIndexSz = log2Ceil(tableSizes.max),
    maxTagSz = tagSizes.max,
    cntrSz = cntrSz,
    ubitSz = ubitSz))

  val f3_has_hit = f3_best_hits.reduce(_|_)
  val f3_pred = VecInit(f3_predictions)(f3_best_ids(predicted_cidx))

  assert (!(f3_has_hit && !f3_best_hits(predicted_cidx)), "[tage] was a hit but our cidx is wrong.")

  io.resp.valid       := f3_has_hit || io.f2_bim_resp.valid
  io.resp.bits.takens := Mux(f3_has_hit,
                             GetPredictionOH(f3_pred.cidx, f3_pred.cntr),
                             io.f2_bim_resp.bits.getTakens)

TAGEの更新

TAGEテーブルの更新は、命令のコミット時に行われる。コミットする命令のインデックスを計算し、アップデートを行うためのTAGの計算を行う。

  • boom/src/main/scala/bpu/bpd/tage/tage.scala
  val r_commit = RegNext(io.commit)
  val r_info = (r_commit.bits.info).asTypeOf(new TageResp(
    fetchWidth = fetchWidth,
    numTables = numTables,
    maxHistoryLength = historyLengths.max,
    maxIndexSz = log2Ceil(tableSizes.max),
    maxTagSz = tagSizes.max,
    cntrSz = cntrSz,
    ubitSz = ubitSz
  ))

  val com_indexes = historyLengths zip tableSizes map { case (hlen, tsize)  =>
    val idx = IdxHash(r_commit.bits.fetch_pc, r_commit.bits.history, hlen, log2Ceil(tsize))
    idx
  }

  val com_tags = for (i <- 0 until numTables) yield {
    val n = numTables
    TagHash(r_commit.bits.fetch_pc, com_indexes((i+1) % n), com_indexes((i+2) % n))
  }

最終的に各テーブルに対してWriteEntryを行う。

  for (i <- 0 until numTables) {
    when (table_wens(i)) {
      // construct old entry so we can overwrite parts without having to do a RMW.
      val com_entry = Wire(new TageTableEntry(tagSizes.max, cntrSz, ubitSz))
      com_entry.tag  := Mux(table_allocates(i), com_tags(i), r_info.tags(i))
      com_entry.cntr := r_info.cntrs(i)
      com_entry.cidx := Mux(table_allocates(i), r_commit.bits.miss_cfi_idx,  r_info.cidxs(i))
      com_entry.ubit := r_info.ubits(i)

      tables_io(i).WriteEntry(
        com_indexes(i),
        com_entry,
        table_allocates(i),
        table_updates(i),
        table_degrades(i),
        r_commit.bits.mispredict,
        r_commit.bits.taken)
      // TODO XXX if "update", only write to u-bit (!mispredict=>ubit) if !alt_agrees.
    }
  }
f:id:msyksphinz:20190712231551p:plain
TAGE分岐予測器の構造