プロセッサアーキテクチャについて再度復習その6。分岐予測の基本について学んだので、実際の実装を見てみたいと思う。
RISC-VのアウトオブオーダBOOMの実装を眺めてみることにした。こちらはChiselをベースにしているので読み解くのは少し厄介だが、できない事は無い。
TAGE分岐予測器
ベースとなるクラス。br-predictor.scala
が基底クラスとなっている。TageBrPredictor
はBoomBrPredictor
から派生している。
boom/src/main/scala/bpu/bpd/br-predictor.scala
abstract class BoomBrPredictor( val historyLength: Int)(implicit p: Parameters) extends BoomModule { ...
boom/src/main/scala/bpu/bpd/tage/tage.scala
class TageBrPredictor( numTables: Int, tableSizes: Seq[Int], historyLengths: Seq[Int], tagSizes: Seq[Int], cntrSz: Int, ubitSz: Int )(implicit p: Parameters) extends BoomBrPredictor( historyLength = historyLengths.max) ...
TAGEのパラメータ
TAGEを構成するにあたり、必要なのは以下のパラメータだ。
numTable
: TAGEテーブルの数 (=4)tableSize
: TAGEテーブルのエントリ数。すべてのテーブルで1024。historyLength
: 分岐履歴レジスタのうち、TAGEのハッシュ生成に使うビット長。- 4つのテーブルでそれぞれ、27, 45, 63, 90
tagSize
: エントリのヒット・ミスを示すためのタグビットのサイズ。- 4つのテーブルですべて9ビット
cntrSz
: 各エントリのカウンタのサイズ (=3)。ubitSz
: usefulnessカウンタのサイズ (=1)。
テーブルの構成。TAGEには分岐予測のためのテーブルが必要だ。実際のテーブルのImplementationはTageTable
に記載されているが、デフォルトではこのTAGEテーブルがnumTables(=4)
だけ実装されている。
val tables = for (i <- 0 until numTables) yield { val table = Module(new TageTable( id = i, numEntries = tableSizes(i), ... // check that the user ordered his TAGE tables properly if (i > 0) require(historyLengths(i) > historyLengths(i-1)) table }
TAGETableの構成
TAGETableは以下の入出力ポートを持つ。
bp1_req
: 予測のリクエスト。インデックスとタグを含む。
class TageTableReq(val idxSz: Int, val tagSz: Int) extends Bundle { val index = UInt(idxSz.W) val tag = UInt(tagSz.W) }
bp2_req
: 予測のレスポンス。bp1_req
に対して1サイクル遅れる。以下の情報が返される。
class TageTableResp(val tagSz: Int, val cntrSz: Int, val ubitSz: Int)(implicit p: Parameters) extends BoomBundle { val tag = UInt(tagSz.W) // 入力TAGの値をそのまま返す。 val cntr = UInt(cntrSz.W) // 2-bit counterの値そのもの val cidx = UInt(log2Ceil(fetchWidth).W) // cidx。複数命令Issueの場合は対象命令の位置を返す。 val ubit = UInt(ubitSz.W) // usefulnessビット def predictsTaken = cntr(cntrSz-1) // 2ビットカウンタのビットのうち、最上位ビットが立っていればTaken。 }
write
: TAGEの情報を更新するための信号。予測結果に基づいて情報をアップデートする。
class TageTableWrite(val idxSz: Int, val tagSz: Int, val cntrSz: Int, val ubitSz: Int)(implicit p: Parameters) extends BoomBundle { val index = UInt(idxSz.W) // アップデート対位法のテーブルインデックス val old = new TageTableEntry(tagSz, cntrSz, ubitSz) // What kind of write are we going to perform? val allocate = Bool() // テーブルのアロケートを行う。 val update = Bool() // テーブルのアップデートを行う。 val degrade = Bool() // テーブルの削除を行う。 // What was the outcome of the branch? val mispredict = Bool() // 分岐の結果、予測が外れた。 val taken = Bool() // 分岐の結果が、Takenであった。 }
予測前の初期化。ステートマシンによりテーブルの初期化を行う。
val s_reset :: s_wait :: s_clear :: s_idle :: Nil = Enum(4) val fsm_state = RegInit(s_reset) val nResetLagCycles = 64 val nBanks = 1 val (lag_counter, lag_done) = Counter(fsm_state === s_wait, nResetLagCycles) val (clear_row_addr, clear_done) = Counter(fsm_state === s_clear, numEntries/nBanks) switch (fsm_state) { is (s_reset) { fsm_state := s_wait } is (s_wait) { when (lag_done) { fsm_state := s_clear } } is (s_clear) { when (clear_done) { fsm_state := s_idle } } is (s_idle) { when (io.do_reset) { fsm_state := s_clear } } }
リセット後にステートマシンがs_wait
に代わり、64サイクル待つ(リセットの非同期対応?)。次にs_clear
ステートに移り、テーブルのエントリを1つずつ初期化する。すべてが完了すると、s_idle
状態に遷移し、定常状態に移る。
//------------------------------------------------------------ // Update (Commit) when (io.write.valid || fsm_state === s_clear) ... when (fsm_state === s_clear) { widx := clear_row_addr wentry.tag := 0.U wentry.cntr := 0.U wentry.cidx := 0.U wentry.ubit := 0.U ... ram.write(widx, wentry)
TAGEのテーブルそのものは以下で構成されている。ビットカウンタが3ビットあり、しかも0から始まると言う事は、Strongly Not Takenから始まると言う事だろうか?これでいいの?
class TageTableEntry(val tagSz: Int, val cntrSz: Int, val ubitSz: Int)(implicit p: Parameters) extends BoomBundle { val tag = UInt(tagSz.W) // Tag. val cntr = UInt(cntrSz.W) // Prediction counter. val cidx = UInt(log2Ceil(fetchWidth).W) // Control-flow instruction index. val ubit = UInt(ubitSz.W) // Usefulness counter. }
実体はどのように作ってもかまわないだろうが、Read/Writeをそれぞれ1ポートずつ持つRAMが生成されている。上記のポートは合計16ビットのデータを保持するので、16ビットのSRAMを接続しなければならないと言う事になる。
val ram = SyncReadMema(numEntries, new TageTableEntry(tagSz, cntrSz, ubitSz)) ram.suggestName("TageTableDataArray")
module TageTableDataArray( input [9:0] R0_addr, input R0_en, input R0_clk, output [8:0] R0_data_tag, output [2:0] R0_data_cntr, output [2:0] R0_data_cidx, output R0_data_ubit, input [9:0] W0_addr, input W0_en, input W0_clk, input [8:0] W0_data_tag, input [2:0] W0_data_cntr, input [2:0] W0_data_cidx, input W0_data_ubit ); ...
予測の手順
bp1
から入ってきた信号は、まずはすべてのテーブルに対してRead Requestが発行される。
tables_io.zipWithIndex.map{ case (table, i) => table.InitializeIo // perform tag hash val n = numTables bp1_tags(i) := TagHash(r_f1_fetchpc, bp1_idxs((i+1) % n), bp1_idxs((i+2) % n)) // Send prediction request. --- table.bp1_req.valid := f1_valid table.bp1_req.bits.index := bp1_idxs(i) table.bp1_req.bits.tag := bp1_tags(i) table.do_reset := false.B // TODO }
この時のインデックスとタグの作り方だが、
- テーブルインデックス : フェッチPC、グローバルヒストリ(
r_f1_history
)を使ってインデックスを生成する。
tables_io.zipWithIndex.map{ case (table, i) =>
bp1_idxs(i) := IdxHash(r_f1_fetchpc, r_f1_history, historyLengths(i), log2Ceil(tableSizes(i)))
}
- テーブルタグ : フェッチPC、1つ後ろのテーブルインデックス、2つ後ろのテーブルインデックスを使用している。これはなんでだ?
tables_io.zipWithIndex.map{ case (table, i) => ... bp1_tags(i) := TagHash(r_f1_fetchpc, bp1_idxs((i+1) % n), bp1_idxs((i+2) % n)) ...
各TAGEテーブルでの参照は、RAMに対してReadリクエストを出している。1サイクル後に応答が返ってくる。同時に、タグがヒットするかを確認している。
boom/src/main/scala/bpu/bpd/tage/tage-table.scala
val s2_out = ram.read(s1_idx, s1_valid) val s2_tag_hit = s2_out.tag === RegEnable(s1_tag, s1_valid) // TAGEテーブルの応答 io.bp2_resp.valid := s2_tag_hit && RegNext(fsm_state === s_idle, false.B) io.bp2_resp.bits.tag := s2_out.tag io.bp2_resp.bits.cntr := s2_out.cntr io.bp2_resp.bits.cidx := s2_out.cidx io.bp2_resp.bits.ubit := s2_out.ubit
TAGEテーブルから戻ってきた信号は、ElasticReg(2エントリのキューのようなもの)に格納する。このキューはテーブルの数だけ並んでいる。
boom/src/main/scala/bpu/bpd/tage/tage.scala
// Buffer all of the table responses into a queue. // Match the other ElasticRegs in the FrontEnd. val q_f3_resps = for (i <- 0 until numTables) yield { val q_resp = withReset(reset.toBool || io.fe_clear || io.f4_redirect) {Module(new ElasticReg(Valid(new TageTableResp(tagSizes.max, cntrSz, ubitSz))))} q_resp.io.enq.valid := io.f2_valid q_resp.io.enq.bits := tables_io(i).bp2_resp q_resp.io.deq.ready := io.resp.ready assert (q_resp.io.enq.ready === !io.f2_stall) assert (q_resp.io.deq.valid === q_f3_history.io.deq.valid) q_resp }
それぞれのTAGEテーブルについて、PCがヒットしたか、Taken/Not Takenの情報を取得する。
// get predictions. val f3_tag_hits = q_f3_resps.map( q => q.io.deq.valid && q.io.deq.bits.valid ) val f3_predictions = q_f3_resps.map( q => q.io.deq.bits.bits )
予測テーブルのヒットマトリックスを作成する。ヒットマトリックスというのは、縦軸がTAGEテーブルエントリの番号、横軸が同時にフェッチされた命令ラインのインデックス(4命令同時デコードならば4ビット)となる。これにより、どの命令が度のテーブルでヒットしたかが明らかになる。
val f3_hits_matrix = for (i <- 0 until numTables) yield { // For each table, return a one-hot bit-mask if there's a tag-hit AND cfi-idx hit. io.f3_is_br.asUInt & Mux(f3_tag_hits(i), UIntToOH(f3_predictions(i).cidx), 0.U) }
この中からヒットした値を探し出すのだが、HistoryLengthの長い方が優先されるので、TAGEテーブルをインデックスとは逆純に追いかけていき、最初にヒットしたのがf3_best_hit(w), f3_best_ids(w)
であり、2番目にHistoryLengthが長くてヒットしたものがf3_alt_hit(w), f3_alt_ids(w)
として格納される(この第2候補は実際には使用されない。)
この情報を使用して、対象となる命令のTaken/NotTakenを決定する。
f3_takens(w) :=
Mux(f3_best_hits(w),
VecInit(f3_predictions)(f3_best_ids(w)).predictsTaken,
false.B)
こうして、TAGEにヒットした場合は、実際の命令の位置predicted_cidx
を計算したのちに予測結果をResponce信号に接続して応答する。
val resp_info = Wire(new TageResp( fetchWidth = fetchWidth, numTables = numTables, maxHistoryLength = historyLengths.max, maxIndexSz = log2Ceil(tableSizes.max), maxTagSz = tagSizes.max, cntrSz = cntrSz, ubitSz = ubitSz)) val f3_has_hit = f3_best_hits.reduce(_|_) val f3_pred = VecInit(f3_predictions)(f3_best_ids(predicted_cidx)) assert (!(f3_has_hit && !f3_best_hits(predicted_cidx)), "[tage] was a hit but our cidx is wrong.") io.resp.valid := f3_has_hit || io.f2_bim_resp.valid io.resp.bits.takens := Mux(f3_has_hit, GetPredictionOH(f3_pred.cidx, f3_pred.cntr), io.f2_bim_resp.bits.getTakens)
TAGEの更新
TAGEテーブルの更新は、命令のコミット時に行われる。コミットする命令のインデックスを計算し、アップデートを行うためのTAGの計算を行う。
boom/src/main/scala/bpu/bpd/tage/tage.scala
val r_commit = RegNext(io.commit) val r_info = (r_commit.bits.info).asTypeOf(new TageResp( fetchWidth = fetchWidth, numTables = numTables, maxHistoryLength = historyLengths.max, maxIndexSz = log2Ceil(tableSizes.max), maxTagSz = tagSizes.max, cntrSz = cntrSz, ubitSz = ubitSz )) val com_indexes = historyLengths zip tableSizes map { case (hlen, tsize) => val idx = IdxHash(r_commit.bits.fetch_pc, r_commit.bits.history, hlen, log2Ceil(tsize)) idx } val com_tags = for (i <- 0 until numTables) yield { val n = numTables TagHash(r_commit.bits.fetch_pc, com_indexes((i+1) % n), com_indexes((i+2) % n)) }
最終的に各テーブルに対してWriteEntryを行う。
for (i <- 0 until numTables) { when (table_wens(i)) { // construct old entry so we can overwrite parts without having to do a RMW. val com_entry = Wire(new TageTableEntry(tagSizes.max, cntrSz, ubitSz)) com_entry.tag := Mux(table_allocates(i), com_tags(i), r_info.tags(i)) com_entry.cntr := r_info.cntrs(i) com_entry.cidx := Mux(table_allocates(i), r_commit.bits.miss_cfi_idx, r_info.cidxs(i)) com_entry.ubit := r_info.ubits(i) tables_io(i).WriteEntry( com_indexes(i), com_entry, table_allocates(i), table_updates(i), table_degrades(i), r_commit.bits.mispredict, r_commit.bits.taken) // TODO XXX if "update", only write to u-bit (!mispredict=>ubit) if !alt_agrees. } }