FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

Scalaのvar型を使用したChiselの記述量削減テクニック

Scalaにはval型とvar型という2種類の型の種類が存在する。var型は再割り当て可能, val型は再割り当て不可能な型というもので、要するに、

  • var型 : 処理の中で何度も書き換えができる。
  • val型 : 処理の中で一度しか書き換えることができない。

というものだ。Chiselのチュートリアルやサンプルコードを見ていると、大体の場合はval型が使用されている。これはハードウェアは一度だけassignするのが基本なので自然の流れではあるのだが、var型を使うことで場合によってはコード量を削減できる。このようなケースについて調査する。

例えば、テーブル探索において、IDと値のペアで格納されているハッシュのような構造を考え、その中のハッシュ値と一致するアドレスの値を探索する。 もしテーブルの中に同じIDが複数存在している場合、テーブルのアドレスの大きい側の値を優先して取り出す。このような回路をChiselで設計する場合、以下のようなコードを記述することになる。

f:id:msyksphinz:20191210233029p:plain
SearchTableの実装の概要
  val entry_id = Reg(Vec(8, UInt(8.W)))
  val entry_value = Reg(Vec(8, UInt(32.W)))

  val is_hit = Wire(Vec(8, Bool()))
  val hit_value = Wire(Vec(8, UInt(32.W)))

  for (i <- 0 until 8) {
    is_hit(i) := (entry_id(i) === io.addr)
    hit_value(i) := Mux(is_hit(i), entry_value(i), if(i==0) { 0.U} else { hit_value(i-1) })
  }
  io.id_out := hit_value(7)
  io.id_hit := is_hit(7)

要点としては、

  • is_hit(i)により、各エントリがIDの値と一致しているかをチェックする。
  • hit_value(i)は、もしエントリiがヒットしていればその値を選択し、そうでなければ一つ前のエントリの結果を引き継ぐ。

というわけで、一つ前のエントリを保持するために、hit_valueをエントリの数だけ用意したり、インデックスが0の時だけ特殊な処理を追加したりと若干面倒くさい。そこでどうするかというと、手続き型のように記述できる(つまり上書きが可能な)var型を使用するという訳である。

  val entry_id = Reg(Vec(8, UInt(8.W)))
  val entry_value = Reg(Vec(8, UInt(32.W)))

  var is_hit = false.B
  var hit_value = 0.U
  for (i <- 0 until 8) {
    is_hit = (entry_id(i) === io.addr)
    hit_value = Mux(is_hit, entry_value(i), hit_value)
  }
  io.id_out := hit_value
  io.id_hit := is_hit

まず、is_hithit_valueを複数エントリ分持たなくても良くなる。当然ハードウェアに生成するときは(valを使用した場合と等価なので)ハードウェア量としては変わらないのでが、ソースコード量の削減になる。

生成されるハードウェア量はどちらも一緒が、Chiselの記述量の削減としてこのような記法も可能、ということだ。


付録 : 全ソースコード

class search_table_var extends Module
{
  val io = IO(new Bundle {
    val addr = Input(UInt(8.W))
    val id_out = Output(UInt(32.W))
    val id_hit = Output(Bool())

    val w_en    = Input(Bool())
    val w_addr  = Input(UInt(3.W))
    val w_id    = Input(UInt(8.W))
    val w_value = Input(UInt(32.W))
  })

  val entry_id = Reg(Vec(8, UInt(8.W)))
  val entry_value = Reg(Vec(8, UInt(32.W)))

  var is_hit = false.B
  var hit_id = 0.U
  for (i <- 0 until 8) {
    is_hit = (entry_id(i) === io.addr)
    hit_id = Mux(is_hit, entry_value(i), hit_id)
  }
  io.id_out := hit_id
  io.id_hit := is_hit


  when (io.w_en) {
    entry_id(io.w_addr) := io.w_id
    entry_value(io.w_addr) := io.w_value
  }
}