MLIRについて基礎を学んだところで、実際に動かしてみたい。以下のページを読みながら少しチュートリアルを触ってみよう。
MLIRとのインタフェース
先ほどのtranspose()
がどのようにMLIRに変換されたのかを以下に示す。
%t_tensor = "toy.transpose"(%tensor) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64> loc("example/file/path":12:1)
t_tnsor
:transpose
の結果に付けられる名前。SSA値として表現される。"toy.transpose"
:操作の名前(%tensor)
:引数のリスト。{inplace = true}
:操作に付けられる属性。ここではinplace
というBool型のTrue値を持つ値を定義している。(tensor<2x3xf64>) -> tensor<3x2xf64>
:関数形式での型の変換形式を示している。引数の型、およびその戻り値の型を示している。loc("example/file/path":12:1)
:この操作が発生したソースコードの場所を示している。
MLIRとのインタフェースのためのToy方言
Toy言語に対する方言を定義するために、C++でToyDialect
を定義する。
/// This is the definition of the Toy dialect. A dialect inherits from /// mlir::Dialect and registers custom attributes, operations, and types (in its /// constructor). It can also override virtual methods to change some general /// behavior, which will be demonstrated in later chapters of the tutorial. class ToyDialect : public mlir::Dialect { public: explicit ToyDialect(mlir::MLIRContext *ctx); /// Provide a utility accessor to the dialect namespace. This is used by /// several utilities. static llvm::StringRef getDialectNamespace() { return "toy"; } };
これをDialectのためのグローバルレジスタに登録する。
mlir::registerDialect<ToyDialect>();
Toy Operationの定義
新しいtoy.constant
オペレーションを定義する。
%4 = "toy.constant"() {value = dense<1.0> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
新しいクラスとしてConstantOp
を作成する。ConstantOp
にいくつかのメソッドを定義してやらなければならない。
class ConstantOp : public mlir::Op<ConstantOp, /// ConstantOpは引数を何も受け取らない。 mlir::OpTrait::ZeroOperands, /// ConstantOpは1つの戻り値を返す。 mlir::OpTrait::OneResult> { public: /// コンストラクタはベースのクラスから継承する。 using Op::Op; /// この操作に対するユニークな操作名を定義する。MLIRはこの名前を登録してシステム中で /// ユニークな名前として使用する。 static llvm::StringRef getOperationName() { return "toy.constant"; } /// この属性から定数を読み取って値を返す。 mlir::DenseElementsAttr getValue(); /// 定義したトレイトを超えた追加の検証を提供することができる。ここでは特定の定数に対する /// 普遍量が守られていることを確認する。例えば、結果の値はTensorTypeでなければならない。 LogicalResult verify(); /// `value`の属性を持つ戻り値を生成するための定数を構築する。 static void build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Type result, mlir::DenseElementsAttr value); /// Build a constant and reuse the type from the given 'value'. /// valueの属性を再利用して定数を生成する。 static void build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::DenseElementsAttr value); /// Build a constant by broadcasting the given 'value'. /// `value`をブロードキャストすることで定数を生成する。 static void build(mlir::OpBuilder &builder, mlir::OperationState &state, double value); };
オペレーションとオペレーション:MLIRオペレーションの使用
新しいオペレーションを定義するためには2つの主要なクラスについて知る必要がある。
Operation
:すべてのオペレーションを包括するために使用する。Op
:特定の型に対する操作を実装する。
Operation Definition Specification(ODS)フレームワークを使用してOperationを定義する
mlir::Op
を使用せずにDSLを使ってTableGen経由でオペレーションを定義することもできる。おそらくこちらの方が推奨されている。
// 'toy' の方言をODSフレームワークに提供し、操作を定義できるようにする。 def Toy_Dialect : Dialect { // The namespace of our dialect, this corresponds 1-1 with the string we // provided in `ToyDialect::getDialectNamespace`. let name = "toy"; // The C++ namespace that the dialect class definition resides in. let cppNamespace = "toy"; }
Op
クラスからの継承によってオペレーションを定義する。
// Base class for toy dialect operations. This operation inherits from the base // `Op` class in OpBase.td, and provides: // * The parent dialect of the operation. // * The mnemonic for the operation, or the name without the dialect prefix. // * A list of traits for the operation. class Toy_Op<string mnemonic, list<OpTrait> traits = []> : Op<Toy_Dialect, mnemonic, traits>;
C++のコードを参照するためには以下のように入力すればよいらしい。
${build_root}/bin/mlir-tblgen -gen-op-defs ${mlir_src_root}/examples/toy/Ch2/include/toy/Ops.td -I ${mlir_src_root}/include/
Ops.td
のConstantOp
の定義は以下のようになっていた。
// We define a toy operation by inheriting from our base 'Toy_Op' class above. // Here we provide the mnemonic and a list of traits for the operation. The // constant operation is marked as 'NoSideEffect' as it is a pure operation // and may be removed if dead. def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { // Provide a summary and description for this operation. This can be used to // auto-generate documentation of the operations within our dialect. let summary = "constant"; let description = [{ Constant operation turns a literal into an SSA value. The data is attached to the operation as an attribute. For example: ```mlir %0 = "toy.constant"() { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } : () -> tensor<2x3xf64> ``` }]; // The constant operation takes an attribute as the only input. let arguments = (ins F64ElementsAttr:$value); // The constant operation returns a single value of TensorType. let results = (outs F64Tensor); // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when // using `builder.create<ConstantOp>(...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<"Builder *builder, OperationState &state, " "DenseElementsAttr value", [{ build(builder, state, value.getType(), value); }]>, // Build a constant with a given constant floating-point value. OpBuilder<"Builder *builder, OperationState &state, double value"> ]; // Invoke a static verify method to verify this constant operation. let verifier = [{ return ::verify(*this); }]; }
-gen-op-defs
を用いて生成したC++の結果はこちら。
//===----------------------------------------------------------------------===// // toy::ConstantOp definitions //===----------------------------------------------------------------------===// ConstantOpOperandAdaptor::ConstantOpOperandAdaptor(ArrayRef<Value> values) { tblgen_operands = values; } ArrayRef<Value> ConstantOpOperandAdaptor::getODSOperands(unsigned index) { return {std::next(tblgen_operands.begin(), index), std::next(tblgen_operands.begin(), index + 1)}; } StringRef ConstantOp::getOperationName() { return "toy.constant"; } Operation::operand_range ConstantOp::getODSOperands(unsigned index) { return {std::next(getOperation()->operand_begin(), index), std::next(getOperation()->operand_begin(), index + 1)}; } Operation::result_range ConstantOp::getODSResults(unsigned index) { return {std::next(getOperation()->result_begin(), index), std::next(getOperation()->result_begin(), index + 1)}; } DenseElementsAttr ConstantOp::valueAttr() { return this->getAttr("value").cast<DenseElementsAttr>(); } DenseElementsAttr ConstantOp::value() { auto attr = valueAttr(); return attr; } void ConstantOp::build(Builder *builder, OperationState &state, DenseElementsAttr value) { build(builder, state, value.getType(), value); } ...
-get-op-decls
を用いて生成したC++の結果はこちら。
//===----------------------------------------------------------------------===// // toy::ConstantOp declarations //===----------------------------------------------------------------------===// class ConstantOpOperandAdaptor { public: ConstantOpOperandAdaptor(ArrayRef<Value> values); ArrayRef<Value> getODSOperands(unsigned index); private: ArrayRef<Value> tblgen_operands; }; class ConstantOp : public Op<ConstantOp, OpTrait::OneResult, OpTrait::HasNoSideEffect, OpTrait::ZeroOperands> { public: using Op::Op; using OperandAdaptor = ConstantOpOperandAdaptor; static StringRef getOperationName(); Operation::operand_range getODSOperands(unsigned index); Operation::result_range getODSResults(unsigned index); DenseElementsAttr valueAttr(); DenseElementsAttr value(); static void build(Builder *builder, OperationState &state, DenseElementsAttr value); static void build(Builder *builder, OperationState &state, double value); static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type resultType0, DenseElementsAttr value); static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef<Type> resultTypes, DenseElementsAttr value); static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes); LogicalResult verify(); };