当前位置:   article > 正文

TVM框架详解

tvm
  • TVM框架详解
    • 第一章 TVM框架介绍
      • 1.1 Models From Frameworks
      • 1.2 Unified IR
        • 1.2.1 IRModule
        • 1.2.2 IRMoudle构造函数相关类
        • 1.2.3 Module
        • 1.2.4 IRmodule -> module
      • 1.3 Multiple Backend And Mininmal Runtime
        • 1.3.1 Device 定义如下所示:
        • 1.3.2 Runtime
    • 第二章 TVM Demo示例及框架流程分析
      • 2.1. 示例1: 使用TVM te接口构建计算图并编译
      • 2.1.1 通过TVM te提供的接口,定义了Tensor A、B、C,其中 C = A + B
      • 2.1.1.1 TVM中Tensor 和 Operation的关系
      • 2.1.1.2 示例代码
      • 2.1.1.3 介绍TVM Tensor和Operation的相关内容
      • 2.1.2 创建schedule
      • 2.1.2.1 该部分的主要内容
      • 2.1.2.2 示例代码
      • 2.1.2.3 介绍TVM shedule、ScheduleNode 和stage、StageNode的相关内容
      • 2.1.3 tvm.lower
      • 2.1.3.1. PassContext创建、初始化以及添加相关pass
      • 2.1.3.2. form_irmodule 将第一步生成的Schedule转换为IRModule
      • 2.1.3.2.1. sch.normalize()
      • 2.1.3.2.2. schedule.InferBound(sch)
      • 2.1.3.2.3. schedule.ScheduleOps(sch, bounds)
      • 2.1.3.2.4. schedule.VerifyCompactBuffer(stmt)
      • 2.1.3.2.5. get_binds(args, compact, binds)
      • 2.1.3.2.6. schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
      • 2.1.3.2.7. schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
      • 2.1.3.3. 对转换生成的IRModule执行pass优化
      • 2.1.4 tvm.build
      • 2.1.4.1. 函数参数解读
      • 2.1.4.2. target级target_host设置与更新
      • 2.1.4.3. build_for_device
      • 2.1.4.4. codegen build_module
      • 2.1.4.4.1. 第一步初始化LLVMEnv
      • 2.1.4.4.2. GetLLVMTargetMachine(target)
      • 2.1.4.4.3. CodeGenLLVM::Create(tm_.get())
      • 2.1.4.4.4. std::vector funcs;
      • 2.1.4.4.5. CodeGenLLVM init: cg->Init(…)
      • 2.1.4.4.6. AddFunction
      • 2.1.4.4.7. AddMainFunction
      • 2.1.4.4.8. LinkParameters
      • 2.1.4.4.9. CodeGenLLVM Finish: module_ = cg->Finish();
      • 2.1.4.4.10. llvm::Module> module_->addModuleFlag
      • 2.1.4.4.11. codegen build_module总结:
      • 2.1.4.5. Generate a unified host module
      • 2.1.4.5.1. CreateCSourceCrtMetadataModule
      • 2.1.4.5.2. CreateLLVMCrtMetadataModule
      • 2.1.4.6. tmv.build输出结果
      • 2.2 示例2:使用TVM python接口,解析模型并编译
      • 2.2.1 通过Relay接口加载pytorch预训练模型resnet18
      • 2.2.2 relay.build
      • 2.2.2.1 BuildModule
      • 2.2.2.2 bld_mod.build
      • 2.2.2.2.1. Optimize
      • 2.2.2.2.2. std::unique_ptr(new GraphCodegen()
      • 2.2.2.2.3. graph_codegen_ Init
      • 2.2.2.2.4. graph_codegen Codegen
      • 2.2.2.2.5. graph_codegen: GetJSON / GetParams /GetIRModule()
      • 2.2.2.2.6. llvm module
      • 2.2.2.2.7. c source module
      • 2.2.2.2.8. tvm::build(lowered_funcs, target_host_)
      • 2.2.2.2.9. graph_codegen_->GetExternalModules()
      • 2.2.2.2.10. tvm::codegen::CreateMetadataModule(…)
      • 2.2.2.3 GraphExecutorFactoryModule
    • 名词解释

TVM框架详解

第一章 TVM框架介绍

请添加图片描述

在上面TVM的架构图中,主要包含三个部分内容:

Models From Frameworks

Unified IR

Multiple Backend And Mininmal Runtime

1.1 Models From Frameworks

在Models from Frameworks模块中,主要支持的前端如下所示:

NoFront-end
1caffe
2caffe2
3coreml
4darknet
5eras
6mxnet
7onnx
8pytorch
9tensorflow
10tflite

1.2 Unified IR

在TVM框架中提供了统一的IR,在在部分主要包含了两个module:


tvm.ir.module.IRModule: IRModule that holds functions and type definitions

tvm.runtime.module.Module : Runtime Module

  • 1
  • 2
  • 3
  • 4
  • 5
1.2.1 IRModule

IRModule定义在include/tvm/ir/module.h,与之对应的python接口定义在python/tvm/ir/module.py;

IRModule属性有:functions、type_definitions、import_set、source_map,代码如下所示:


   Map<GlobalVar, BaseFunc> functions: 
            Functions in the module.A map from ids to all global functions.


   Map<GlobalTypeVar, TypeData> type_definitions: 
            Type definitions in the module.A map from global type vars to ADT type data


   std::unordered_set<String> import_set: 
            Set of imported files in the module

   parser::SourceMap source_map: 
            map The module source map.The source map for the module.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

创建IRmoudle有如下两种类型,函数如下所示:

IRModule(Map<GlobalVar, BaseFunc> functions,
                Map<GlobalTypeVar, TypeData> type_definitions = {},
                std::unordered_set<String> import_set = {}, parser::SourceMap map = {});

IRModule FromExpr(const RelayExpr& expr,
                const Map<GlobalVar, BaseFunc>& global_funcs = {},
                const Map<GlobalTypeVar, TypeData>& type_definitions = {});
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

IRModule创建方式小结

  IRModule生成方式: 
    前端Framework对应的模型通过Framework的Parser,生成IRModule

    通过TVM提供的API(tvm.te;tvm.ir;tvm.topi),构建计算图,生成IRModule

  IRModule中的主要属性有:function和Type;function是一个map,包含了BaseFunc(见下文分析),
  
  同时IRModule可以通过FromExpr函数生成,传入参包含了RelayExpr(见下文分析)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
1.2.2 IRMoudle构造函数相关类

BaseFunc

如上文所述,IRmodule中的主要属性是Function,而function是BaseFunc的map,那么下面我们来具体看一下BaseFunc和RelayExpr的继承关系:
请添加图片描述

如上图所示BaseExpr是TVM IR的两个基础结构中的一个,BaseExpr是语法表达式的基类(文章后续章节会介绍另外一个语法结构Stmt,他是语法树节

点的基类),BaseFunc继承RelayExpr,同时BaseFunc有两个子类Function和PrimFunc,两个类的作用如下:

Function: Relay Function container

PrimFunc: Primitive functions that contains TIR statements
  • 1
  • 2
  • 3

Function是Relay包含Relay IR,而PrimFunc中包含TIR。

RelayExpr和PrimExpr继承BaseExpr。关于RelayExpr和PrimExpr的解释如下:

RelayExpr: Base node of all non-primitive expressions. 

           RelayExpr supports tensor types, functions and ADT as first class citizens.
           The life-cycle of the corresponding objects are implicitly managed by the language.


PrimExpr: Base node of all primitive expressions. 

          A primitive expression deals with low-level POD data types and 
          handles without doing life-cycle management for objects. 

          PrimExpr is used in the low-level code optimizations and integer analysis.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

eg: LetNode对应的两种类型实现如下


/*! \brief A binding of a sub-network. */
class LetNode : public RelayExprNode {
 public:
  /*! \brief The variable we bind to */
  Var var;
  /*! \brief The value we bind var to */
  RelayExpr value;
  /*! \brief The body of the let binding */
  RelayExpr body;


class LetNode : public PrimExprNode {
 public:
  /*! \brief The variable. */
  Var var;
  /*! \brief The value to be binded. */
  PrimExpr value;
  /*! \brief The result expression. */
  PrimExpr body;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
1.2.3 Module

Module定义在include/tvm/runtime/module.h,与之对应的python接口定义在python/tvm/runtime/module.py;

IRModule属性有:functions、type_definitions、import_set、source_map,代码如下所示:

Module类定义如下所示:

/*!
 * \brief Module container of TVM.
 */
class Module : public ObjectRef {
 public:
  Module() {}
  // constructor from container.
  explicit Module(ObjectPtr<Object> n) : ObjectRef(n) {}
  ......
  // refer to the corresponding container.
  using ContainerType = ModuleNode;
  friend class ModuleNode;
};

/*!
 * \brief Base container of module.
 *
 * Please subclass ModuleNode to create a specific runtime module.
 */
class TVM_DLL ModuleNode : public Object {
 public:
  /*! \brief virtual destructor */
  virtual ~ModuleNode() {}
 ......
 protected:
  friend class Module;
  friend class ModuleInternal;
  /*! \brief The modules this module depend on */
  std::vector<Module> imports_;

 private:
  /*! \brief Cache used by GetImport */
  std::unordered_map<std::string, std::shared_ptr<PackedFunc> > import_cache_;
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

ModuleNode继承关系:

在这里插入图片描述

1.2.4 IRmodule -> module

IRmodule 是在什么时候转换为module的?

这里直接给出结论,后续张姐详细介绍两个module的转换时序。IRmodule在runtime codegen的时候,

将传入的IRModule转换为对应Runtime module.

1.3 Multiple Backend And Mininmal Runtime

1.3.1 Device 定义如下所示:

3rdparty/dlpack/include/dlpack/dlpack.h

/*!
 * \brief The device type in DLDevice.
 */
typedef enum {
  /*! \brief CPU device */
  kDLCPU = 1,
  /*! \brief CUDA GPU device */
  kDLGPU = 2,
  /*!
   * \brief Pinned CUDA GPU device by cudaMallocHost
   * \note kDLCPUPinned = kDLCPU | kDLGPU
   */
  kDLCPUPinned = 3,
  /*! \brief OpenCL devices. */
  kDLOpenCL = 4,
  /*! \brief Vulkan buffer for next generation graphics. */
  kDLVulkan = 7,
  /*! \brief Metal for Apple GPU. */
  kDLMetal = 8,
  /*! \brief Verilog simulator buffer */
  kDLVPI = 9,
  /*! \brief ROCm GPUs for AMD GPUs */
  kDLROCM = 10,
  /*!
   * \brief Reserved extension device type,
   * used for quickly test extension device
   * The semantics can differ depending on the implementation.
   */
  kDLExtDev = 12,
} DLDeviceType;

include/tvm/runtime/c_runtime_api.h

/*! \brief Extension device types in TVM */
typedef enum {
  kDLAOCL = 5,
  kDLSDAccel = 6,
  kOpenGL = 11,
  kDLMicroDev = 13,
  kDLHexagon = 14,
  kDLWebGPU = 15
  // AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
1.3.2 Runtime

如下图所示,TVM调用栈分为两个主要部分:

TVM compiler: 编译和优化相关模型

TVM Runtime: 运行在对应的设备上(例如:树莓派、Android等)

在这里插入图片描述

模型优化编译

Cross Compilation and RPC

Runtime编译

cross-compile for aarch64

sudo apt-get update
sudo apt-get install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu


cmake .. \
    -DCMAKE_SYSTEM_NAME=Linux \
    -DCMAKE_SYSTEM_VERSION=1 \
    -DCMAKE_C_COMPILER=/usr/bin/aarch64-linux-gnu-gcc \
    -DCMAKE_CXX_COMPILER=/usr/bin/aarch64-linux-gnu-g++ \
    -DCMAKE_FIND_ROOT_PATH=/usr/aarch64-linux-gnu \
    -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
    -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
    -DMACHINE_NAME=aarch64-linux-gnu

make -j$(nproc) runtime
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

cross-compile for RISC-V

sudo apt-get update
sudo apt-get install gcc-riscv64-linux-gnu g++-riscv64-linux-gnu

cmake .. \
    -DCMAKE_SYSTEM_NAME=Linux \
    -DCMAKE_SYSTEM_VERSION=1 \
    -DCMAKE_C_COMPILER=/usr/bin/riscv64-linux-gnu-gcc \
    -DCMAKE_CXX_COMPILER=/usr/bin/riscv64-linux-gnu-g++ \
    -DCMAKE_FIND_ROOT_PATH=/usr/riscv64-linux-gnu \
    -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
    -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
    -DMACHINE_NAME=riscv64-linux-gnu

make -j$(nproc) runtime
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

第二章 TVM Demo示例及框架流程分析

2.1. 示例1: 使用TVM te接口构建计算图并编译

如下列代码所示,该代码主要包含如下内容:

2.1.1 通过TVM te提供的接口,定义了Tensor A、B、C,其中 C = A + B;

2.1.2 创建schedule

2.1.3 调用TVM Lower,调用相关优化并生成IRmodule

2.1.4 调用TVM build,调用相关优化并生成Module

详细代码如下所示

import tvm
from tvm import te

A = te.placeholder((2,), name="A")
B = te.placeholder((2,), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")

s = te.create_schedule(C.op)

ir_m = tvm.lower(s, [A, B, C], simple_mode=True, name='myadd')

target = tvm.target.Target(target="llvm", host="llvm")
fadd = tvm.build(ir_m, [A, B, C], target, name="myadd")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

按照代码说明,我们将该示例代码分为四个部分进行详细说明:

2.1.1 通过TVM te提供的接口,定义了Tensor A、B、C,其中 C = A + B

在该部分中,主要包含三个方面的内容。

2.1.1.1 TVM中Tensor 和 Operation的关系

2.1.1.2 示例代码

2.1.1.3 介绍TVMTensor和Operation的相关内容

2.1.1.1 TVM中Tensor 和 Operation的关系

  1. A Tensor object has an Operation object associated with it

  2. A Tensor is an output of its Operation object

  3. Each Operation object has in turn input_tensors() method 
   
    which returns a list of input Tensor to it

    This way we can keep track of dependencies between Operation
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
2.1.1.2 示例代码

# Tensor & Operation
A = te.placeholder((2,), name="A")
B = te.placeholder((2,), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在上面的代码中,A、B、C分别表示shape为2的tensor,而te.placeholder和te.compute则表示Operation。

2.1.1.3 介绍TVM Tensor和Operation的相关内容

Tensor继承关系图如下图所示:

在这里插入图片描述
Operation继承关系图如下图所示:
在这里插入图片描述

和Operation相关联的OperationNode继承关系图:

在这里插入图片描述

示例中te.placeholder()的定义


 python/tvm/te/operation.py

def placeholder(shape, dtype=None, name="placeholder"):
    """Construct an empty tensor object.

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor

    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    Returns
    -------
    tensor: Tensor
        The created tensor
    """
    shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
    dtype = "float32" if dtype is None else dtype
    return _ffi_api.Placeholder(shape, dtype, name)

 src/te/operation/placeholder_op.cc

TVM_REGISTER_GLOBAL("te.Placeholder")
    .set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
      return placeholder(shape, dtype, name);
    });

Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
  return PlaceholderOp(name, shape, dtype).output(0);
}

PlaceholderOp::PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype) {
  auto n = make_object<PlaceholderOpNode>();
  n->name = name;
  n->shape = shape;
  n->dtype = dtype;
  data_ = std::move(n);
}

include/tvm/te/operation.h

/*!
 * \brief Managed reference to PlaceholderOpNode
 * \sa PlaceholderOpNode
 */
class PlaceholderOp : public Operation {
 public:
  TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);

  TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

示例中te.compute()的定义

python/tvm/te/operation.py

def compute(shape, fcompute, name="compute", tag="", attrs=None):
    """Construct a new tensor by computing over the shape domain.

    The compute rule is result[axis] = fcompute(axis)

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor

    fcompute: lambda function of indices-> value
        Specifies the input source expression

    name: str, optional
        The name hint of the tensor

    tag: str, optional
        Additional tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

    Returns
    -------
    tensor: Tensor
        The created tensor
    """
    ......

    op_node = _ffi_api.ComputeOp(name, tag, attrs, dim_var, body)

    ......


TVM_REGISTER_GLOBAL("te.ComputeOp")
    .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
                       Array<IterVar> axis,
                       Array<PrimExpr> body) { return ComputeOp(name, tag, attrs, axis, body); });

ComputeOp::ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
                     Array<IterVar> axis, Array<PrimExpr> body) {
  if (!attrs.defined()) {
    attrs = Map<String, ObjectRef>();
  }
  auto n = make_object<ComputeOpNode>();
  n->name = std::move(name);
  n->tag = std::move(tag);
  n->attrs = std::move(attrs);
  n->axis = std::move(axis);
  n->body = std::move(body);
  if (n->body[0]->IsInstance<tir::ReduceNode>()) {
    const tir::ReduceNode* reduce = n->body[0].as<tir::ReduceNode>();
    n->reduce_axis = reduce->axis;
  }
  VerifyComputeOp(n.get());
  data_ = std::move(n);
}

 include/tvm/te/operation.h

/*!
 * \brief Managed reference to ComputeOpNode
 * \sa ComputeOpNode
 */
class ComputeOp : public Operation {
 public:
  TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
                    Array<IterVar> axis, Array<PrimExpr> body);

  TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
2.1.2 创建schedule

在该部分中,主要包含三个方面的内容。

2.1.2.1 该部分的主要内容

2.1.2.2 示例代码

2.1.2.3 介绍TVM shedule、ScheduleNode 和stage的相关内容

2.1.2.1 该部分的主要内容

在create_schedule的过程,就是根据第一部分创建的OP,创建Schedule,Schedule中的属性节点通过

ScheduleNode来表示;在创建ScheduleNode过程中,会创建Stage,Stage和Opeartion一一对应。

在示例1中,包含了2个 placeholder ops 和一个compute op,因此这里创建的schedule包含

了三个stages.

2.1.2.2 示例代码

本小节主要讲解的内容就是通过create_schedule的调用流程来了解该过程中涉及的数据结构和类,

代码如下:

# Create a schedule for list of ops
s = te.create_schedule(C.op)

  • 1
  • 2
  • 3
2.1.2.3 介绍TVM shedule、ScheduleNode 和stage、StageNode的相关内容

函数create_schedule的定义以及调用流程:


python/tvm/te/schedule.py

def create_schedule(ops):
    """Create a schedule for list of ops

    Parameters
    ----------
    ops : list of Operations
        The source expression.

    Returns
    -------
    sch : schedule.Schedule
        The created schedule.
    """
    if not isinstance(ops, (list, _container.Array)):
        ops = [ops]
    return _ffi_api.CreateSchedule(ops)

src/te/schedule/schedule_lang.cc

TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule);

include/tvm/te/schedule.h

/*!
 * \brief Create a schedule for array of ops(and their dependencies).
 * \param ops The ops to be scheduled.
 * \return sch The created Schedule.
 */
inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }


/*!
 * \brief Global schedule container
 *  For operations and all the operations they depend on.
 *  The schedule per Operation is named as stage.
 */
class Schedule : public ObjectRef {
 public:
  Schedule() {}
  explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
  /*!
   * \brief Create a schedule for array of ops(and their dependencies).
   * \param ops The ops to be scheduled.
   * \return sch The created Schedule.
   */
  TVM_DLL explicit Schedule(Array<Operation> ops);

  ......

}

src/te/schedule/schedule_lang.cc

Schedule::Schedule(Array<Operation> ops) {
  auto n = make_object<ScheduleNode>();  << 创建ScheduleNode并对应ScheduleNode中的相关属性进行赋值
  data_ = n;
  n->outputs = ops;
  auto g = te::CreateReadGraph(n->outputs);
  Array<Operation> post_order = te::PostDFSOrder(n->outputs, g);
  // output set.
  std::unordered_set<Operation> output_set;
  for (Operation x : ops) {
    output_set.insert(x);
  }
  for (Operation op : post_order) {
    Stage stage(op); << 创建Stage
    stage->is_output = output_set.count(op) != 0;
    n->stages.push_back(stage);
    n->stage_map.Set(op, stage); << stage_map赋值
    // mark scan updates.
    if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
      Array<Tensor> inputs;
      for (Tensor t : scan->state_placeholder) {
        inputs.push_back(t);
      }
      for (Tensor t : scan->inputs) {
        inputs.push_back(t);
      }
      // Create the scan group.
      Stage scan_group = this->create_group(scan->update, inputs, false);
      scan_group->attach_type = kScanUpdate;
      scan_group->attach_stage = stage;

      for (size_t i = 0; i < scan->update.size(); ++i) {
        Stage s = n->stage_map[scan->update[i]->op];
        ICHECK(scan_group.same_as(s->group));
      }
    }
  }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

ScheduleNode定义


/*! \brief node container for schedule */
class ScheduleNode : public Object {
 public:
  /*! \brief The output operations in original data flow graph */
  Array<Operation> outputs;
  /*!
   * \brief list of all stages for ops.
   * The stages are sorted in dependency order.
   */
  Array<Stage> stages;
  /*!
   * \brief List of all stage groups.
   */
  Array<Stage> groups;
  /*! \brief map of original operation to the stages */
  Map<Operation, Stage> stage_map;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在ScheduleNode中包含了Stage属性,下面我们了解一下Stage的定义:

include/tvm/te/schedule.h

/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public ObjectRef {
 public:
  Stage() {}
  explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
  /*!
   * \brief create a new schedule for op.
   * \param op The operator in the schedule
   */
  explicit Stage(Operation op);

  ......

  /*!
   * \brief specify the schedule to be computed at the parent schedule's scope.
   * \param parent The parent schedule.
   * \param scope The iteration point to carry the schedule.
   * \return reference to self.
   */
  TVM_DLL Stage& compute_at(Stage parent, IterVar scope);  // NOLINT(*)
  /*!
   * \brief Compute the function inline.
   * \return reference to self.
   */
  TVM_DLL Stage& compute_inline();  // NOLINT(*)
  /*!
   * \brief Compute the function at group root.
   * \return reference to self.
   */
  TVM_DLL Stage& compute_root();  // NOLINT(*)
  ......

}

src/te/schedule/schedule_lang.cc

Stage::Stage(Operation op) {
  auto n = make_object<StageNode>(); << 创建StageNode,并对相关属性进行初始化
  n->op = op;
  n->origin_op = op;
  n->all_iter_vars = op->root_iter_vars();
  // remove opaque var from leaf.
  Array<IterVar> clean;
  for (IterVar iv : n->all_iter_vars) {
    if (iv->iter_type != kOpaque) clean.push_back(iv);
  }
  if (clean.size() == n->all_iter_vars.size()) {
    n->leaf_iter_vars = n->all_iter_vars;
  } else {
    n->leaf_iter_vars = clean;
  }
  data_ = std::move(n);
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

StageNode定义


/*!
 * \brief represents a stage.
 *
 *  relations form a Directed acylic hypergraph in bipartite manner.
 *  With each node is represented by a IterVar,
 *  and each hyper-edge is represented by a IterVarRelation.
 *  The relations connects the IterVars in the graph.
 *
 *  Besides typical stage that corresponds to operations.
 *  There is also group stage, which groups stages together.
 *  Each stage's group(given by group) represent an constraint,
 *  the stage can only be attached to stages within the group.
 *
 *  The group stage node can be attached to IterVars as in normal stage.
 */
class StageNode : public Object {
 public:
  /*!
   * \brief The operation of stage, can be different from original op.
   *  If it is null, then this stage is a group stage.
   */
  Operation op;
  /*!
   * \brief The original operator.
   *  The op field can change during schedule to alternate the dataflow,
   *  while origin_op remains fixed.
   */
  Operation origin_op;
  /*! \brief All the nodes in the iter var */
  Array<IterVar> all_iter_vars;
  /*! \brief The current active leaf iter vars in the stage. */
  Array<IterVar> leaf_iter_vars;
  /*!
   * \brief Specify threads to be launched at the stage.
   *  This is only valid for composite ops such as Scan.
   * \note Experimental primitive: used for thread persistence.
   */
  Array<IterVar> env_threads;
  /*!
   * \brief The predicate under which store can happen
   *  Use this when there can be duplicated threads doing the same store.
   * \note Experimental primitive: used by cross thread-reduction.
   */
  PrimExpr store_predicate;
  /*! \brief The relation bwteen of IterVars */
  Array<IterVarRelation> relations;
  /*! \brief additional attributes about iter var. */
  Map<IterVar, IterVarAttr> iter_var_attrs;
  /*! \brief The attachment type of the schedule */
  AttachType attach_type{kGroupRoot};
  /*! \brief The attach point of this schedule. */
  IterVar attach_ivar;
  /*! \brief The stage this node attaches to */
  Stage attach_stage;
  /*! \brief The thread storage scope level of the stage */
  std::string scope;
  /*! \brief Whether this is an output stage */
  bool is_output{false};
  /*! \brief Whether apply double buffer optimization to this stage */
  bool double_buffer{false};
  /*!
   * \brief The parent group of the current stage.
   *  The stage cannot be assigned to stages outside the group.
   */
  Stage group;
  /*! \brief Number of direct child stages, only used for group stage.*/
  int num_child_stages{0};

}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
2.1.3 tvm.lower

tvm.lower的主要作用是将第一二步构建生成的Schedule转换为IRModule,并执行相关Pass优化。

在tvm.lower中主要包含三部分内容:

2.1.3.1. PassContext创建、初始化以及添加相关pass

2.1.3.2. form_irmodule 将第一步生成的Schedule转换为IRModule

2.1.3.3. 对转换生成的IRModule执行pass优化

这里我们先来看tvm.lower定义

python/tvm/driver/build_module.py

def lower(sch, args, name="main", binds=None, simple_mode=False):
    """Lowering step before build into target.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The schedule to be built

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str, optional
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    simple_mode : bool, optional
        Whether only output simple and compact statement, this will skip
        LoopPartition, api wrapper generation and Unrolling.

    Returns
    -------
    m : IRModule or Stmt
       The result IRModule, if simple_mode=False
       Then the Stmt before make api is returned.
    """
    # config setup
    pass_ctx = PassContext.current()

    # pass_ctx init ......

    # Phase 0
    if isinstance(sch, schedule.Schedule):
        mod = form_irmodule(sch, args, name, binds)
    else:
        mod = sch
    
    # create pass  pass_list ......

    optimize = tvm.transform.Sequential(pass_list)
    mod = optimize(mod)

    return mod

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

下面我们分别详细介绍三部分内容

2.1.3.1. PassContext创建、初始化以及添加相关pass

PassContext是用来组织Pass的行为。如下文源码所示,在PassContext初始化时,默认设置的Opt 级别是2;

在PassContext::Create()中会创建和PassContext相关的类PassContextNode,在PassContextNode

中包含了和pass相关的信息,pass level,pass依赖,dissable pass等,详细内容见下文PassContextNode定义。

同时在会有PassConfigManager用来管理全局的PassContext。

python/tvm/ir/transform.py

def __init__(
    self, opt_level=2/**默认opt level是2**/, required_pass=None, disabled_pass=None, trace=None, config=None
):
    self.__init_handle_by_constructor__(
         _ffi_transform_api.PassContext, opt_level, required, disabled, trace, config
    )

TVM_REGISTER_GLOBAL("transform.PassContext")
    .set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
                       TraceFunc trace_func, Optional<Map<String, ObjectRef>> config) {
      auto pctx = PassContext::Create();
      pctx->opt_level = opt_level;

      pctx->required_pass = std::move(required);
      pctx->disabled_pass = std::move(disabled);
      pctx->trace_func = std::move(trace_func);
      if (config.defined()) {
        pctx->config = config.value();
      }
      PassConfigManager::Global()->Legalize(&(pctx->config));
      return pctx;
    });

/*!
 * \brief PassContextNode contains the information that a pass can rely on,
 * such as analysis results.
 * \sa PassContext
 */
class PassContextNode : public Object {
 public:
  /*! \brief The default optimization level. */
  int opt_level{2};

  /*! \brief The list of required passes. */
  Array<String> required_pass;
  /*! \brief The list of disabled passes. */
  Array<String> disabled_pass;
  /*! \brief The diagnostic context. */
  mutable Optional<DiagnosticContext> diag_ctx;
  /*! \brief Pass specific configurations. */
  Map<String, ObjectRef> config;
  /*! \brief Trace function to be invoked before and after each pass. */
  TraceFunc trace_func;
......
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

PassContext.current(),获取当前ThreadLocal对应的默认PassContext.


@staticmethod
def current():
    """Return the current pass context."""
    return _ffi_transform_api.GetCurrentPassContext()

TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current);

PassContext PassContext::Current() {
  PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
  if (!entry->context_stack.empty()) {
    return entry->context_stack.top();
  } else {
    return entry->default_context;
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在PassContext完成初始化后,在tvm.lower中构建了pass list对通过form_irmodule生成的IRModule进行优化。

关于TVM pass相关内容见后文。/*todo/

2.1.3.2. form_irmodule 将第一步生成的Schedule转换为IRModule

这里简要介绍form_irmodule函数中的操作:

2.1.3.2.1. sch.normalize() :Build a normalized schedule from the current schedule

2.1.3.2.2. schedule.InferBound(sch) : Infer the bound of all iteration variables relates to the schedule

2.1.3.2.3. schedule.ScheduleOps(sch, bounds): Schedule s’ dependent operations.将Te的结构基本转化成了Stmt的结构

2.1.3.2.4. schedule.VerifyCompactBuffer(stmt):Verify if there is any argument bound to compact buffer.

2.1.3.2.5. get_binds(args, compact, binds) : get binds and arg_list given arguments

2.1.3.2.6. schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds): Try to modify the AST generated by

ScheduleOps to support TensorCore

2.1.3.2.7. schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) :Postprocessing the Stmt generated by

ScheduleOps to create a PrimFunc that can then be used for further TIR optimizations.

form_irmodule 代码实现如下所示:

python/tvm/driver/build_module.py

form_irmodule(sch, args, name, binds) /*name=main*/

def form_irmodule(sch, args, name, binds):
    """According to the given schedule, form a function.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The given scheduler to form the raw body

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        The binds information

    Returns
    -------
    The body formed according to the given schedule
    """
    # normalize schedule first
    pass_ctx = PassContext.current()
    sch = sch.normalize()
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds)

    compact = schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, binds)

    stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
    func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr("global_symbol", name)

    if pass_ctx.config.get("tir.noalias", True):
        func = func.with_attr("tir.noalias", True)
    return tvm.IRModule({name: func})

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

下面我们详细介绍相关函数的具体内容.

2.1.3.2.1. sch.normalize()

Build a normalized schedule from the current schedule


python/tvm/te/schedule.py

def normalize(self):
    """Build a normalized schedule from the current schedule.

    Insert necessary rebase to make certain iter var to start from 0.
    This is needed before bound inference and followup step.

    Returns
    -------
    sch : Schedule
        The normalized schedule.
    """
    return _ffi_api.ScheduleNormalize(self)

src/te/schedule/schedule_lang.cc

TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize);

Schedule Schedule::normalize() {
  Schedule sn = copy();
  InjectInline(sn.operator->(), false);
  RebaseNonZeroMinLoop(sn.operator->());
  LegalizeInvalidAttach(sn.operator->());
  return sn;
}

include/tvm/te/schedule.h

  /*!
   * \brief Normalize the schedule.
   *  This is needed before bound inference.
   *  Insert necessary RebaseNode to make sure all leaf_iter_vars
   *  are in form [0, extent)
   *
   * \return A normalized schedule, can be same as current one.
   */
  Schedule normalize();

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
2.1.3.2.2. schedule.InferBound(sch)

InferBound的主要工作是创建边界图,该边界图为程序中的每个【IterVar】指定一个范围。

然后将这些边界传递给【ScheduleOps】,在其中它们用于设置For循环的范围,

并设置分配的缓冲区的大小(BuildRealize),以及其他用途。

InferBound的输出是从IterVar到Range的映射 :

Map<IterVar, Range> InferBound(const Schedule& sch);

  • 1
  • 2

下面是InferBound的详细代码:


src/te/schedule/bound.cc

TVM_REGISTER_GLOBAL("schedule.InferBound").set_body_typed(InferBound);

Map<IterVar, Range> InferBound(const Schedule& sch) {
  // Prepare context
  GraphContext ctx;
  Array<Operation> roots;
  arith::Analyzer analyzer;

  for (Operation op : sch->outputs) {
    roots.push_back(sch->stage_map[op]->op);
  }
  ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));

   ......

  ctx.attach_path = CreateAttachPath(sch);

  // Run inference.
  std::unordered_map<IterVar, Range> ret;
  for (size_t i = sch->stages.size(); i != 0; --i) {
    const Stage& stage = sch->stages[i - 1];
    InferRootBound(stage, ctx, &ret);
    ......

    // pass down to get bound of all iter vars.
    PassDownDomain(stage, &ret, &analyzer);
    ......

  for (auto& p : ret) {
    ret[p.first] =
        Range::FromMinExtent(analyzer.Simplify(p.second->min), analyzer.Simplify(p.second->extent));
  }
  return Map<IterVar, Range>(ret.begin(), ret.end());
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

在InferBound调用过程中涉及的数据结构有以下:Range、IterVar、GraphContext

Range


include/tvm/ir/expr.h

class Range : public ObjectRef {
 public:
  /*!
   * \brief constructor by begin and end
   * \param begin The begin of the range.
   * \param end The end of the range.
   * \param span The location of the Range in the source.
   */
  TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span());
  /*!
   * \brief construct a new range with min and extent
   *  The corresponding constructor is removed,
   *  because that is counter convention of tradition meaning
   *  of range(begin, end)
   *
   * \param min The minimum range.
   * \param extent The extent of the range.
   * \param span The location of the Range in the source.


class RangeNode : public Object {
 public:
  /*! \brief beginning of the node */
  PrimExpr min;
  /*! \brief the extend of range */
  PrimExpr extent;
  /*! \brief the location of this range in the source */
  mutable Span span;
  /*! \brief constructor */

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

IterVar : IterVarRelation & IterVarRelationNode

IterVarRelation:Fuse、Rebase、Singleton、Split
在这里插入图片描述

IterVarRelationNode
在这里插入图片描述

FuseNode :Fuse two domains into one domain.

class FuseNode : public IterVarRelationNode {
 public:
  /*! \brief The outer domain */
  IterVar outer;
  /*! \brief The inner domain */
  IterVar inner;
  /*! \brief The target domain */
  IterVar fused;

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

RebaseNode :Rebase the iteration to make min to be 0.

class RebaseNode : public IterVarRelationNode {
 public:
  /*! \brief The parent domain */
  IterVar parent;
  /*! \brief The inner domain */
  IterVar rebased;
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

SingletonNode :Singleton iterator [0, 1)

class SingletonNode : public IterVarRelationNode {
 public:
  /*! \brief The singleton iterator */
  IterVar iter;
  
  • 1
  • 2
  • 3
  • 4
  • 5

SplitNode :Split the parent domain into product of outer and iter.

class SplitNode : public IterVarRelationNode {
 public:
  /*! \brief The parent domain */
  IterVar parent;
  /*! \brief The outer domain */
  IterVar outer;
  /*! \brief The inner domain */
  IterVar inner;
  /*! \brief The split factor */
  PrimExpr factor;
  /*! \brief Number of parts, only factor or nparts can be given */
  PrimExpr nparts;
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

GraphContext : The graph context used during bound inference

src/te/schedule/bound.cc

struct GraphContext {
  /*! \brief The feed graph */
  FeedGraph feed_graph;
  /*! \brief Attachment path */
  AttachPath attach_path;
  /*! \brief The bind map */
  std::unordered_map<IterVar, IterVar> bind_map;
  /*! \brief map from op to stage */
  std::unordered_map<const Object*, Stage> op2stage_;
};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

FeedGraph : The map between tensor and operation it feeds to.


using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;

  • 1
  • 2
  • 3

AttachPath :AttachPath maps op-> a list of IterVar


using AttachPath = Map<Operation, Array<IterVar> >;

  • 1
  • 2
  • 3

ReadGraph : data structure of Operation->Tensors it reads


using ReadGraph = Map<Operation, Array<Tensor> >;

  • 1
  • 2
  • 3
2.1.3.2.3. schedule.ScheduleOps(sch, bounds)

Schedule s’ dependent operations.将Te的结构基本转化成了Stmt的结构

这里先介绍一下Stmt和StmtNode相关信息。

Stmt定义如下:


/*! \brief Container of all statements */
class Stmt : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode);
};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Stmt是TVM IR 的两个基础结构的另外一个,是语法树节点的基类,每个Stmt结构本身表示一个

独立的语法树节点,但是语法树节点之间相互嵌套,通过Stmt的body(Stmt的通常结构)等成

员继续向下查看就能够看到一颗完整的抽象语法树(AST)了。Stmt继承关系图如下图所示:
在这里插入图片描述

StmtNode定义如下:


/*! \brief Base node of all statements. */
class StmtNode : public Object {
 public:
  /*!
   * \brief Span that points to the original source code.
   *        Reserved debug information.
   */
  mutable Span span;

  StmtNode() = default;
  explicit StmtNode(Span span) : span(span) {}

  static constexpr const char* _type_key = "tir.Stmt";
  static constexpr const bool _type_has_method_sequal_reduce = true;
  static constexpr const bool _type_has_method_shash_reduce = true;
  static constexpr const uint32_t _type_child_slots = 15;
  TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

StmtNode继承关系图如下图所示:

在这里插入图片描述

Stmt的相关类型

计算节点类型

	For、LetStmt(MakeLoopNest)、IfThenElseNode(predicates)
  • 1

存储节点类型

	ProducerStore(MakeReduction、MakeProvide),ProducerRealize(BuildRealize)
  • 1

属性节点类型

	pragma_scope、prefetch、virtual_thread、thread_extent、pipeline、loop_scope(MakeLoopNest)、

buffer_dim_align(BuildRealize)、BufferBound(Flatten ProducerStore)、realize_scope(MakePipeline)
  • 1
  • 2
  • 3

下面我们来看schedule.ScheduleOps(sch, bounds)调用栈,传入参数为sch 和 schedule.InferBound(sch)生成的bounds

ScheduleOps的主要内容是创建Stmt,根据Schedule中的Stage对Stmt进行初始化以及更新。

src/te/schedule/schedule_ops.cc

TVM_REGISTER_GLOBAL("schedule.ScheduleOps").set_body([](TVMArgs args, TVMRetValue* ret) {
  if (args.size() == 2)
    *ret = ScheduleOps(args[0], args[1], false);
  else
    *ret = ScheduleOps(args[0], args[1], args[2]);
});


Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
  Stmt body = Stmt(); //创建 Stmt
  std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
  // scan init and scan updates
  std::unordered_map<Operation, Operation> scan_init;
  for (Stage s : sch->stages) {
    const ScanOpNode* scan = s->op.as<ScanOpNode>();
    if (!scan) continue;
    for (Tensor t : scan->init) {
      if (scan_init.count(t->op)) {
      } else {
        scan_init[t->op] = s->op;
      }
    }
  }

  // reverse the post DFS order.
  for (size_t i = sch->stages.size(); i != 0; --i) {
    Stage s = sch->stages[i - 1];
    // no need to specify place holder op.
    if (s->op.as<PlaceholderOpNode>()) continue;
    // Remove grouping sugar, get the real attach spec.
    Stage attach_spec = s.GetAttachSpec();

    if (scan_init.count(s->op)) {
      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
      body = mu(std::move(body));
    } else if (attach_spec->attach_type == kScanUpdate) {
      // Handle scan update
      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
      body = mu(std::move(body));
    } else if (attach_spec->attach_type == kInlinedAlready) {
      // do nothing
    } else if (attach_spec->attach_type == kGroupRoot) {
      body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
    } else {
      ICHECK_EQ(attach_spec->attach_type, kScope);
      InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
      body = mutator(std::move(body));
   
    }
  }
  SchedulePostProc post_proc;
  post_proc.Init(sch);
  return post_proc(std::move(body));
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

在上面的函数中涉及一下类:InjectScanStep、InjectAttach、SchedulePostProc以及函数MakePipeline,

下面我们简单介绍。

InjectScanStep、InjectAttach继承于StmtMultator,继承关系如下图所示

在这里插入图片描述

2.1.3.2.4. schedule.VerifyCompactBuffer(stmt)

Verify if there is any argument bound to compact buffer.

src/te/schedule/verify_compact_buffer.cc

TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer").set_body_typed(VerifyCompactBuffer);

bool VerifyCompactBuffer(const Stmt& stmt) {
  VerifyBuffer verifier;
  return verifier.Verify(stmt);
}


class VerifyBuffer : public StmtVisitor {
 public:
  bool Verify(const Stmt& stmt) {
    this->VisitStmt(stmt);
    return is_compact_;
  }

  void VisitStmt_(const AttrStmtNode* op) final {
    StmtVisitor::VisitStmt_(op);
    if (op->attr_key == tir::attr::buffer_bind_scope) {
      is_compact_ = true;
    }
  }

 private:
  bool is_compact_{false};
};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
2.1.3.2.5. get_binds(args, compact, binds)

get binds and arg_list given arguments


python/tvm/driver/build_module.py

def get_binds(args, compact=False, binds=None):
    """Internal function to get binds and arg_list given arguments.

    Parameters
    ----------
    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    compact : bool
        If the statement has already bound to a compact buffer.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    Returns
    -------
    binds: dict
        The bind specification

    arg_list: list
        The list of symbolic buffers of arguments.
    """
    binds = {} if binds is None else binds.copy()
    arg_list = []
    for x in args:
        if isinstance(x, tensor.Tensor):
            any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
            buffer_type = "auto_broadcast" if any_dim and not compact else ""
            if x not in binds:
                buf = tvm.tir.decl_buffer(
                    x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
                )
                binds[x] = buf
                arg_list.append(buf)
            else:
                arg_list.append(binds[x])
        elif isinstance(x, schedule.Buffer):
            arg_list.append(x)
        elif isinstance(x, tvm.tir.Var):
            arg_list.append(x)
        else:
            raise ValueError("args must be Tensor, Buffer or Var")
    return binds, arg_list


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
2.1.3.2.6. schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)

Try to modify the AST generated by ScheduleOps to support TensorCore


src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc

TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
   .set_body_typed([](Stmt stmt, Schedule schedule, Map<te::Tensor, Buffer> extern_buffer) {
     return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer);
   });

Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
                                         Map<Tensor, Buffer> extern_buffer) {
 // Check if current lower target is CUDA
 auto target = tvm::Target::Current(true);
 if (target.defined() && target->kind->name != "cuda") {
   return stmt;
 }

 // Check if current runtime support GPU CUDA
 Device dev{kDLGPU, 0};
 auto api = tvm::runtime::DeviceAPI::Get(dev, true);
 if (api == nullptr) {
   return stmt;
 }

 MMAMatcher mma_matcher(extern_buffer);
 mma_matcher(stmt);
 if (!mma_matcher.Matched()) {
   return stmt;
 }

 ScheduleAnalyser schedule_analyser(mma_matcher);
 if (!schedule_analyser.MatrixIdentify(schedule)) {
   return stmt;
 }

 BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher);
 buffer_analyser(stmt);
 if (!buffer_analyser.QualifiedForTensorCore()) {
   return stmt;
 }

 return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
2.1.3.2.7. schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

Postprocessing the Stmt generated by ScheduleOps to create a PrimFunc that can then be

used for further TIR optimizations.


src/te/schedule/schedule_postproc_to_primfunc.cc

TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
    .set_body_typed(SchedulePostProcToPrimFunc);


PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
                                    Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
  std::unordered_map<Tensor, Buffer> extern_buffer;

  if (extern_buffer_opt.defined()) {
    auto v = extern_buffer_opt.value();
    extern_buffer = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
  }

  Array<tir::Var> params;
  Map<tir::Var, tir::Buffer> buffer_map;

  for (auto var : arg_list) {
    if (auto* n = var.as<tir::VarNode>()) {
      params.push_back(GetRef<tir::Var>(n));
    } else if (auto* n = var.as<te::TensorNode>()) {
      te::Tensor tensor = GetRef<te::Tensor>(n);
      ICHECK(!extern_buffer.count(tensor));

      tir::Buffer buffer = CreateBufferFor(tensor);
      tir::Var bptr(buffer->name, DataType::Handle());
      params.push_back(bptr);
      buffer_map.Set(bptr, buffer);
      extern_buffer[tensor] = buffer;
    } else {
      tir::Buffer buffer = Downcast<tir::Buffer>(var);
      tir::Var bptr(buffer->name, DataType::Handle());
      params.push_back(bptr);
      buffer_map.Set(bptr, buffer);
    }
  }

  body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
  return tir::PrimFunc(params, body, VoidType(), buffer_map);
}  


/*!
 * \brief Postprocessing the Stmt generated by ScheduleOps to create
 *        a PrimFunc that can then be used for further TIR optimizations.
 *
 *  Perform this translation before running any TIR optimizations.
 *
 *  List of actions taken by the function:
 *  - Remove occurences of te::Tensor, te::Operation in the IR
 *    and replace them by corresponding IR nodes via tir::Buffer.
 *  - Add annotation of extern buffers using the buffer_map field
 *    in the PrimFunc type.
 *
 * \param arg_list Array of Tensor/Var/Buffer arguments to the function.
 * \param body The body of the function.
 * \param bindings potential Tensor to Buffer bindings for the Tensors in the body.
 */
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
                                    Optional<Map<Tensor, Buffer>> bindings);

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
2.1.3.3. 对转换生成的IRModule执行pass优化

python/tvm/driver/build_module.py

mod = optimize(mod)

  • 1
  • 2
  • 3
  • 4
  • 5

optimize list


optimize = tvm.transform.Sequential(pass_list)

  • 1
  • 2
  • 3

Phase 1

  tvm.tir.transform.InjectPrefetch() 
	
  tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers)

  tvm.tir.transform.BF16Legalize()[BF16Promote/BF16CastElimination/BF16TypeLowering]

  tvm.tir.transform.NarrowDataType(32)

  tvm.tir.transform.Simplify(),
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Phase 2


  tvm.tir.transform.LoopPartition()
  
  tvm.tir.transform.VectorizeLoop(not disable_vectorize)
  
  tvm.tir.transform.InjectVirtualThread()
  
  tvm.tir.transform.InjectDoubleBuffer()
  
  tvm.tir.transform.StorageRewrite()
  
  tvm.tir.transform.UnrollLoop()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

Phase 3


  tvm.tir.transform.Simplify(),
  
  tvm.tir.transform.RemoveNoOp(),

  tvm.tir.transform.RewriteUnsafeSelect()

  tvm.tir.transform.HoistIfThenElse()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Phase 4

    tvm.tir.transform.InstrumentBoundCheckers()

  • 1
  • 2
2.1.4 tvm.build

tvm.build的主要作用是根据输入函数参数,生成对应设备的相关代码,即

Code generation, where target machine code is generated from the low level IR

这里我们将该函数主要分为一下内容详细介绍:

2.1.4.1. 函数参数解读

2.1.4.2. target级target_host设置与更新

2.1.4.3. _build_for_device

2.1.4.4. codegen build_module

2.1.4.5. Generate a unified host module

2.1.4.6. tmv.build输出结果示例

tvm.build函数代码如下文所示:

def build(inputs, args=None, target=None, target_host=None, name="default_function", binds=None):
    
    if isinstance(inputs, schedule.Schedule):
        if args is None:
            raise ValueError("args must be given for build from schedule")
        input_mod = lower(inputs, args, name=name, binds=binds)
    elif isinstance(inputs, (list, tuple, container.Array)):
        merged_mod = tvm.IRModule({})
        for x in inputs:
            merged_mod.update(x)
        input_mod = merged_mod
    elif isinstance(inputs, tvm.IRModule):
        input_mod = inputs
    
    ......

    if not isinstance(inputs, (dict, container.Map)):
        target = Target.current() if target is None else target
        target = target if target else "llvm"
        target_input_mod = {target: input_mod}
    else:
        target_input_mod = inputs
    
    ......

    target_input_mod, target_host = Target.check_and_update_host_consist(
        target_input_mod, target_host
    )

    if not target_host:
        for tar, mod in target_input_mod.items():
            tar = Target(tar)
            device_type = ndarray.device(tar.kind.name, 0).device_type
            if device_type == ndarray.cpu(0).device_type:
                target_host = tar
                break
    if not target_host:
        target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

    target_input_mod, target_host = Target.check_and_update_host_consist(
        target_input_mod, target_host
    )

    mod_host_all = tvm.IRModule({})

    device_modules = []
    for tar, input_mod in target_input_mod.items():
        mod_host, mdev = _build_for_device(input_mod, tar, target_host)
        mod_host_all.update(mod_host)
        device_modules.append(mdev)
    # Generate a unified host module.
    rt_mod_host = codegen.build_module(mod_host_all, target_host)

    # Import all modules.
    for mdev in device_modules:
        if mdev:
            rt_mod_host.import_module(mdev)
    if not isinstance(target_host, Target):
        target_host = Target(target_host)
    if (
        target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c"
        and target_host.attrs.get("system-lib", 0).value == 1
    ):
        if target_host.kind.name == "c":
            create_csource_crt_metadata_module = tvm._ffi.get_global_func(
                "runtime.CreateCSourceCrtMetadataModule"
            )
            return create_csource_crt_metadata_module([rt_mod_host], target_host)

        if target_host.kind.name == "llvm":
            create_llvm_crt_metadata_module = tvm._ffi.get_global_func(
                "runtime.CreateLLVMCrtMetadataModule"
            )
            return create_llvm_crt_metadata_module([rt_mod_host], target_host)
    return rt_mod_host

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76

下面我们详细介绍tmv.build函数的详细内容:

2.1.4.1. 函数参数解读

tvm.build的输入参数有:inputs、args、target、target_host、name、binds,下面详细介绍:

inputs: 根据tvm.build的函数实现代码,inputs参数支持的类型有: tvm.te.Schedule 、

IRModule、dict of target to IRModule

args : Buffer 或者 Tensor 或者 var 列表,是一个可选参数

eg:
  n = 10
  A = te.placeholder((n,), name="A")
  B = te.compute((n,), lambda *i: A(*i) + 1.0, name="B")
  s = te.create_schedule(B.op)
  f = tvm.build(s, [A, B], "ext_dev", "llvm") // 这里args指[A, B]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

target : str or :any:tvm.target.Target, 这里的targe是指编译生成的目标,是一个可选参数,

tvm目前支持的Target列表有:
   
   'opencl', 'nvptx', 'cuda', 'llvm', 'metal', 'c', 'vulkan', 

   'webgpu', 'rocm', 'stackvm', 'ext_dev', 'sdaccel', 'hexagon', 

   'aocl', 'composite', 'aocl_sw_emu', 'hybrid'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

target_host : str or :any:tvm.target.Target 是一个可选参数

     Host compilation target, if target is device.
     When TVM compiles device specific program such as CUDA,
     we also need host(CPU) side code to interact with the driver
     setup the dimensions and parameters correctly.
     target_host is used to specify the host side codegen target.
     By default, llvm is used if it is enabled,
     otherwise a stackvm intepreter is used.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

name : str, 是一个可选参数, The name of result function.

binds : dict, 是一个可选参数,
Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument.
Returns :ret : tvm.module(runtime::Module),A module that combines both host and device code.

tvm.build根据Input的类型,有3种典型的调用方式.

Inputs 是IRModule

        n = 2
        A = te.placeholder((n,), name='A')
        B = te.placeholder((n,), name='B')
        C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
        s = tvm.te.create_schedule(C.op)
        m = tvm.lower(s, [A, B, C], name="test_add") // tvm.lower返回值为IRModule
        rt_mod = tvm.build(m, target="llvm")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

Inputs是一个字典,包含 --> 编译目标:IRModule

        n = 2
        A = te.placeholder((n,), name='A')
        B = te.placeholder((n,), name='B')
        C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
        s1 = tvm.te.create_schedule(C.op)
        with tvm.target.cuda() as cuda_tgt:
          s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
          m1 = tvm.lower(s1, [A, B, C], name="test_add1")
          m2 = tvm.lower(s2, [A, B, C], name="test_add2")
          rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

Inputs是一个Schedule

    n = 10
    A = te.placeholder((n,), name="A")
    B = te.compute((n,), lambda *i: A(*i) + 1.0, name="B")
    s = te.create_schedule(B.op)

    f = tvm.build(s, [A, B], "llvm", "llvm")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
2.1.4.2. target级target_host设置与更新

根据Inputs以及tvm支持的target、target_host设置更新target_input_mod, target_host

2.1.4.3. build_for_device

关于该函数的作用,我们引用该函数的原注释进行说明

Build the lowered functions for a device with the given compilation target.

  • 1
  • 2

即,根据编译目标生成对应设备上的Lowered functions(PrimFunc),下面我们来具体看该函数中

的相关处理。

函数原型参数:


输入参数:

input_mod : IRModule

target : str or :any:`tvm.target.Target`

target_host : str or :any:`tvm.target.Target`

返回值:

fhost : IRModule,The host IRModule.

mdev : tvm.module,A module that contains device code.

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

第一步 针对IRMoudle的pass优化, 包含的pass有以下内容

opt_mixed

tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))

tvm.tir.transform.VerifyMemory()

tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))

tvm.tir.transform.ThreadSync("global")

tvm.tir.transform.ThreadSync("shared")

tvm.tir.transform.ThreadSync("warp")

tvm.tir.transform.InferFragment(),

tvm.tir.transform.LowerThreadAllreduce()

tvm.tir.transform.MakePackedAPI()

tvm.tir.transform.SplitHostDevice()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

opt_device

tvm.tir.transform.Filter(
                lambda f: "calling_conv" in f.attrs
                and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
            )
            
tvm.tir.transform.LowerWarpMemory()

tvm.tir.transform.Simplify()

tvm.tir.transform.LowerDeviceStorageAccessInfo()

tvm.tir.transform.LowerCustomDatatypes()

tvm.tir.transform.LowerIntrin()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

opt_host

tvm.tir.transform.Filter(
                lambda f: "calling_conv" not in f.attrs
                or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
            )

tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host))

tvm.tir.transform.LowerTVMBuiltin()

tvm.tir.transform.LowerDeviceStorageAccessInfo()

tvm.tir.transform.LowerCustomDatatypes()

tvm.tir.transform.LowerIntrin()

tvm.tir.transform.CombineContextCall()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

第二步 针对target,执行codegen build moudle


rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None

  • 1
  • 2
  • 3
2.1.4.4. codegen build_module
python/tvm/target/codegen.py

def build_module(mod, target):
    """Build IRModule into Module.

    Parameters
    ----------
    mod : tvm.IRModule
        The ir module.

    target : str
        The target module type.

    Returns
    -------
    module : runtime.Module
        The corressponding module.
    """
    target = Target(target) if isinstance(target, str) else target
    return _ffi_api.Build(mod, target)

src/target/codegen.cc

TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);

runtime::Module Build(IRModule mod, Target target) {
  if (transform::PassContext::Current()
          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
          .value()) {
    mod = tir::transform::SkipAssert()(mod);
  }
  std::string build_f_name;
  if (target->kind->name == "micro_dev") {
    build_f_name = "target.build.c";
  } else {
    build_f_name = "target.build." + target->kind->name;
  }

  // the build function.
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  ICHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}

这里以llvm为例来说明,即build_f_name = "target.build.llvm"

src/target/llvm/llvm_module.cc

TVM_REGISTER_GLOBAL("target.build.llvm")
    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
      LOG(INFO) << "TVM_REGISTER_GLOBAL target.build.llvm ";
      auto n = make_object<LLVMModuleNode>();
      n->Init(mod, target);
      return runtime::Module(n);
    });

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

创建LLVMModuleNode,调用init函数,返回runtime::Module(n),主要包含以下内容:

2.1.4.4.1. 第一步初始化LLVMEnv,src/target/llvm/llvm_common.cc

2.1.4.4.2. GetLLVMTargetMachine(target)

2.1.4.4.3. CodeGenLLVM::Create(tm_.get())

2.1.4.4.4. std::vector funcs;

2.1.4.4.5. CodeGenLLVM init: cg->Init(…)

2.1.4.4.6. AddFunction

2.1.4.4.7. AddMainFunction

2.1.4.4.8. LinkParameters

2.1.4.4.9. CodeGenLLVM Finish: module_ = cg->Finish();

2.1.4.4.10. llvm::Module> module_->addModuleFlag

src/target/llvm/llvm_module.cc

void Init(const IRModule& mod, const Target& target) {
    InitializeLLVM(); 
    tm_ = GetLLVMTargetMachine(target);
    bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false));
    bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt);

    ctx_ = std::make_shared<llvm::LLVMContext>();
    std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());

    std::vector<PrimFunc> funcs;
    std::string entry_func;

    Map<String, LinkedParam> linked_params;
    bool found_linked_params = false;
    bool could_have_linked_params = target->GetAttr<Bool>("link-params").value_or(Bool(false));

    for (auto kv : mod->functions) {
      if (could_have_linked_params &&
          kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) {
        Map<String, ObjectRef> attrs_dict =
            Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict);
        linked_params =
            Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]);
        found_linked_params = true;
        continue;
      }

      auto f = Downcast<PrimFunc>(kv.second);
      auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);

      function_names_.push_back(global_symbol.value());
      if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
        entry_func = global_symbol.value();
      }
      funcs.push_back(f);
    }

    // TODO(tqchen): remove the entry function behavior as it does not
    // makes sense when we start to use multiple modules.
    cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);
    for (const auto& f : funcs) {
      cg->AddFunction(f);
    }
    if (entry_func.length() != 0) {
      cg->AddMainFunction(entry_func);
    }
    if (found_linked_params) {
      cg->LinkParameters(linked_params);
    }
    module_ = cg->Finish();
    module_->addModuleFlag(llvm::Module::Warning, "tvm_target",
                           llvm::MDString::get(*ctx_, LLVMTargetToString(target)));
    module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
                           llvm::DEBUG_METADATA_VERSION);

    if (tm_->getTargetTriple().isOSDarwin()) {
      module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
    }

    target_ = target;
    mptr_ = module_.get();
  }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
2.1.4.4.1. 第一步初始化LLVMEnv

这里会调用llvm相关接口,进行Target相关信息初始化。

src/target/llvm/llvm_common.cc

void InitializeLLVM() {
  LLVMEnv* e = LLVMEnv::Global();
  if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) {
    std::lock_guard<std::mutex> lock(e->mu);
    if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) {
      llvm::InitializeAllTargetInfos();
      llvm::InitializeAllTargets();
      llvm::InitializeAllTargetMCs();
      llvm::InitializeAllAsmParsers();
      llvm::InitializeAllAsmPrinters();
      e->all_initialized.store(true, std::memory_order::memory_order_release);
    }
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
2.1.4.4.2. GetLLVMTargetMachine(target)

根据传入的target信息,获取对应的llvm target Machine信息,追后调用 llvm_target->createTargetMachine创建

llvm::TargetMachine。

std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target, bool allow_null) {
  std::string target_triple, mcpu, mattr;
  llvm::TargetOptions opt;

  ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt);

  if (target_triple.length() == 0 || target_triple == "default") {
    target_triple = llvm::sys::getDefaultTargetTriple();
  }
  if (mcpu.length() == 0) {
    mcpu = "generic";
  }

  std::string err;
  const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err);
  if (llvm_target == nullptr) {
    ICHECK(allow_null) << err << " target_triple=" << target_triple;
    return nullptr;
  }
  llvm::TargetMachine* tm =
      llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
  return std::unique_ptr<llvm::TargetMachine>(tm);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
2.1.4.4.3. CodeGenLLVM::Create(tm_.get())
src/target/llvm/codegen_llvm.h
src/target/llvm/codegen_llvm.cc

std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine* tm) {
  std::string target = tm->getTarget().getName();
  std::string factory_name = "tvm.codegen.llvm.target_" + target;
  //eg: ubuntu pc factory_name -> tvm.codegen.llvm.target_x86-64
  const PackedFunc* f = runtime::Registry::Get(factory_name);
  if (f != nullptr) {
    void* handle = (*f)();
    return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
  } else {
    return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
  }
}

src/target/llvm/codegen_x86_64.cc


TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
    .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
      CodeGenLLVM* cg = new CodeGenX86_64();
      *rv = static_cast<void*>(cg);
    });


class CodeGenX86_64 final : public CodeGenCPU {
 public:
  llvm::Value* VisitExpr_(const CastNode* op) override;

 private:
  llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
                                const std::vector<llvm::Value*>& args);
};

src/target/llvm/codegen_cpu.h

// CPU host code generation
class CodeGenCPU : public CodeGenLLVM {...}


src/target/llvm/codegen_llvm.h

/*!
 * \brief A base class to generate a LLVM.
 */
class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
                    public StmtFunctor<void(const Stmt&)> {...}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

如上描述所示,CodeGenLLVM是生成LLVMcode 的基类,在TVM的圆满中 CodeGenLLVM的子类有四个,分别是:

  1. CodeGenAMDGPU:AMDGPU code generator.

  2. CodeGenCPU:CPU host code generator.

  3. CodeGenHexagon:Hexagon code generator.

  4. CodeGenNVPTX:NVPTX code generator.

2.1.4.4.4. std::vector funcs;

创建PrimFunc vector 并将mod->functions中对应的Primfunc赋值给funcs

2.1.4.4.5. CodeGenLLVM init: cg->Init(…)
src/target/llvm/llvm_module.cc

void Init(const IRModule& mod, const Target& target) {
  ······

cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);
  ······
}

根据上面的示例中,我们在pc端编译,因此这里使用的是CodeGenX86_64继承自 CodeGenCPU,

因此调用 CodeGenCPU中的init函数

src/target/llvm/codegen_cpu.cc


void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm,
                      llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup,
                      bool target_c_runtime) {
  CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime);
  dbg_info_ = CreateDebugInfo(module_.get());
  static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
  func_handle_map_.clear();
  export_system_symbols_.clear();
  // TVM runtime types
  t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, DataType::ShapeIndex().bits());
  t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
  t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
  t_tvm_func_handle_ = t_void_p_;
  t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_context_, t_int_, t_tvm_type_,
                                           t_tvm_shape_index_->getPointerTo(),
                                           t_tvm_shape_index_->getPointerTo(), t_int64_});
  t_tvm_value_ = llvm::StructType::create({t_float64_});
  t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_});
  ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get(
      t_int_,
      {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_,
       t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_},
      false);
  t_tvm_crt_func_registry_ = llvm::StructType::create(
      {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()});
  t_tvm_crt_module_ = llvm::StructType::create({t_tvm_crt_func_registry_->getPointerTo()});
  ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
      t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false);
  md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_);
  // Runtime functions.
  ftype_tvm_func_call_ = llvm::FunctionType::get(
      t_int_,
      {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_,
       t_tvm_value_->getPointerTo(), t_int_->getPointerTo()},
      false);
  ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(
      t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false);
  ftype_tvm_api_set_last_error_ =
      llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false);
  ftype_tvm_parallel_launch_ = llvm::FunctionType::get(
      t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false);
  ftype_tvm_parallel_barrier_ =
      llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false);
  ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false);
  ftype_tvm_static_init_ =
      llvm::FunctionType::get(t_int_,
                              {t_void_p_->getPointerTo(),
                               ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_},
                              false);
  // initialize TVM runtime API
  if (system_lib && !target_c_runtime) {
    // We will need this in environment for backward registration.
    f_tvm_register_system_symbol_ = llvm::Function::Create(
        llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
        llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
  } else {
    f_tvm_register_system_symbol_ = nullptr;
  }
  if (dynamic_lookup || system_lib) {
    f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage,
                                              "TVMFuncCall", module_.get());
    f_tvm_get_func_from_env_ =
        llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage,
                               "TVMBackendGetFuncFromEnv", module_.get());
    f_tvm_api_set_last_error_ =
        llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage,
                               "TVMAPISetLastError", module_.get());
    f_tvm_parallel_launch_ =
        llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage,
                               "TVMBackendParallelLaunch", module_.get());
    f_tvm_parallel_barrier_ =
        llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage,
                               "TVMBackendParallelBarrier", module_.get());
  }
  this->InitGlobalContext(dynamic_lookup);
  target_c_runtime_ = target_c_runtime;
  is_system_lib_ = system_lib;
}

在CodeGenCPU中的init函数中会调用CodeGenLLVM::Init的函数

src/target/llvm/codegen_llvm.cc

void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
                       llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup,
                       bool target_c_runtime) {
  InitializeLLVM();
  ctx_ = ctx;
  builder_.reset(new IRBuilder(*ctx_)); /**std::unique_ptr<llvm::Module> module_;**/
  module_.reset(new llvm::Module(module_name, *ctx_));
  md_builder_.reset(new llvm::MDBuilder(*ctx_));
  // types
  t_void_ = llvm::Type::getVoidTy(*ctx_);
  t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
  t_int_ = llvm::Type::getInt32Ty(*ctx_);
  t_char_ = llvm::Type::getInt8Ty(*ctx_);
  t_int8_ = llvm::Type::getInt8Ty(*ctx_);
  t_int16_ = llvm::Type::getInt16Ty(*ctx_);
  t_int32_ = llvm::Type::getInt32Ty(*ctx_);
  t_int64_ = llvm::Type::getInt64Ty(*ctx_);
  t_float64_ = llvm::Type::getDoubleTy(*ctx_);
  // meta data
  md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
  md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
  md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
  this->InitTarget(tm);
}

src/target/llvm/codegen_llvm.h
  /*!
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
   * \param tm Target machine model
   * \param ctx The context.
   * \param system_lib Whether to insert system library registration.
   * \param dynamic_lookup Whether dynamically lookup runtime function
   *                       or use the runtime function table passed by caller.
   * \param target_c_runtime If true, generate a module to be executed by the C runtime. In practice
   *                       this option influences whether global ctors are used.
   */
  virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx,
                    bool system_lib, bool dynamic_lookup, bool target_c_runtime);

以上两个Init的函数的调用,主要是为了初始化CodeGenCPU和LLVM的属性。

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
2.1.4.4.6. AddFunction

在 AddFunction中,主要是将第四步中生成的funcs,添加到当前的Moudle(std::unique_ptrllvm::Module module_;)中,

(在该示例中是将PrimFunc转换为LLVM 中的Function )

根据上面流程中所示,这里调用CodeGenCPU中的AddFunction中的函数,而在该

函数中,会首先调用CodeGenLLVM的AddFunction函数,在该函数中,主要是创建llvm::Function,

并根据PrimFunc中body相关内容对对应属性进行赋值。主要代码如下所示:

src/target/llvm/llvm_module.cc
  void Init(const IRModule& mod, const Target& target) {
    ......

    for (const auto& f : funcs) {
      cg->AddFunction(f);
    }
    ......
  }


src/target/llvm/codegen_cpu.cc

void CodeGenCPU::AddFunction(const PrimFunc& f) {
  CodeGenLLVM::AddFunction(f);
  if (f_tvm_register_system_symbol_ != nullptr) {
    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
    ICHECK(global_symbol.defined())
        << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
    export_system_symbols_.emplace_back(
        std::make_pair(global_symbol.value().operator std::string(), function_));
  }
  AddDebugInformation(function_);
}

src/target/llvm/codegen_llvm.h
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
  virtual void AddFunction(const PrimFunc& f);

src/target/llvm/codegen_llvm.cc

void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }

void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
  this->InitFuncState();

   
  std::vector<llvm::Type*> param_types;
  is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
  for (Var param : f->params) {
    param_types.push_back(GetLLVMType(param));
    if (!is_restricted_ && param.dtype().is_handle()) {
      alias_var_set_.insert(param.get());
    }
  }
  // TODO(tvm-team):
  // Update the function type to respect the ret_type field of f.
  // Once we allow more flexibility in the PrimFunc.
  llvm::FunctionType* ftype =
      llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);

  function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
                                     global_symbol.value().operator std::string(), module_.get());
  function_->setCallingConv(llvm::CallingConv::C);
  function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);

  // set var map and align information
  auto arg_it = function_->arg_begin();
  for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
    llvm::Argument* v = &(*arg_it);
    const Var& var = f->params[i];
    var_map_[var.get()] = v;
    if (is_restricted_) {
      if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
        // set non alias.
        function_->addParamAttr(i, llvm::Attribute::NoAlias);
    }
  }
  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
  builder_->SetInsertPoint(entry);
  this->VisitStmt(f->body);

  llvm::StringRef fs = target_machine_->getTargetFeatureString();
  if (!fs.empty()) {
    function_->addFnAttr("target-features", fs);
  }

  if (ret_void) {
    builder_->CreateRetVoid();
  } else {
    builder_->CreateRet(ConstInt32(0));
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
2.1.4.4.7. AddMainFunction

主要创建llvm::GlobalVariable并根据module_中entry_func_name进行赋值

src/target/llvm/codegen_cpu.cc

void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
  llvm::Function* f = module_->getFunction(entry_func_name);
  llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
  llvm::GlobalVariable* global =
      new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr,
                               runtime::symbol::tvm_module_main);
#if TVM_LLVM_VERSION >= 100
  global->setAlignment(llvm::Align(1));
#else
  global->setAlignment(1);
#endif
  // comdat is needed for windows select any linking to work
  // set comdat to Any(weak linking)
  if (target_machine_->getTargetTriple().isOSWindows()) {
    llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
    comdat->setSelectionKind(llvm::Comdat::Any);
    global->setComdat(comdat);
  }

  global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name));
  global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
2.1.4.4.8. LinkParameters

src/target/llvm/codegen_llvm.h

  /*!
   * \brief Link parameters into the module so they don't need to be supplied at runtime.
   * Parameters can be linked into the module so that the generated code is easier to use, or so
   * that RAM space doesn't need to be allocated for them. This function adds the given parameters
   * to the generated LLVM module.
   * \param storage_id_offset Offset added to the index of each entry in params_by_sid to form the
   *     storage_id of that parameter. Storage ids for parameters are expected to be contiguous.
   * \param params_by_sid Array of NDArray. Each entry is a parameter. The index of the array (added
   *     to sid_offset) is the storage_id of the param.
   * \param param_names Array containing the name for each param in params_by_sid.
   */
  void LinkParameters(const Map<String, LinkedParam> params);

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
2.1.4.4.9. CodeGenLLVM Finish: module_ = cg->Finish();
src/target/llvm/codegen_cpu.cc

std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
  // link modules
  if (dbg_info_ != nullptr) {
    dbg_info_->di_builder_->finalize();
  }
  return CodeGenLLVM::Finish();
}

src/target/llvm/codegen_llvm.cc

std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
  this->AddStartupFunction();
  for (size_t i = 0; i < link_modules_.size(); ++i) {
    ICHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
        << "Failed to link modules";
  }
  link_modules_.clear();
  // optimize
  this->Optimize();
  return std::move(module_);
}


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

在Finish函数中,会调用CodeGenLLVM的Finish(),在该函数中主要有两个主要内容:

  1. AddStartupFunction

  2. Optimize

下面我们来看这两个函数的主要内容:

AddStartupFunction

src/target/llvm/codegen_cpu.cc

void CodeGenCPU::AddStartupFunction() {
  if (!target_c_runtime_) {
    llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
    function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
                                       "__tvm_module_startup", module_.get());
    llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
    builder_->SetInsertPoint(startup_entry);
    for (const auto& kv : export_system_symbols_) {
      llvm::Value* name = GetConstString(kv.first);
      builder_->CreateCall(f_tvm_register_system_symbol_,
                           {name, builder_->CreateBitCast(kv.second, t_void_p_)});
    }
    llvm::appendToGlobalCtors(*module_, function_, 65535);
    builder_->CreateRet(nullptr);
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

Optimize

这里主要进行LLVM相关优化,优化level为3.


src/target/llvm/codegen_llvm.cc

void CodeGenLLVM::Optimize() {
  // pass manager
  FPassManager fpass(module_.get());
  MPassManager mpass;
  mpass.add(llvm::createTargetTransformInfoWrapperPass(
      target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
  fpass.add(llvm::createTargetTransformInfoWrapperPass(
      target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));

  // place optimization pass
  llvm::PassManagerBuilder builder;
  builder.OptLevel = 3;

#if TVM_LLVM_VERSION >= 50
  builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
#else
  builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
#endif
  builder.LoopVectorize = true;
  builder.SLPVectorize = true;
  this->InitPassManagerBuilder(&builder);

#if TVM_LLVM_VERSION >= 50
  target_machine_->adjustPassManager(builder);
#endif

  builder.populateFunctionPassManager(fpass);
  builder.populateModulePassManager(mpass);

  fpass.doInitialization();
  for (auto it = module_->begin(); it != module_->end(); ++it) {
    fpass.run(*it);
  }
  fpass.doFinalization();
  mpass.run(*module_);
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
2.1.4.4.10. llvm::Module> module_->addModuleFlag

针对llvm module添加相关Flag

    module_->addModuleFlag(llvm::Module::Warning, "tvm_target",
                           llvm::MDString::get(*ctx_, LLVMTargetToString(target)));
    module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
                           llvm::DEBUG_METADATA_VERSION);

    if (tm_->getTargetTriple().isOSDarwin()) {
      module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
    }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
2.1.4.4.11. codegen build_module总结:

codegen build_module主要内容根据以上流程,会调用LLVMModuleNode的init函数,在该函数中主要是将

IRModule转换为对应Traget(在上面的示例中我用是在pc端进行编译的,所以最后的target

<tvm.codegen.llvm.target_x86-64> codegen是在llvm cpu相关的实现。最后生成llvm moudule,

而LLVMModuleNode是runtime::ModuleNode的子类,返回runtime::Module(n)

2.1.4.5. Generate a unified host module

在上面的介绍中,build_for_device步骤中,会根据用户设置的target和target_host进行build,、

最后生成mod_host和target对应的codegen结果。在完成上述步骤后,会再一次调用codegen.build_module

不过这次传入参数是target_host,即生成target_host对应的代码。最后会将target对应的codegen结果和

Target_host对应的结果进行合并,将target对应的module赋值给target_host module 对应的import字段

代码如下

    rt_mod_host = codegen.build_module(mod_host_all, target_host)
    # Import all modules.
    for mdev in device_modules:
        if mdev:
            rt_mod_host.import_module(mdev)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在tvm.build的最后,会根据Target host的类型,调用相应的函数:

如果target_host是“C”,则调用create_csource_crt_metadata_module([rt_mod_host], target_host)

如果target_host是“llvm”,则调用CreateLLVMCrtMetadataModule([rt_mod_host], target_host)

    if not isinstance(target_host, Target):
        target_host = Target(target_host)
    if (
        target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c"
        and target_host.attrs.get("system-lib", 0).value == 1
    ):
        if target_host.kind.name == "c":
            create_csource_crt_metadata_module = tvm._ffi.get_global_func(
                "runtime.CreateCSourceCrtMetadataModule"
            )
            return create_csource_crt_metadata_module([rt_mod_host], target_host)

        if target_host.kind.name == "llvm":
            create_llvm_crt_metadata_module = tvm._ffi.get_global_func(
                "runtime.CreateLLVMCrtMetadataModule"
            )
            return create_llvm_crt_metadata_module([rt_mod_host], target_host)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

下面分别来看两个函数的调用栈

2.1.4.5.1. CreateCSourceCrtMetadataModule



src/target/source/source_module.cc

TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule")
    .set_body_typed([](const Array<runtime::Module>& modules, Target target) {
      return CreateCSourceCrtMetadataModule(modules, target);
    });

src/target/source/source_module.cc

runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& modules,
                                               Target target) {
  Array<String> func_names;
  for (runtime::Module mod : modules) {
    auto pf_funcs = mod.GetFunction("get_func_names");
    if (pf_funcs != nullptr) {
      Array<String> func_names_ = pf_funcs();
      for (const auto& fname : func_names_) {
        func_names.push_back(fname);
      }
    }
  }
  auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "cc", target);
  auto csrc_metadata_module = runtime::Module(n);
  for (const auto& mod : modules) {
    csrc_metadata_module.Import(mod);
  }
  return std::move(csrc_metadata_module);
}


  CSourceCrtMetadataModuleNode(const Array<String>& func_names, const std::string& fmt,
                               Target target)
      : fmt_(fmt), func_names_(func_names), target_(target) {
    CreateSource();
  }

  void CreateSource() {
    if (target_->GetAttr<Bool>("system-lib").value_or(Bool(false)) && !func_names_.empty()) {
      CreateFuncRegistry();
      GenerateCrtSystemLib();
    }
    code_ << ";";
  }
};

  void CreateFuncRegistry() {
    code_ << "#include <tvm/runtime/crt/module.h>\n";
    for (const auto& fname : func_names_) {
      code_ << "#ifdef __cplusplus\n";
      code_ << "extern \"C\"\n";
      code_ << "#endif\n";
      code_ << "TVM_DLL int32_t " << fname.data();
      code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, int* "
               "out_type_code);\n";
    }
    code_ << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n";
    for (auto f : func_names_) {
      code_ << "    (TVMBackendPackedCFunc)" << f << ",\n";
    }
    code_ << "};\n";
    auto registry = target::GenerateFuncRegistryNames(func_names_);
    code_ << "static const TVMFuncRegistry _tvm_func_registry = {\n"
          << "    \"" << ::tvm::support::StrEscape(registry.data(), registry.size(), true) << "\","
          << "    _tvm_func_array,\n"
          << "};\n";
  }

  void GenerateCrtSystemLib() {
    code_ << "static const TVMModule _tvm_system_lib = {\n"
          << "    &_tvm_func_registry,\n"
          << "};\n"
          << "const TVMModule* TVMSystemLibEntryPoint(void) {\n"
          << "    return &_tvm_system_lib;\n"
          << "}\n";
  }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
2.1.4.5.2. CreateLLVMCrtMetadataModule

CreateLLVMCrtMetadataModule

src/target/llvm/llvm_module.cc

TVM_REGISTER_GLOBAL("runtime.CreateLLVMCrtMetadataModule")
    .set_body_typed([](const Array<runtime::Module>& modules, Target target) {
      return CreateLLVMCrtMetadataModule(modules, target);
    });


runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& modules, Target target) {
  Array<String> func_names;
  for (runtime::Module mod : modules) {
    auto pf_funcs = mod.GetFunction("get_func_names");
    if (pf_funcs != nullptr) {
      Array<String> func_names_ = pf_funcs();
      for (const auto& fname : func_names_) {
        func_names.push_back(fname);
      }
    }
  }

  InitializeLLVM();
  auto tm = GetLLVMTargetMachine(target);
  bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false));
  bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt);
  ICHECK(system_lib && target_c_runtime)
      << "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; "
      << "got target: " << target->str();
  auto ctx = std::make_shared<llvm::LLVMContext>();
  std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
  cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime);

  cg->DefineFunctionRegistry(func_names);
  auto mod = cg->Finish();
  mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
                     llvm::MDString::get(*ctx, LLVMTargetToString(target)));
  mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION);

  if (tm->getTargetTriple().isOSDarwin()) {
    mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
  }

  std::string verify_errors_storage;
  llvm::raw_string_ostream verify_errors(verify_errors_storage);
  LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
      << "LLVM module verification failed with the following errors: \n"
      << verify_errors.str();

  auto n = make_object<LLVMModuleNode>();
  n->Init(std::move(mod), ctx);
  for (auto m : modules) {
    n->Import(m);
  }
  return runtime::Module(n);
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
2.1.4.6. tmv.build输出结果

target=‘c’, target_host=‘c -keys=cpu -link-params=0’

tir:
 #[version = "0.0.5"]
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "myadd", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [2], []),
             A: Buffer(A_2: Pointer(float32), float32, [2], []),
             B: Buffer(B_2: Pointer(float32), float32, [2], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, 2) {
    C_2[i] = ((float32*)A_2[i] + (float32*)B_2[i])
  }
}

#[metadata]
{
  "root": 1, 
  "nodes": [
    {
      "type_key": ""
    }, 
    {
      "type_key": "Map", 
      "keys": [
        "IntImm"
      ], 
      "data": [2]
    }, 
    {
      "type_key": "Array", 
      "data": [3]
    }, 
    {
      "type_key": "IntImm", 
      "attrs": {
        "dtype": "bool", 
        "span": "0", 
        "value": "1"
      }
    }
  ], 
  "b64ndarrays": [], 
  "attrs": {"tvm_version": "0.8.dev0"}
}
source code:
 // tvm target: c -keys=cpu -link-params=0
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
void* __tvm_module_ctx = NULL;
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t myadd(void* args, void* arg_type_ids, int32_t num_args, void* out_ret_value, void* out_ret_tcode, void* resource_handle) {
  void* arg0 = (((TVMValue*)args)[0].v_handle);
  int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
  void* arg1 = (((TVMValue*)args)[1].v_handle);
  int32_t arg1_code = ((int32_t*)arg_type_ids)[(1)];
  void* arg2 = (((TVMValue*)args)[2].v_handle);
  int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
  void* A = (((DLTensor*)arg0)[0].data);
  void* arg0_shape = (((DLTensor*)arg0)[0].shape);
  void* arg0_strides = (((DLTensor*)arg0)[0].strides);
  int32_t dev_id = (((DLTensor*)arg0)[0].device.device_id);
  void* B = (((DLTensor*)arg1)[0].data);
  void* arg1_shape = (((DLTensor*)arg1)[0].shape);
  void* arg1_strides = (((DLTensor*)arg1)[0].strides);
  void* C = (((DLTensor*)arg2)[0].data);
  void* arg2_shape = (((DLTensor*)arg2)[0].shape);
  void* arg2_strides = (((DLTensor*)arg2)[0].strides);
  if (!(arg0_strides == NULL)) {
  }
  if (!(arg1_strides == NULL)) {
  }
  if (!(arg2_strides == NULL)) {
  }
  for (int32_t i = 0; i < 2; ++i) {
    ((float*)C)[(i)] = (((float*)A)[(i)] + ((float*)B)[(i)]);
  }
  return 0;
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
2.2 示例2:使用TVM python接口,解析模型并编译

如下列代码所示,该示例主要包含以下内容:

2.2.1 通过Relay接口加载pytorch预训练模型resnet18

2.2.2 通过Relay接口relay.build编译模型,并进行opt_level = 3优化

import tvm
from tvm import relay

# PyTorch imports
import torch
import torchvision

# Load a pretrained PyTorch model
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()


# Import the graph to Relay
# Convert PyTorch graph to Relay graph. The input name can be arbitrary.
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

# Relay Build
# Compile the graph to llvm target with given input specification.
target = tvm.target.Target("llvm", host="llvm")
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
2.2.1 通过Relay接口加载pytorch预训练模型resnet18

将pytorch 模型转换为Relay IR表示的TVM IRmodule,这一部分的主要内容就是算子的Mapping相关

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

python/tvm/relay/frontend/pytorch.py

def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dtype="float32"):
    """Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
    The companion parameters will be handled automatically.

    Parameters
    ----------
    script_module : TopLevelTracedModule object
        TorchScripted PyTorch graph
        Note: We currently only support traces (ie: torch.jit.trace(model, input))

    input_infos : List of tuples
        Can be (input name, input shape) or (input name, (input shape, input types))
        Graph level input shape and type list
        The same input names need to be used for deployment, so choose easy to
        remember names (such as: input0, input1)
        e.g.
        [('input0', (1, 2)), ('input1', (3, 4))]
        or
        [('input0', ((1, 2), 'int')), ('input1', ((3, 4), 'float'))]

    custom_convert_map : Dictionary of str to Relay op
        A custom op conversion map in the same format as _convert_map above

    Returns
    -------
    mod : tvm.relay.Module
        The module that optimizations will be performed on.

    params : dict of str to tvm.runtime.NDArray
        Dict of converted parameters stored in tvm.runtime.ndarray format
    """
    import torch

    mod = tvm.IRModule()
    prelude = Prelude(mod)

    converter = PyTorchOpConverter(prelude, default_dtype)

    graph = script_module.graph.copy()
    _run_jit_passes(graph)

    if custom_convert_map:
        converter.update_convert_map(custom_convert_map)

    op_names = get_all_op_names(graph)
    converter.report_missing_conversion(op_names)

    is_module = isinstance(script_module, torch.jit.ScriptModule)
    params = script_module.state_dict() if is_module else {}
    outputs = _get_relay_input_vars(
        graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
    )
    param_vars, tensors, packed_param_map = convert_params(graph, params)
    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

    outputs.update(param_vars)
    ret_name = _get_input_names(graph.return_node())

    # For quantized models
    quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
    if len(quantized_ops.intersection(set(op_names))) > 0:
        weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
        qnn_torch.add_input_quant_params_to_op_inputs(graph)
        qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params)
        qnn_torch.add_quant_params(tvm_params, weight_quant_params)
        converter.update_convert_map(qnn_torch.convert_map)

    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
    if isinstance(ret, list):
        # ListConstruct kept original python list. Convert to tuple.
        ret = _expr.Tuple(ret)

    # Separate data inputs and parameters to make sure data inputs are always in the beginning.
    func_args = []
    data_inputs = []
    for arg in _analysis.free_vars(ret):
        if arg.name_hint not in tvm_params.keys():
            data_inputs.append(arg)
        else:
            func_args.append(arg)
    func_args = data_inputs + func_args
    mod["main"] = tvm.relay.Function(func_args, ret)
    return transform.RemoveUnusedFunctions()(mod), tvm_params

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
2.2.2 relay.build

根据第一步生成的IRModule进行编译,编译结果有三个:graph executor graph_json 、tvm.Module、params of the final graph

lib = relay.build(mod, target=target, params=params)

python/tvm/relay/build_module.py

def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"):

    """Helper function that builds a Relay function to run on TVM graph executor.

    Parameters
    ----------
    ir_mod : :py:class:`~tvm.IRModule`
        The IR module to build. Using relay.Function is deprecated.

    target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional
        For heterogeneous compilation, it is a dictionary indicating context to
        target mapping. For homogeneous compilation, it is a build target.

    target_host : str or :any:`tvm.target.Target`

    params : dict of str to NDArray
        Input parameters to the graph that do not change
        during inference time. Used for constant folding.

    mod_name: Optional[str]
        The module name we will build

    Returns
    -------
    graph_json : str
        The json string that can be accepted by graph executor.

    mod : tvm.Module
        The module containing necessary libraries.

    params : dict
        The parameters of the final graph.
    """
    
    ## ir_mod check : support tvm.IRModule

    ## target update

    ## target_host update

    # If current dispatch context is fallback context (the default root context),
    # then load pre-tuned parameters from TopHub
    if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
        tophub_context = autotvm.tophub.context(list(target.values()))
    else:
        tophub_context = autotvm.utils.EmptyContext()

    with tophub_context:
        bld_mod = BuildModule()
        graph_json, runtime_mod, params = bld_mod.build(mod=ir_mod, target=target, params=params)
        executor_factory = _graph_executor_factory.GraphExecutorFactoryModule(
            ir_mod, target, graph_json, runtime_mod, mod_name, params
        )
        return executor_factory

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

该函数中主要有三个调用

2.2.2.1 BuildModule

2.2.2.2 bld_mod.build

2.2.2.3 GraphExecutorFactoryModule

下面针对这三个调用分别调用:

2.2.2.1 BuildModule

BuildModule调用的结果就是创建了一个RelayBuildModule,RelayBuildModule继承自ModuleNode。


python/tvm/relay/build_module.py

class BuildModule(object):
    """Build an IR module to run on TVM graph executor. This class is used
    to expose the `RelayBuildModule` APIs implemented in C++.
    """

    def __init__(self):
        self.mod = _build_module._BuildModule()

src/relay/backend/build_module.cc

TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
  *rv = RelayBuildCreate();
});

runtime::Module RelayBuildCreate() {
  auto exec = make_object<RelayBuildModule>();
  return runtime::Module(exec);
}

/*!
 * \brief Relay build module
 *
 */
class RelayBuildModule : public runtime::ModuleNode {

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
2.2.2.2 bld_mod.build

bld_mod即上一步中创建的RelayBuildModule,这里的build的函数也是RelayBuildModule对应的build函数。


python/tvm/relay/build_module.py

def build(self, mod, target=None, target_host=None, params=None):
        """
       
        Returns
        -------
        factory_module : tvm.relay.backend.graph_executor_factory.GraphExecutorFactoryModule
            The runtime factory for the TVM graph executor.
        """
        target = _update_target(target)
        target, target_host = Target.check_and_update_host_consist(
            target, target_host, target_is_dict_key=False
        )

        # Setup the params.
        if params:
            self._set_params(params)

        # Build the IR module. If auto_scheduler is not enabled,
        # then use the TOPI-defined schedule.
        use_auto_scheduler = PassContext.current().config.get(
            "relay.backend.use_auto_scheduler", False
        )

        # Turn off AutoTVM config not found warnings if auto_scheduler is enabled.
        old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
        autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler

        self._build(mod, target, target_host)
        autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent

        # Get artifacts
        graph_json = self.get_json()
        mod = self.get_module()
        params = self.get_params()

        return graph_json, mod, params

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

如build函数中所示,这里最终会调用 self._build(mod, target, target_host),最后将build的结果返回,即返回 graph_json, mod, params


src/relay/backend/build_module.cc

 /*!
   * \brief Build relay IRModule for graph executor
   *
   * \param mod Relay IRModule
   * \param target Target device
   * \param target_host Host target device
   */
  void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
    // Create protected variable targets_ from ground up
    targets_ = targets;
    target_host_ = target_host;
    CheckAndUpdateHostConsistency(&targets_, &target_host_);
    BuildRelay(mod, params_);
    // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
    CompileEngine::Global()->Clear();
  }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

而self._build()对应的函数如下,最后会调用BuildRelay函数:

/*!
   * \brief Compile a Relay IR module to runtime module.
   *
   * \param relay_module The Relay IR module.
   * \param params The parameters.
   */
  void BuildRelay(IRModule relay_module,
                  const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
    Target target_host = GetTargetHost();
    // If no target_host has been set, we choose a default one, which is
    // llvm if "codegen.LLVMModuleCreate" is accessible.
    const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
    if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm");

    // Update all the targets in the targets_ TargetsMap
    CheckAndUpdateHostConsistency(&targets_, &target_host);

    // Relay IRModule -> IRModule optimizations.
    relay_module = Optimize(relay_module, targets_, params);
    // Get the updated function.
    auto func = Downcast<Function>(relay_module->Lookup("main"));

    // Generate code for the updated function.
    graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
    graph_codegen_->Init(nullptr, targets_);
    graph_codegen_->Codegen(func);

    ret_.graph_json = graph_codegen_->GetJSON();
    ret_.params = graph_codegen_->GetParams();

    auto lowered_funcs = graph_codegen_->GetIRModule();

    // Generate a placeholder function that attaches linked params as its arguments.
    if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {
      CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen.";
      auto param_ids = graph_codegen_->GetParamIds();
      auto link_params = Map<String, tir::LinkedParam>();
      for (auto param : ret_.params) {
        link_params.Set(param.first, tir::LinkedParam(param_ids[param.first], param.second));
      }

      Map<String, ObjectRef> dict;
      dict.Set(tvm::tir::attr::kLinkedParams, link_params);
      dict.Set(tvm::attr::kGlobalSymbol, String(::tvm::runtime::symbol::tvm_lookup_linked_param));
      DictAttrs attrs{dict};
      auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
                                Map<tir::Var, tir::Buffer>(), attrs);
      if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) {
        lowered_funcs.Set(target_host->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
      }
      lowered_funcs[target_host->str()]->Add(
          GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim);
    }

    // When there is no lowered_funcs due to reasons such as optimization.
    if (lowered_funcs.size() == 0) {
      if (target_host.defined() && target_host->kind->name == "llvm") {
        // If we can decide the target is LLVM, we then create an empty LLVM module.
        ret_.mod = (*pf)(target_host->str(), "empty_module");
      } else {
        // If we cannot decide the target is LLVM, we create an empty CSourceModule.
        // The code content is initialized with ";" to prevent complaining
        // from CSourceModuleNode::SaveToFile.
        ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
      }
    } else {
      ret_.mod = tvm::build(lowered_funcs, target_host_);
    }

    auto ext_mods = graph_codegen_->GetExternalModules();
    ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost());
  }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

在BuildRelay,主要是将IRModule编译成Runtime Module,其中主要包含一下调用

2.2.2.2.1. Optimize

2.2.2.2.2. std::unique_ptr(new GraphCodegen()

2.2.2.2.3. graph_codegen_ Init

2.2.2.2.4. graph_codegen Codegen

2.2.2.2.5. graph_codegen: GetJSON / GetParams /GetIRModule()

2.2.2.2.6. llvm module

2.2.2.2.7. c source module

2.2.2.2.8. tvm::build(lowered_funcs, target_host_)

下面针对每一个调用分别介绍:

2.2.2.2.1. Optimize

Optimize主要主要针对Relay 的优化,这个函数省略了部分条件判断,只包含了对应优化的类型。

    // Relay IRModule -> IRModule optimizations.
    relay_module = Optimize(relay_module, targets_, params);

  /*!
   * \brief Optimize a Relay IRModule.
   *
   * \param relay_module The input IRModule where optmization will be applied on.
   * \param targets The device type to `Target` mapping.
   * \param params The param name to value mapping.
   *
   * \return relay::IRModule The updated Relay IR module after optimization.
   */
  IRModule Optimize(IRModule relay_module, const TargetsMap& targets,
                    const std::unordered_map<std::string, runtime::NDArray>& params) {

    transform::RemoveUnusedFunctions(entry_functions)
    transform::ToBasicBlockNormalForm()

    // Run all dialect legalization passes.
    relay::qnn::transform::Legalize()

    // Legalize pass is restricted to homogeneous execution for now.
    transform::Legalize()
    
    transform::SimplifyInference()

    // Convert Dynamic ops to static versions
    transform::DynamicToStatic()
    transform::EliminateCommonSubexpr(fskip)
    transform::SimplifyExpr()
    transform::CombineParallelConv2D(3)
    transform::CombineParallelDense(3)
    transform::CombineParallelBatchMatmul(3)
    transform::FoldConstant()
    transform::FoldScaleAxis()
    transform::CanonicalizeCast()
    transform::CanonicalizeOps()

    // Alter layout transformation is only applied to homogeneous execution yet.
    transform::AlterOpLayout()

    // Fast math optimizations.
    transform::FastMath()
    transform::FoldConstant()

    // Fuse the operations if it is needed.
    transform::FuseOps()
    transform::DefuseOps()
    transform::FoldConstant()
    transform::Inline()

    return relay_module;
  }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
2.2.2.2.2. std::unique_ptr(new GraphCodegen()

GraphCodegen是GraphExecutor对应的一个wrapper

src/relay/backend/build_module.cc

/*!
 * \brief GraphCodegen module wrapper
 *
 */
struct GraphCodegen {
 public:
  GraphCodegen() {
    auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen");
    mod = (*pf)();
  }

}

src/relay/backend/graph_executor_codegen.cc

TVM_REGISTER_GLOBAL("relay.build_module._GraphExecutorCodegen")
    .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); });


runtime::Module CreateGraphCodegenMod() {
  auto ptr = make_object<GraphExecutorCodegenModule>();
  return runtime::Module(ptr);
}

class GraphExecutorCodegenModule : public runtime::ModuleNode {
 public:
  GraphExecutorCodegenModule() {}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
2.2.2.2.3. graph_codegen_ Init

这里的Init函数,就是调用GraphCodegen即GraphExecutorCodegenModule的Init函数。


src/relay/backend/build_module.cc

  void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); }

src/relay/backend/build_module.cc

  template <typename... Args>
  void CallFunc(const std::string& name, Args... args) {
    auto pf = mod.GetFunction(name, false);
    pf(std::forward<Args>(args)...);
    return;
  }
};

src/relay/backend/graph_executor_codegen.cc

class GraphExecutorCodegenModule : public runtime::ModuleNode {
 public:
  GraphExecutorCodegenModule() {}
  virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
    if (name == "init") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
                                    << "runtime::Module mod and Map<int, Target> targets";
        void* mod = args[0];
        Map<Integer, tvm::Target> tmp = args[1];
        TargetsMap targets;
        for (const auto& it : tmp) {
          auto dev_type = it.first.as<tir::IntImmNode>();
          ICHECK(dev_type);
          targets[dev_type->value] = it.second;
        }
        codegen_ = std::make_shared<GraphExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod),
                                                          targets);
      });
    } 

src/relay/backend/graph_executor_codegen.cc

/*! \brief Code generator for graph executor */
class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
 public:
  GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
    compile_engine_ = CompileEngine::Global();
    targets_ = targets;
  }

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
2.2.2.2.4. graph_codegen Codegen

Codegen函数同样,就是调用GraphCodegen即GraphExecutorCodegenModule的Codegen函数。

    graph_codegen_->Codegen(func);

src/relay/backend/build_module.cc

  void Codegen(const Function& func) { CallFunc("codegen", func); }

src/relay/backend/build_module.cc

  template <typename... Args>
  void CallFunc(const std::string& name, Args... args) {
    auto pf = mod.GetFunction(name, false);
    pf(std::forward<Args>(args)...);
    return;
  }
};

src/relay/backend/graph_executor_codegen.cc
class GraphExecutorCodegenModule : public runtime::ModuleNode {
 public:
  GraphExecutorCodegenModule() {}
  virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {

   } else if (name == "codegen") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        Function func = args[0];
        this->output_ = this->codegen_->Codegen(func);
      });
    } 


/*! \brief Code generator for graph executor */
class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {

  LoweredOutput Codegen(relay::Function func) {
    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
    storage_device_map_ = (*pf)(func);
    // First we convert all the parameters into input nodes.
    for (auto param : func->params) {
      auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
      var_map_[param.get()] = AddNode(node_ptr, param);
    }
    heads_ = VisitExpr(func->body);
    std::ostringstream os;
    dmlc::JSONWriter writer(&os);
    GetJSON(&writer);
    LoweredOutput ret;
    ret.graph_json = os.str();
    ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
    for (auto param : params_) {
      ret.params.emplace(std::make_pair(
          param.first,
          std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
    }

    for (auto& kv : lowered_funcs_) {
      if (ret.lowered_funcs.count(kv.first) == 0) {
        ret.lowered_funcs.Set(kv.first, IRModule(Map<GlobalVar, BaseFunc>({})));
      }
      auto& mod = ret.lowered_funcs[kv.first];
      mod->Update(kv.second);
      ret.lowered_funcs.Set(kv.first, mod);
    }
    ret.external_mods = compile_engine_->LowerExternalFunctions();
    return ret;
  }


/*! \brief Lowered outputs */
struct LoweredOutput {
  std::string graph_json;
  Map<String, IRModule> lowered_funcs;
  Array<tvm::runtime::Module> external_mods;
  std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>> params;
};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
2.2.2.2.5. graph_codegen: GetJSON / GetParams /GetIRModule()

GetJSON / GetParams /GetIRModule()函数的作用就是将调用GraphCodegen即GraphExecutorCodegenModule的Codegen函数

的结果进行返回。


    ret_.graph_json = graph_codegen_->GetJSON();
    ret_.params = graph_codegen_->GetParams();

    auto lowered_funcs = graph_codegen_->GetIRModule();

src/relay/backend/build_module.cc

  std::string GetJSON() { return CallFunc<std::string>("get_graph_json", nullptr); }

  std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
    std::unordered_map<std::string, tvm::runtime::NDArray> ret;
    auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
    for (const auto& expr : names) {
      // Implicit cast from runtime::String to std::string
      std::string key = expr;
      ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
    }
    return ret;
  }

  Map<String, IRModule> GetIRModule() {
    return CallFunc<Map<String, IRModule>>("get_irmodule", nullptr);
  }

src/relay/backend/graph_executor_codegen.cc
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {

    else if (name == "get_graph_json") {
      return PackedFunc(
          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; });
    } else if (name == "list_params_name") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        Array<runtime::String> ret;
        for (const auto& kv : this->output_.params) {
          ret.push_back(kv.first);
        }
        *rv = ret;
      });
    } else if (name == "get_param_by_name") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        String key = args[0];
        auto it = this->output_.params.find(key);
        CHECK(it != this->output_.params.end()) << "no such parameter " << key;
        *rv = (*it).second.second;
      });
    } else if (name == "get_irmodule") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        *rv = this->output_.lowered_funcs;
      });


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
2.2.2.2.6. llvm module

如果TargetHost对应的是LLVM,那么会调用如下流程,创建LLVMModule,这里就会进入llvm的codegen相关流程。

src/relay/backend/build_module.cc

    const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");

    ret_.mod = (*pf)(target_host->str(), "empty_module");

src/target/llvm/llvm_module.cc

TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
    .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {
      Target target = Target(target_str);
      auto n = make_object<LLVMModuleNode>();
      // Generate a LLVM module from an input target string
      InitializeLLVM();
      auto tm = GetLLVMTargetMachine(target);
      auto ctx = std::make_shared<llvm::LLVMContext>();
      std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
      // Use a default data layout and target triple
      auto triple = tm->getTargetTriple();
      module->setTargetTriple(triple.str());
      module->setDataLayout(tm->createDataLayout());
      n->Init(std::move(module), ctx);
      return runtime::Module(n);
    });
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
2.2.2.2.7. c source module

如果TargetHost对应的是c,那么会调用如下流程,创建CSourceModule,这里就会进入CSource的codegen相关流程。

  src/relay/backend/build_module.cc

  ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});

  src/target/source/source_module.cc

  runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
                                    const Array<String>& func_names,
                                    const Array<String>& const_vars) {
  auto n = make_object<CSourceModuleNode>(code.operator std::string(), fmt.operator std::string(),
                                          func_names, const_vars);
  return runtime::Module(n);
  }

src/target/source/source_module.cc
// Simulator function
class CSourceModuleNode : public runtime::ModuleNode {
 public:
  CSourceModuleNode(const std::string& code, const std::string& fmt,
                    const Array<String>& func_names, const Array<String>& const_vars)
      : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {}

 protected:
  std::string code_;
  std::string fmt_;
  Array<String> const_vars_;
  Array<String> func_names_;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
2.2.2.2.8. tvm::build(lowered_funcs, target_host_)

还有一种情况就是lowered_funcs为空时,按照注释解释时优化优化导致,即所有func都在target上编译。这里会根据

input对应的target进行编译跟示例1中的编译流程相同

src/driver/driver_api.cc
// Build for heterogeneous execution when target is a string.
runtime::Module build(const Map<String, IRModule>& inputs_arg, const Target& target_host_arg) {
  Map<Target, IRModule> updated_inputs;
  Target target_host = target_host_arg;
  for (const auto& it : inputs_arg) {
    Target target = Target(it.first);
    CheckAndUpdateHostConsistency(&target, &target_host);
    Optional<String> device = target->GetAttr<String>("device");
    if (device.defined() && device.value() == "vta") {
      target = Target("ext_dev");
    }
    updated_inputs.Set(target, it.second);
  }
  return build(updated_inputs, target_host);
}


// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
  auto pass_ctx = transform::PassContext::Current();

  std::vector<runtime::Module> device_modules;
  Map<Target, IRModule> inputs = inputs_arg;
  Target target_host = target_host_arg;

  // Fetch previous defined target host in targets
  CheckAndUpdateHostConsistency(&inputs, &target_host);

  if (!target_host.defined()) {
    for (const auto& it : inputs) {
      if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) {
        target_host = it.first;
        break;
      }
    }
  }

  if (!target_host.defined()) {
    target_host = DefaultTargetHost(target_host);
  }

  // Update target host for all targets
  CheckAndUpdateHostConsistency(&inputs, &target_host);

  IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());

  ICHECK(mhost_all.defined()) << "The host module must be defined";

  for (const auto& it : inputs) {
    if (it.second.defined()) {
      auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx);
      auto& mhost = pair.first;
      auto& mdevice = pair.second;

      ICHECK(mhost.defined()) << "The split host module must be defined";

      ICHECK(mhost_all.defined()) << "The host module must be defined";

      mhost_all->Update(mhost);

      if (mdevice->functions.size() != 0) {
        device_modules.push_back(codegen::Build(mdevice, it.first));
      }
    }
  }


src/target/codegen.cc
runtime::Module Build(IRModule mod, Target target) {
  if (transform::PassContext::Current()
          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
          .value()) {
    mod = tir::transform::SkipAssert()(mod);
  }
  std::string build_f_name;
  if (target->kind->name == "micro_dev") {
    build_f_name = "target.build.c";
  } else {
    build_f_name = "target.build." + target->kind->name;
  }
  LOG(INFO) << "build_f_name: "<< build_f_name;
  // the build function.
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  LOG(INFO) << "(bf != nullptr) : "<< (bf != nullptr);
  ICHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
2.2.2.2.9. graph_codegen_->GetExternalModules()

src/relay/backend/build_module.cc

  Array<tvm::runtime::Module> GetExternalModules() {
    return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
  }

src/relay/backend/graph_executor_codegen.cc
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {

    } else if (name == "get_external_modules") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        *rv = this->output_.external_mods;
      });

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
2.2.2.2.10. tvm::codegen::CreateMetadataModule(…)

tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost())

src/target/metadata_module.cc

/*!
 * \brief Create a metadata module wrapper. The helper is used by different
 *        codegens, such as graph executor codegen and the vm compiler.
 * \return The created metadata module that manages initialization of metadata.
 */
runtime::Module CreateMetadataModule(
    const std::unordered_map<std::string, runtime::NDArray>& params,
    tvm::runtime::Module target_module, const Array<runtime::Module>& ext_modules, Target target) {
 
    // 1. target c 
      target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target);

    // 2. target llvm
      target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target);

    // 3. others
    runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, sym_metadata);
    binary_meta_mod.Import(target_module);
    for (const auto& it : non_crt_exportable_modules) {
      binary_meta_mod.Import(it);
    }

  return target_module;
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
2.2.2.3 GraphExecutorFactoryModule

class GraphExecutorFactoryModule:
    """Graph executor factory module.
    This is a module of graph executor factory

    Parameters
    ----------
    graph_json_str : str
        The graph to be deployed in json format output by graph compiler.
        The graph can contain operator(tvm_op) that points to the name of
        PackedFunc in the libmod.
    target : tvm.Target
        The Target used to build this module.
    libmod : tvm.Module
        The module of the corresponding function
    libmod_name: str
        The name of module
    params : dict of str to NDArray
        The parameters of module
    """

    def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params):
        assert isinstance(graph_json_str, string_types)
        fcreate = get_global_func("tvm.graph_executor_factory.create")
        args = []
        for k, v in params.items():
            args.append(k)
            args.append(ndarray.array(v))
        self.ir_mod = ir_mod
        self.target = target
        self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
        self.graph_json = graph_json_str
        self.lib = libmod
        self.libmod_name = libmod_name
        self.params = params
        self.iter_cnt = 0

    def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
        return self.module.export_library(file_name, fcompile, addons, **kwargs)

    # Sometimes we want to get params explicitly.
    # For example, we want to save its params value to
    # an independent file.
    def get_params(self):
        return self.params

    def get_json(self):
        return self.graph_json

    def get_lib(self):
        return self.lib

    def __getitem__(self, item):
        return self.module.__getitem__(item)

    def __iter__(self):
        warnings.warn(
            "legacy graph executor behavior of producing json / lib / params will be "
            "removed in the next release."
            " Please see documents of tvm.contrib.graph_executor.GraphModule for the "
            " new recommended usage.",
            DeprecationWarning,
            2,
        )
        return self

    def __next__(self):
        if self.iter_cnt > 2:
            raise StopIteration

        objs = [self.graph_json, self.lib, self.params]
        obj = objs[self.iter_cnt]
        self.iter_cnt += 1
        return obj


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76

名词解释

  1. IR,Intermediate Representationv(中间表示)

  2. Relay IR,the second generation of NNVM(Neural Network Virtual Machine)

  3. TIR,Tensor-level IR. TIR contains the definition of the low-level program representations.

  4. TE,tensor expression

  5. TOPI,Tensor operator inventory.provides a set of pre-defined operators (in TE or TIR)

    defined by numpy and found in common deep learning workloads

  6. Relay, is the high-level functional IR used to represent full models.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/166261
推荐阅读
相关标签
  

闽ICP备14008679号