FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

LLVMの新しい中間言語表現 MLIRを試す(2. MLIRに関するコード生成を試す)

MLIRについて基礎を学んだところで、実際に動かしてみたい。以下のページを読みながら少しチュートリアルを触ってみよう。

mlir.llvm.org

mlir.llvm.org


MLIRとのインタフェース

先ほどのtranspose()がどのようにMLIRに変換されたのかを以下に示す。

%t_tensor = "toy.transpose"(%tensor) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64> loc("example/file/path":12:1)
  • t_tnsortransposeの結果に付けられる名前。SSA値として表現される。
  • "toy.transpose":操作の名前
  • (%tensor):引数のリスト。
  • {inplace = true}:操作に付けられる属性。ここではinplaceというBool型のTrue値を持つ値を定義している。
  • (tensor<2x3xf64>) -> tensor<3x2xf64>:関数形式での型の変換形式を示している。引数の型、およびその戻り値の型を示している。
  • loc("example/file/path":12:1):この操作が発生したソースコードの場所を示している。

MLIRとのインタフェースのためのToy方言

Toy言語に対する方言を定義するために、C++ToyDialectを定義する。

/// This is the definition of the Toy dialect. A dialect inherits from
/// mlir::Dialect and registers custom attributes, operations, and types (in its
/// constructor). It can also override virtual methods to change some general
/// behavior, which will be demonstrated in later chapters of the tutorial.
class ToyDialect : public mlir::Dialect {
 public:
  explicit ToyDialect(mlir::MLIRContext *ctx);

  /// Provide a utility accessor to the dialect namespace. This is used by
  /// several utilities.
  static llvm::StringRef getDialectNamespace() { return "toy"; }
};

これをDialectのためのグローバルレジスタに登録する。

  mlir::registerDialect<ToyDialect>();

Toy Operationの定義

新しいtoy.constantオペレーションを定義する。

 %4 = "toy.constant"() {value = dense<1.0> : tensor<2x3xf64>} : () -> tensor<2x3xf64>

新しいクラスとしてConstantOpを作成する。ConstantOpにいくつかのメソッドを定義してやらなければならない。

class ConstantOp : public mlir::Op<ConstantOp,
                     /// ConstantOpは引数を何も受け取らない。
                     mlir::OpTrait::ZeroOperands,
                     /// ConstantOpは1つの戻り値を返す。
                     mlir::OpTrait::OneResult> {

 public:
  /// コンストラクタはベースのクラスから継承する。
  using Op::Op;

  /// この操作に対するユニークな操作名を定義する。MLIRはこの名前を登録してシステム中で
  /// ユニークな名前として使用する。
  static llvm::StringRef getOperationName() { return "toy.constant"; }

  /// この属性から定数を読み取って値を返す。
  mlir::DenseElementsAttr getValue();

  /// 定義したトレイトを超えた追加の検証を提供することができる。ここでは特定の定数に対する
  /// 普遍量が守られていることを確認する。例えば、結果の値はTensorTypeでなければならない。
  LogicalResult verify();
                         
  /// `value`の属性を持つ戻り値を生成するための定数を構築する。
  static void build(mlir::OpBuilder &builder, mlir::OperationState &state,
                    mlir::Type result, mlir::DenseElementsAttr value);
  /// Build a constant and reuse the type from the given 'value'.
  /// valueの属性を再利用して定数を生成する。
  static void build(mlir::OpBuilder &builder, mlir::OperationState &state,
                    mlir::DenseElementsAttr value);
  /// Build a constant by broadcasting the given 'value'.
  /// `value`をブロードキャストすることで定数を生成する。
  static void build(mlir::OpBuilder &builder, mlir::OperationState &state,
                    double value);
};

オペレーションとオペレーション:MLIRオペレーションの使用

新しいオペレーションを定義するためには2つの主要なクラスについて知る必要がある。

  • Operation :すべてのオペレーションを包括するために使用する。
  • Op:特定の型に対する操作を実装する。

Operation Definition Specification(ODS)フレームワークを使用してOperationを定義する

mlir::Opを使用せずにDSLを使ってTableGen経由でオペレーションを定義することもできる。おそらくこちらの方が推奨されている。

// 'toy' の方言をODSフレームワークに提供し、操作を定義できるようにする。
def Toy_Dialect : Dialect {
  // The namespace of our dialect, this corresponds 1-1 with the string we
  // provided in `ToyDialect::getDialectNamespace`.
  let name = "toy";

  // The C++ namespace that the dialect class definition resides in.
  let cppNamespace = "toy";
}

Opクラスからの継承によってオペレーションを定義する。

// Base class for toy dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
//   * The parent dialect of the operation.
//   * The mnemonic for the operation, or the name without the dialect prefix.
//   * A list of traits for the operation.
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
    Op<Toy_Dialect, mnemonic, traits>;

C++のコードを参照するためには以下のように入力すればよいらしい。

${build_root}/bin/mlir-tblgen -gen-op-defs ${mlir_src_root}/examples/toy/Ch2/include/toy/Ops.td -I ${mlir_src_root}/include/

Ops.tdConstantOpの定義は以下のようになっていた。

// We define a toy operation by inheriting from our base 'Toy_Op' class above.
// Here we provide the mnemonic and a list of traits for the operation. The
// constant operation is marked as 'NoSideEffect' as it is a pure operation
// and may be removed if dead.
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
  // Provide a summary and description for this operation. This can be used to
  // auto-generate documentation of the operations within our dialect.
  let summary = "constant";
  let description = [{
    Constant operation turns a literal into an SSA value. The data is attached
    to the operation as an attribute. For example:

    ```mlir
      %0 = "toy.constant"()
         { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
        : () -> tensor<2x3xf64>
    ```
  }];

  // The constant operation takes an attribute as the only input.
  let arguments = (ins F64ElementsAttr:$value);

  // The constant operation returns a single value of TensorType.
  let results = (outs F64Tensor);

  // Add custom build methods for the constant operation. These method populates
  // the `state` that MLIR uses to create operations, i.e. these are used when
  // using `builder.create<ConstantOp>(...)`.
  let builders = [
    // Build a constant with a given constant tensor value.
    OpBuilder<"Builder *builder, OperationState &state, "
              "DenseElementsAttr value", [{
      build(builder, state, value.getType(), value);
    }]>,

    // Build a constant with a given constant floating-point value.
    OpBuilder<"Builder *builder, OperationState &state, double value">
  ];

  // Invoke a static verify method to verify this constant operation.
  let verifier = [{ return ::verify(*this); }];
}

-gen-op-defsを用いて生成したC++の結果はこちら。

//===----------------------------------------------------------------------===//
// toy::ConstantOp definitions
//===----------------------------------------------------------------------===//

ConstantOpOperandAdaptor::ConstantOpOperandAdaptor(ArrayRef<Value> values) {
  tblgen_operands = values;
}

ArrayRef<Value> ConstantOpOperandAdaptor::getODSOperands(unsigned index) {
  return {std::next(tblgen_operands.begin(), index), std::next(tblgen_operands.begin(), index + 1)};
}

StringRef ConstantOp::getOperationName() {
  return "toy.constant";
}

Operation::operand_range ConstantOp::getODSOperands(unsigned index) {
  return {std::next(getOperation()->operand_begin(), index), std::next(getOperation()->operand_begin(), index + 1)};
}

Operation::result_range ConstantOp::getODSResults(unsigned index) {
  return {std::next(getOperation()->result_begin(), index), std::next(getOperation()->result_begin(), index + 1)};
}

DenseElementsAttr ConstantOp::valueAttr() {
  return this->getAttr("value").cast<DenseElementsAttr>();
}

DenseElementsAttr ConstantOp::value() {
  auto attr = valueAttr();
  return attr;
}

void ConstantOp::build(Builder *builder, OperationState &state, DenseElementsAttr value) {
      build(builder, state, value.getType(), value);

}
...

-get-op-declsを用いて生成したC++の結果はこちら。

//===----------------------------------------------------------------------===//
// toy::ConstantOp declarations
//===----------------------------------------------------------------------===//

class ConstantOpOperandAdaptor {
public:
  ConstantOpOperandAdaptor(ArrayRef<Value> values);
  ArrayRef<Value> getODSOperands(unsigned index);

private:
  ArrayRef<Value> tblgen_operands;
};
class ConstantOp : public Op<ConstantOp, OpTrait::OneResult, OpTrait::HasNoSideEffect, OpTrait::ZeroOperands> {
public:
  using Op::Op;
  using OperandAdaptor = ConstantOpOperandAdaptor;
  static StringRef getOperationName();
  Operation::operand_range getODSOperands(unsigned index);
  Operation::result_range getODSResults(unsigned index);
  DenseElementsAttr valueAttr();
  DenseElementsAttr value();
  static void build(Builder *builder, OperationState &state, DenseElementsAttr value);
  static void build(Builder *builder, OperationState &state, double value);
  static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type resultType0, DenseElementsAttr value);
  static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef<Type> resultTypes, DenseElementsAttr value);
  static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes);
  LogicalResult verify();
};