MLIRについて勉強している。
独自言語を作成し、その中間表現をMLIRを使って表現してみることに挑戦する。
前回のテストでは、最適化が効きすぎて変数がまとめられてしまい、思い通りのMLIRを出力することができていなかった。次は演算子を定義し、これらをMLIRで出力できるようにしたい。
対象とするのは、以下のMYSVの構文だ。
assign A = 3; assign B = 4; assign Hoge1 = A + B; assign Hoge2 = A * B;
まずはLexerとParserを拡張する。2項演算子をサポートさせる。
/// expression::= primary binop rhs std::unique_ptr<ExprAST> parseExpr() { auto lhs = parsePrimary(); if (!lhs) return nullptr; return parseBinOpRHS(0, std::move(lhs)); }
これをMLIRに変換するために、次のMLIRのOperationを定義する。MulOpとAddOpだ。
//===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// def MulOp : MYSV_Op<"mul", [NoSideEffect]> { let summary = "element-wise multiplication operation"; let description = [{ The "mul" operation performs element-wise multiplication between two tensors. The shapes of the tensor operands are expected to match. }]; let arguments = (ins I64:$lhs, I64:$rhs); let results = (outs I64); } //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// def AddOp : MYSV_Op<"add", [NoSideEffect]> { let summary = "element-wise addition operation"; let description = [{ The "add" operation performs element-wise addition between two tensors. The shapes of the tensor operands are expected to match. }]; let arguments = (ins I64:$lhs, I64:$rhs); let results = (outs I64); }
それぞれのOperationは、整数64ビットの引数が2つ、整数64ビットの引数が1つであるということになっている。
これに対して、MLIRを生成するときにbuilder.createを使用する。
mlir::Value mlirGen(BinaryExprAST &binop) { // First emit the operations for each side of the operation before emitting // the operation itself. For example if the expression is `a + foo(a)` // 1) First it will visiting the LHS, which will return a reference to the // value holding `a`. This value should have been emitted at declaration // time and registered in the symbol table, so nothing would be // codegen'd. If the value is not in the symbol table, an error has been // emitted and nullptr is returned. // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted // and the result value is returned. If an error occurs we get a nullptr // and propagate. // mlir::Value lhs = mlirGen(*binop.getLHS()); if (!lhs) return nullptr; mlir::Value rhs = mlirGen(*binop.getRHS()); if (!rhs) return nullptr; auto location = loc(binop.loc()); mlir::Type elementType = builder.getI64Type(); // Derive the operation name from the binary operator. At the moment we only // support '+' and '*'. switch (binop.getOp()) { case '+': return builder.create<AddOp>(location, elementType, lhs, rhs); case '*': return builder.create<MulOp>(location, elementType, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; return nullptr; }
ここまででLLVMをビルドしてテストを流してみる。
./bin/mysv --emit=mlir ../mlir/examples/mysv/test/assign_ops.sv
module { %0 = "mysv.constant"() {value = 3 : si64} : () -> i64 %1 = "mysv.constant"() {value = 4 : si64} : () -> i64 %2 = "mysv.add"(%0, %1) : (i64, i64) -> i64 %3 = "mysv.mul"(%0, %1) : (i64, i64) -> i64 }
一応想定どおりのMLIRを生成できた。