Press "Enter" to skip to content

【从零开始学深度学习编译器】十五,MLIR Toy Tutorials学习笔记之Lowering到LLVM IR

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

0x0. 前言

 

在上一节中,我们将Toy Dialect的部分Operation Lowering到Affine Dialect,MemRef Dialect和Standard Dialect,而toy.print
操作保持不变,所以又被叫作部分Lowering。通过这个Lowering可以将Toy Dialect的Operation更底层的实现逻辑表达出来,以寻求更多的优化机会,得到更好的MLIR表达式。这一节,我们将在上一节得到的混合型MLIR表达式完全Lowering到LLVM Dialect上,然后生成LLVM IR,并且我们可以使用MLIR的JIT编译引擎来运行最终的MLIR表达式并输出计算结果。

 

0x1. IR下降到LLVM Dialect

 

这一小节我们将来介绍如何将上一节结束的MLIR表达式完全Lowering为LLVM Dialect,我们还是回顾一下上一节最终的MLIR表达式:

 

func @main() {
  %cst = arith.constant 1.000000e+00 : f64
  %cst_0 = arith.constant 2.000000e+00 : f64
  %cst_1 = arith.constant 3.000000e+00 : f64
  %cst_2 = arith.constant 4.000000e+00 : f64
  %cst_3 = arith.constant 5.000000e+00 : f64
  %cst_4 = arith.constant 6.000000e+00 : f64
// Allocating buffers for the inputs and outputs.
  %0 = memref.alloc() : memref<3x2xf64>
  %1 = memref.alloc() : memref<2x3xf64>
// Initialize the input buffer with the constant values.
  affine.store %cst, %1[0, 0] : memref<2x3xf64>
  affine.store %cst_0, %1[0, 1] : memref<2x3xf64>
  affine.store %cst_1, %1[0, 2] : memref<2x3xf64>
  affine.store %cst_2, %1[1, 0] : memref<2x3xf64>
  affine.store %cst_3, %1[1, 1] : memref<2x3xf64>
  affine.store %cst_4, %1[1, 2] : memref<2x3xf64>
  affine.for %arg0 = 0 to 3 {
    affine.for %arg1 = 0 to 2 {
// Load the transpose value from the input buffer.
      %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64>
// Multiply and store into the output buffer.
      %3 = arith.mulf %2, %2 : f64
      affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64>
    }
  }
// Print the value held by the buffer.
  toy.print %0 : memref<3x2xf64>
  memref.dealloc %1 : memref<2x3xf64>
  memref.dealloc %0 : memref<3x2xf64>
return
}

 

我们要将这个三种Dialect混合的MLIR表达式完全Lowering为LLVM Dialect,注意LLVM Dialect是MLIR的一种特殊的Dialect层次的中间表示,「它并不是LLVM IR」
。Lowering为LLVM Dialect的整体过程可以分为如下几步:

 

1. Lowering toy.print Operation

 

之前部分Lowering的时候并没有对toy.print
操作进行Lowering,所以这里优先将toy.print
进行Lowering。我们把toy.print
Lowering到一个非仿射循环嵌套,它为每个元素调用printf
。Dialect转换框架支持传递Lowering,不需要直接Lowering为LLVM Dialect。通过应用传递Lowering可以应用多种模式来使得操作合法化(合法化的意思在这里指的就是完全Lowering到LLVM Dialect)。传递Lowering在这里体现为将toy.print
先Lowering到循环嵌套Dialect里面,而不是直接Lowering为LLVM Dialect。

 

在Lowering过程中,printf
的声明在mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
中,代码如下:

 

 /// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
                                             ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get(context, "printf");
// Create a function declaration for printf, the signature is:
//   * `i32 (i8*, ...)`
auto llvmI32Ty = IntegerType::get(context, 32);
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
    rewriter.setInsertionPointToStart(module.getBody());
    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
return SymbolRefAttr::get(context, "printf");
  }

 

这部分代码返回了printf
函数的符号引用,必要时将其插入Module。在函数中,为printf
创建了函数声明,然后将printf
函数插入到父Module的主体中。

 

2. 确定Lowering过程需要的所有组件

 

第一个需要确定的是转换目标(ConversionTarget),对于这个Lowering我们除了顶层的Module将所有的内容都Lowering为LLVM Dialect。这里代码表达的信息和官方文档有一些出入,以最新的代码为准。

 

// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();

 

然后需要确定类型转换器(Type Converter),我们现存的MLIR表达式还有MemRef
类型,我们需要将其转换为LLVM的类型。为了执行这个转化,我们使用TypeConverter
作为Lowering的一部分。这个转换器指定一种类型如何映射到另外一种类型。由于现存的操作中已经不存在任何Toy Dialect操作,因此使用MLIR默认的转换器就可以满足需求。定义如下:

 

// During this lowering, we will also be lowering the MemRef types, that are
// currently being operated on, to a representation in LLVM. To perform this
// conversion we use a TypeConverter as part of the lowering. This converter
// details how one type maps to another. This is necessary now that we will be
// doing more complicated lowerings, involving loop region arguments.
LLVMTypeConverter typeConverter(&getContext());

 

再然后还需要确定转换模式(Conversion Patterns)。这部分代码为:

 

// Now that the conversion target has been defined, we need to provide the
// patterns used for lowering. At this point of the compilation process, we
// have a combination of `toy`, `affine`, and `std` operations. Luckily, there
// are already exists a set of patterns to transform `affine` and `std`
// dialects. These patterns lowering in multiple stages, relying on transitive
// lowerings. Transitive lowering, or A->B->C lowering, is when multiple
// patterns must be applied to fully transform an illegal operation into a
// set of legal ones.
RewritePatternSet patterns(&getContext());
  populateAffineToStdConversionPatterns(patterns);
  populateLoopToStdConversionPatterns(patterns);
  populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
  populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
// PrintOp.
  patterns.add<PrintOpLowering>(&getContext());

 

上面这段代码展示了为Affine Dialect,Standard Dialect以及遗留的toy.print
定义匹配重写规则。首先将Affine Dialect下降到Standard Dialect,即populateAffineToStdConversionPatterns
。然后将Loop(针对的是toy.print
操作,它已经Lowering到了循环嵌套Dialect)下降到Standard Dialect,即populateLoopToStdConversionPatterns
。最后,将Standard Dialect转换到LLVM Dialect,即populateMemRefToLLVMConversionPatterns
。以及不要忘了把toy.print
的Lowering模式PrintOpLowering
加到patterns
里面。

 

3. 完全Lowering

 

定义了Lowering过程需要的所有组件之后,就可以执行完全Lowering了。使用applyFullConversion(module, target, std::move(patterns)))
函数可以保证转换的结果只存在合法的操作,上一篇部分Lowering的笔记调用的是mlir::applyPartialConversion(function, target, patterns)
可以对比着看一下。

 

// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
auto module = getOperation();
if (failed(applyFullConversion(module, target, std::move(patterns))))
    signalPassFailure();

 

4. 将上面定义好的完全Lowering的Pass加到Pipline中

 

这段代码在mlir/examples/toy/Ch6/toyc.cpp
中:

 

if (isLoweringToLLVM) {
// Finish lowering the toy IR to the LLVM dialect.
  pm.addPass(mlir::toy::createLowerToLLVMPass());
 }

 

这段代码在优化Pipline中添加了mlir::toy::createLowerToLLVMPass()
这个完全Lowering的Pass,可以把MLIR 表达式下降为LLVM Dialect表达式。我们运行一下示例程序看下结果:

 

执行下面的命令:

 

cd llvm-project/build/bin
./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/llvm-lowering.mlir -emit=mlir-llvm

 

即获得了完全Lowering之后的MLIR表达式,结果比较长,这里只展示一部分。可以看到目前MLIR表达式已经完全在LLVM Dialect空间下了。

 

llvm.func @free(!llvm<"i8*">)
llvm.func @printf(!llvm<"i8*">, ...) -> i32
llvm.func @malloc(i64) -> !llvm<"i8*">
llvm.func @main() {
  %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64
  %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64
  ...
^bb16:
  %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %222 = llvm.mlir.constant(0 : index) : i64
  %223 = llvm.mlir.constant(2 : index) : i64
  %224 = llvm.mul %214, %223 : i64
  %225 = llvm.add %222, %224 : i64
  %226 = llvm.mlir.constant(1 : index) : i64
  %227 = llvm.mul %219, %226 : i64
  %228 = llvm.add %225, %227 : i64
  %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*">
  %230 = llvm.load %229 : !llvm<"double*">
  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32
  %232 = llvm.add %219, %218 : i64
  llvm.br ^bb15(%232 : i64)
  ...
^bb18:
  %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
  %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
  %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
  llvm.return
}

 

0x2. 代码生成以及Jit执行

 

我们可以使用JIT编译引擎来运行上面得到的LLVM Dialect IR,获得推理结果。这里我们使用了mlir::ExecutionEngine
基础架构来运行LLVM Dialect IR。程序位于:mlir/examples/toy/Ch6/toyc.cpp

 

int runJit(mlir::ModuleOp module) {
// Initialize LLVM targets.
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
// Register the translation from MLIR to LLVM IR, which must happen before we
// can JIT-compile.
  mlir::registerLLVMDialectTranslation(*module->getContext());
// An optimization pipeline to use within the execution engine.
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
// the module.
auto maybeEngine = mlir::ExecutionEngine::create(
module, /*llvmModuleBuilder=*/nullptr, optPipeline);
  assert(maybeEngine && "failed to construct an execution engine");
auto &engine = maybeEngine.get();
// Invoke the JIT-compiled function.
auto invocationResult = engine->invokePacked("main");
if (invocationResult) {
    llvm::errs() << "JIT invocation failed
";
return -1;
  }
return 0;
}

 

 

这里尤其需要注意这行:mlir::registerLLVMDialectTranslation(*module->getContext());。从代码的注释来看这个是将LLVM Dialect表达式翻译成LLVM IR,在JIT编译的时候起到缓存作用,也就是说下次执行的时候不会重复执行上面的各种MLIR表达式变换。

 

 

这里创建一个MLIR执行引擎mlir::ExecutionEngine
来运行表达式中的main
函数。可以使用下面的命令来输出最终的计算结果:

 

cd llvm-project/build/bin
./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/codegen.toy -emit=jit -opt

 

结果为:

 

1.000000 16.000000 
4.000000 25.000000 
9.000000 36.000000

 

到这里,我们就将原始的MLIR表达式经过一系列Pass进行优化,以及部分Lowering到三种Dialect混合的表达式,和完全Lowering为LLVM Dialect表达式,最后翻译到LLVM IR使用MLIR的Jit执行引擎进行执行,获得了最终结果。

 

另外,mlir/examples/toy/Ch6/toyc.cpp
中还提供了一个dumpLLVMIR
函数,可以将MLIR表达式翻译成LLVM IR表达式。然后再经过LLVM IR的优化处理。使用如下命令可以打印出生成的LLVM IR:

 

$ cd llvm-project/build/bin
$ ./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/codegen.toy -emit=llvm -opt

 

0x3. 总结

 

这篇文章介绍了如何将部分Lowering之后的MLIR表达式进一步完全Lowering到LLVM Dialect上,然后通过JIT编译引擎来执行代码并获得推理结果,另外还可以输出LLVM Dialect生成的LLVM IR。

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注