FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

MLIRの勉強 (6. 演算を定義する)

MLIRについて勉強している。

独自言語を作成し、その中間表現をMLIRを使って表現してみることに挑戦する。

前回のテストでは、最適化が効きすぎて変数がまとめられてしまい、思い通りのMLIRを出力することができていなかった。次は演算子を定義し、これらをMLIRで出力できるようにしたい。

対象とするのは、以下のMYSVの構文だ。

assign A = 3;
assign B = 4;

assign Hoge1 = A + B;
assign Hoge2 = A * B;

まずはLexerとParserを拡張する。2項演算子をサポートさせる。

  /// expression::= primary binop rhs
  std::unique_ptr<ExprAST> parseExpr() {
    auto lhs = parsePrimary();
    if (!lhs)
      return nullptr;
    return parseBinOpRHS(0, std::move(lhs));
  }

これをMLIRに変換するために、次のMLIRのOperationを定義する。MulOpとAddOpだ。

//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

def MulOp : MYSV_Op<"mul", [NoSideEffect]> {
  let summary = "element-wise multiplication operation";
  let description = [{
    The "mul" operation performs element-wise multiplication between two
    tensors. The shapes of the tensor operands are expected to match.
  }];

  let arguments = (ins I64:$lhs, I64:$rhs);
  let results = (outs I64);

}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//

def AddOp : MYSV_Op<"add", [NoSideEffect]> {
  let summary = "element-wise addition operation";
  let description = [{
    The "add" operation performs element-wise addition between two tensors.
    The shapes of the tensor operands are expected to match.
  }];

  let arguments = (ins I64:$lhs, I64:$rhs);
  let results = (outs I64);
}

それぞれのOperationは、整数64ビットの引数が2つ、整数64ビットの引数が1つであるということになっている。

これに対して、MLIRを生成するときにbuilder.createを使用する。

mlir::Value mlirGen(BinaryExprAST &binop) {
    // First emit the operations for each side of the operation before emitting
    // the operation itself. For example if the expression is `a + foo(a)`
    // 1) First it will visiting the LHS, which will return a reference to the
    //    value holding `a`. This value should have been emitted at declaration
    //    time and registered in the symbol table, so nothing would be
    //    codegen'd. If the value is not in the symbol table, an error has been
    //    emitted and nullptr is returned.
    // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
    //    and the result value is returned. If an error occurs we get a nullptr
    //    and propagate.
    //
    mlir::Value lhs = mlirGen(*binop.getLHS());
    if (!lhs)
      return nullptr;
    mlir::Value rhs = mlirGen(*binop.getRHS());
    if (!rhs)
      return nullptr;
    auto location = loc(binop.loc());
    mlir::Type elementType = builder.getI64Type();

    // Derive the operation name from the binary operator. At the moment we only
    // support '+' and '*'.
    switch (binop.getOp()) {
    case '+':
      return builder.create<AddOp>(location, elementType, lhs, rhs);
    case '*':
      return builder.create<MulOp>(location, elementType, lhs, rhs);
    }

    emitError(location, "invalid binary operator '") << binop.getOp() << "'";
    return nullptr;
  }

ここまででLLVMをビルドしてテストを流してみる。

./bin/mysv --emit=mlir ../mlir/examples/mysv/test/assign_ops.sv
module {
  %0 = "mysv.constant"() {value = 3 : si64} : () -> i64
  %1 = "mysv.constant"() {value = 4 : si64} : () -> i64
  %2 = "mysv.add"(%0, %1) : (i64, i64) -> i64
  %3 = "mysv.mul"(%0, %1) : (i64, i64) -> i64
}

一応想定どおりのMLIRを生成できた。