MNN 核心类介绍

深入分析MNN框架的核心类设计理念、实现细节以及它们之间的相互关系

MNN 介绍

MNN框架作为一个高性能的深度学习推理引擎,核心设计围绕着几个关键的抽象类展开。其中,Varp,Expr,Op等类是整个框架的基石,它们不仅定义了数据的表示方式,还构建了计算图的基本结构。本文将深入分析这些核心类的设计理念、实现细节以及它们之间的相互关系。

1. MNN核心类

MNN 中的核心类主要包括:

类名 职责
VARP 智能指针包装的 Variable,表示表达式中的变量节点
Variable 表达式图中的变量节点,持有张量数据或计算信息
Expr 表达式边,表示一个计算操作 以及 输入节点和输出节点
Tensor 张量数据容器,存储实际的多维数据
Op 算子描述符,定义计算的类型和参数、模型权重等
Pipeline 计算流水线,TODO:
Session 执行会话,推理数据的持有者, TODO
Interpreter 解释器, 模型数据的持有者
Executor 执行器, TODO
ExecutorScope 执行器作用域, TODO

1.1 类之间关系

  • **逻辑依赖关系: **VARP和Expr是MNN计算图格式的核心,分别表示计算图的节点和边,Op类是整个计算图的核心,表示计算图的边Expr的计算操作。关系图如下:

    1
    2
    3
    4
    5
    6
    7
    
    VARP (Variable Ptr)
        ↓ 指向
    Variable
        ↓ 包含
    Expr ──→ Op (操作描述)
    Tensor (存储数据)
    

    每个边Expr都有属性std::vector<VARP> mInputs;指示输入来自的节点; 每个VARP节点有属性EXPRP mFrom; // typedef std::shared_ptr<Expr> EXPRP;指示输入的边,同时VARP的节点的Tensor数据也存放在入边mFrom中,数据读取和写入在Variable::readMapVariable::writeMap中,需要转换成对应数据格式的指针进行读取/写入。一个简单的计算图可以如下:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    
    Expr (mFrom) -> Tensor x(data)        Expr (mFrom) -> Tensor y(data) 
             	│									    │
             	│									    │
             VARP x                    				  VARP y
                ↓                        	  		    ↓
                └────────────────────┬──────────────────┘
                    				 │ Expr -------> Op(Add)
    	                             │   └─--------> Tensor z
    	             			  VARP z
    
  • 执行依赖关系: MNN的计算图有两种模式,Defer(延迟计算)模式或Eager(立即计算)模式:Defer模式下,调用表达式相关API不直接计算,而是搭建模型,在需要获取输出值时才执行;Eager模式下,直接进行计算,对应地无法搭建模型。 下面以Eager模式为例,梳理MNN中表达式计算顺序,更具体的代码分析见readmap详解:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    
    用户定义VARP计算代码 // 如:VARP x = _Input({2, 3});VARP y = _Input({2, 3});VARP z = _Add(x, y);
        
    varp->readmap()  // 触发计算
        
    ExecutorScope::Current()->computeInfo() //计算节点信息  ─→ SizeComputer::computeOutputSize() // 动态计算中间节点/输出的形状
        
    ExecutorScope::Current()->makeCache() // 计算缓存, 可复用, 按数据依赖顺序准备中间节点的Tensor 以及Session会话的计算信息
        
    Executor::ComputeCache::compute() // 按数据依赖顺序执行计算缓存中: 算子的resize和执行 算子执行会从计算缓存的mSession进入下一步
        
    Session::run() // 这里继续进入Session::mPipelines的执行
        
    Pipeline::execute() // 这里执行execution->onExecute 进入算子的后端执行
        
    后端执行计算	// 计算结果存在tensor中 最后从readmap读取为指针
        
    获得指针
    

    这当中主要配置的信息是在ExecutorScope::Current()->makeCache()中配置Session,后续的执行大多依赖Session的数据

1.2 VARP类

VARP本质是 Variable 的智能指针包装类,利用Variable指针的地址重载了比较运算符,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// express/Expr.cpp
class MNN_PUBLIC VARP {
public:
    // 构造函数
    VARP();
    VARP(std::shared_ptr<Variable> c);
    VARP(Variable* c);
    
    // 获取内部 Variable 变量
    Variable* get() const;
    
    // 重载运算符
    bool operator==(const VARP& var) const = 0;
    bool operator<(const VARP& var) const = 0;
    bool operator<=(const VARP& var) const = 0;
	
    // ... 其它代码  

    // 类型标记
    enum InputType {
        INPUT = 0,      // 输入变量
        CONSTANT = 1,   // 常量变量
        TRAINABLE = 2   // 可训练变量
    };
    // 固定变量类型
    bool fix(InputType type) const;
    // 设置数据格式
    void setOrder(Dimensionformat format); // enum Dimensionformat { NHWC, NC4HW4, NCHW };
    
private:
    std::shared_ptr<Variable> mContent;
};

1.2.1 Variable 类

描述算子图的一个节点,主要刻画了对节点和数据的部分操作,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// express/Expr.cpp
class MNN_PUBLIC Variable {
public:
    // 获取节点名称
    const std::string& name() const;
    void setName(const std::string& name);
    
    // 获取数据信息
    const Info* getInfo();
	struct Info {
        Dimensionformat order;  // 数据格式(NCHW/NHWC/NC4HW4)
        INTS dim;              // 维度
        halide_type_t type;    // 数据类型
        size_t size;           // 元素总数
        void syncSize();       // 同步大小
    };
    
    // 设置和获取(输入边 以及 数据是“输入边”的第几个输出Tensor)
	void setExpr(EXPRP expr, int index);
    std::pair<EXPRP, int> expr() const;
    
    // 获取 Tensor
    const Tensor* getTensor() const;
    
    // 调整大小
    bool resize(INTS dims);
    
    // 读取/保存节点到文件
	static std::vector<VARP> load(const char* fileName);
    static void save(const std::vector<VARP>& vars, NetT* dest);
    
    // 获得数据的指针映射
    template <typename T>
    const T* readMap();
    template <typename T>
    T* writeMap();
};

1.2.2 readMap详解

接下来详细解释readMap是怎么执行的,TODO

1.3 Expr类

Expr表示计算图的边,核心属性包括边的输入(VARP) 输出(Tensor) 以及计算算子(Op),这里的输出是Tensor格式,被输出节点VARP通过边引用,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// express/Expr.cpp
class MNN_PUBLIC Expr {
public:
    struct Inside;
    enum MemoryType {
        COPY,
        MOVE,
        REF
    };
    // 多种构造函数 使用各种信息构造Expr类
    static EXPRP create(Tensor* tensor, bool own = false);
    static EXPRP create(Variable::Info&& info, const void* ptr, VARP::InputType type, MemoryType copy = COPY);
    static EXPRP create(const OpT* op, std::vector<VARP> inputs, int outputSize = 1);
    static EXPRP create(std::shared_ptr<BufferStorage> extra, std::vector<VARP>&& inputs, int outputSize = 1);
    static EXPRP create(std::unique_ptr<OpT>&& op, std::vector<VARP> inputs, int outputSize = 1);
    
    // 获取/设置节点信息
    const Op* get() const;
    const std::vector<VARP>& inputs() const = 0;
    int outputSize() const = 0;
    void setName(const std::string& name);
    const std::string& name() const = 0;
    const std::string& outputName(int index);
    static void replace(EXPRP oldExpr, EXPRP newExpr);
    VARP::InputType inputType() const;
    
    // 判断当前操作依赖的输入节点 以及输入节点往前依赖的所有节点的信息是否正确, 过程中会推导节点的shape
    bool requireInfo();
    
    // 遍历,过程中会调用用户传入的 回调函数
    void visitOutputs(const std::function<bool(EXPRP, int)>& visit);
    static void visit(EXPRP expr, const std::function<bool(EXPRP)>& before, const std::function<bool(EXPRP)>& after);
   

private:
    const Op* mOp;				// 表示的计算算子
    std::vector<VARP> mInputs;	// 输入节点
	
    // 节点信息
    VARP::InputType mType;
    std::string mName;
    std::vector<std::string> mOutputNames;
    // 其中inside记录了很多更多有关边的信息
    std::shared_ptr<Inside> mInside = nullptr;
    struct Expr::Inside {
		// 构造 析构函数
        Inside(int outputSize);
        Inside(Tensor* tensor, bool own = false);
        ~ Inside();
        
        std::vector<Variable::Info> mOutputInfos;	// 输出tensor的信息 与tensor一一对应
        std::vector<Tensor*> mOutputTensors;	// 边对应的输出tensor	
        
        std::shared_ptr<Executor::ComputeCache> mCache; // 计算cache缓存 在1.1 类之间关系提到 通过cache 
        int mCacheOffset = 0;	// 用来索引tensor
        
        // 标记位
        Executor::Requirement mReq;	
        bool mInfoDirty = true;
        bool mContentDirty = true;
        bool mOwnTensor = true;
        
        Tensor* mHostTensor = nullptr;			// 设备内存(GPU/NPU)中的tensor
        std::shared_ptr<Backend> mHoldBackend;	// 持有数据的后端
    };
};

1.3.1 requireInfo 详解

TODO:

1.4 Tensor类

张量数据类,包括数据指针,数据格式,数据维度等信息,所有属性存储在下面两个结构体对象中,

1
2
3
4
5
6
7
// include/MNN/Tensor.hpp
class MNN_PUBLIC Tensor{
// 其它代码
private:
    halide_buffer_t mBuffer;
    struct InsideDescribe* mDescribe;
};

1.4.1 数据格式

MNN支持常见的数据格式,其中N C H W分别表示 批次大小 通道数 高度 宽度。这个格式是对图片格式的兼容,在大模型推理中输入embedding的shape是(batch_size, seq_len, hidden_size),依次是输入、序列长度和隐藏层大小,不需要太关注这个NCHW,(大模型推理过程中的shape变化可以参考这篇介绍)。

1
2
3
4
5
6
7
// include/MNN/Tensor.hpp
// 维度类型
    enum DimensionType {
        TENSORFLOW,  // TensorFlow 格式:NHWC
        CAFFE,       // Caffe 格式:NCHW
        CAFFE_C4     // Caffe 格式:NC4HW4(4通道对齐)
    };

值得一提的是 MNN默认只支持1条数据输入,输入embedding维度是(seq_len, hidden_size),但是除了Attention算子的大部分算子都做了batch维度的适配,对MNN中大模型推理感兴趣可以看这篇介绍

我尝试在MNN框架上做了Chunk prefill,把不同输入请求合并在seq_len维度上,并在Attention算子中展开:传送门

1.4.2 底层数据格式

从更底层出发,这个数据格式信息由下面结构体存储,对应Tensor类的halide_buffer_t mBuffer属性,其中存储了数据的指针、数据类型、数据维度等信息,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// include/MNN/HalideRuntime.h
/**
 * The raw representation of an image passed around by generated
 * Halide code. It includes some stuff to track whether the image is
 * not actually in main memory, but instead on a device (like a
 * GPU). For a more convenient C++ wrapper, use Halide::Buffer<T>. */
typedef struct halide_buffer_t {
    /** A device-handle for e.g. GPU memory used to back this buffer. */
    uint64_t device; // 设备句柄

    /** The interface used to interpret the above handle. */
    const struct halide_device_interface_t *device_interface; // 接口指针

    /** A pointer to the start of the data in main memory. In terms of
     * the Halide coordinate system, this is the address of the min
     * coordinates (defined below). */
    uint8_t* host; // 指针 指向数据

    /** flags with various meanings. */
    uint64_t flags;	

    /** The type of each buffer element. */
    struct halide_type_t type;	// 数据类型

    /** The dimensionality of the buffer. */
    int32_t dimensions;	// 数据维度,如[batch_size, seq_len, hidden_size]大小的Tensor数据就是2维

    /** The shape of the buffer. Halide does not own this array - you
     * must manage the memory for it yourself. */
    halide_dimension_t *dim; // 数据各维度的数值, 如如[batch_size, seq_len, hidden_size]大小的Tensor数据 dim[0]就表示batch_size维度的信息

    /** Pads the buffer up to a multiple of 8 bytes */
    void *padding; // 用来对齐内存
} halide_buffer_t;

其中的数据类型halide_type_t type通过数据占比特数和数据的性质 判断数据的类型,例如code = halide_type_float并且bits = 16表示半精度浮点数。核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// include/MNN/HalideRuntime.h
/** A runtime tag for a type in the halide type system. Can be ints,
 * unsigned ints, or floats of various bit-widths (the 'bits'
 * field). Can also be vectors of the same (by setting the 'lanes'
 * field to something larger than one). This struct should be
 * exactly 32-bits in size. */
struct halide_type_t {
    /** The basic type code: signed integer, unsigned integer, or floating point. */
    
    // 这个code表示数据的性质 见本代码块最下方结构体,例如 halide_type_int = 0, 表示 signed integers 有符号整型
#ifndef _MSC_VER
    HALIDE_ATTRIBUTE_ALIGN(1) halide_type_code_t code; // halide_type_code_t
#else
    HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t
#endif

    /** The number of bits of precision of a single scalar value of this type. */
    HALIDE_ATTRIBUTE_ALIGN(1) uint8_t bits;	// 数据占用比特数

    /** How many elements in a vector. This is 1 for scalar types. */
    HALIDE_ATTRIBUTE_ALIGN(2) uint16_t lanes; // 一次处理的数据宽度 用于SIMD

    // 构造函数
#ifdef __cplusplus
    /** Construct a runtime representation of a Halide type from:
     * code: The fundamental type from an enum.
     * bits: The bit size of one element.
     * lanes: The number of vector elements in the type. */
    HALIDE_ALWAYS_INLINE halide_type_t(halide_type_code_t code, uint8_t bits, uint16_t lanes = 1)
        : code(code), bits(bits), lanes(lanes) {
    }
    /** Default constructor is required e.g. to declare halide_trace_event
     * instances. */
    HALIDE_ALWAYS_INLINE halide_type_t() : code((halide_type_code_t)0), bits(0), lanes(0) {}

    // 重载了对比函数
    /** Compare two types for equality. */
    HALIDE_ALWAYS_INLINE bool operator==(const halide_type_t &other) const {
        return (code == other.code &&
                bits == other.bits &&
                lanes == other.lanes);
    }
    HALIDE_ALWAYS_INLINE bool operator!=(const halide_type_t &other) const {
        return !(*this == other);
    }
    
	// 单个数据占据的内存字节数, 按8bit向上对齐
    /** Size in bytes for a single element, even if width is not 1, of this type. */
    HALIDE_ALWAYS_INLINE int bytes() const { return (bits + 7) / 8; }
#endif
};


typedef enum halide_type_code_t
{
    halide_type_int = 0,   //!< signed integers
    halide_type_uint = 1,  //!< unsigned integers
    halide_type_float = 2, //!< IEEE floating point numbers
    halide_type_handle = 3, //!< opaque pointer type (void *)
    halide_type_bfloat = 4  //!< floating point numbers in the bfloat format
} halide_type_code_t;

这里数据维度信息halide_dimension_t *dim 包含该维度的元素个数(extend)和在该维度移动一步时内存地址的偏移量信息(stride,通常在TensorUtils::setLinearLayout中利用extend计算stride)。核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// include/MNN/HalideRuntime.h
typedef struct halide_dimension_t { 
    // extend: 该维度的元素个数
    // stride: 在该维度移动一步时内存地址的偏移量信息
    int32_t min, extent, stride;

    // Per-dimension flags. None are defined yet (This is reserved for future use).
    uint32_t flags;

#ifdef __cplusplus
    HALIDE_ALWAYS_INLINE halide_dimension_t() : min(0), extent(0), stride(0), flags(0) {}
    HALIDE_ALWAYS_INLINE halide_dimension_t(int32_t m, int32_t e, int32_t s, uint32_t f = 0) :
        min(m), extent(e), stride(s), flags(f) {}

    HALIDE_ALWAYS_INLINE bool operator==(const halide_dimension_t &other) const {
        return (min == other.min) &&
            (extent == other.extent) &&
            (stride == other.stride) &&
            (flags == other.flags);
    }

    HALIDE_ALWAYS_INLINE bool operator!=(const halide_dimension_t &other) const {
        return !(*this == other);
    }
#endif
} halide_dimension_t;

1.4.2 核心接口

Tensor类的核心接口包括设置/获取数据信息、把数据映射到执行设备、调整大小等,MNN文档有详细介绍, 需要用到时可以自己查看,比较常用的调试接口是打印数据和打印形状,这里打印数据中有根据数据底层的bits和code信息自动转换成对应数据类型的指针进行打印的转化。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// include/MNN/Tensor.hpp
class MNN_PUBLIC Tensor {
public:
    /**
        * @brief print tensor data. for DEBUG use only.
        */
    void print() const;

    /**
        *@brief print tensor shape
        */
    void printShape() const;
}

1.5 Op类

Op 是算子的描述类,定义了神经网络中各种操作的类型和参数。MNN 使用内存高效的FlatBuffers库来序列化/反序列化来表示Op等信息,核心结构体是OpT,在需要修改算子信息时通常使用该结构体,通常代码中使用的就是该结构体,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
// schema/current/MNN_generated.h
   struct OpT : public flatbuffers::NativeTable {
     std::vector<int32_t> inputIndexes;      // 输入张量索引列表
     OpParameterUnion main;                   // 算子参数(联合体)
     std::string name;                        // 算子名称
     std::vector<int32_t> outputIndexes;     // 输出张量索引列表
     OpType type;                             // 算子类型枚举
     MNN_DATA_FORMAT defaultDimentionFormat; // 默认数据格式(NHWC/NCHW等)
     std::string externalPath;               // 外部权重路径
   };

这里OpType 是一个枚举类型,用于表示不同的算子类型,例如 Conv2D, Add, Relu 等。MNN 中定义了多个算子类型,每个类型对应一个具体的算子实现,部分代码如下:

1
2
3
4
5
6
7
8
9
// schema/current/MNN_generated.h
enum OpType {
    OpType_AbsVal = 0,
    OpType_QuantizedAdd = 1,
    OpType_ArgMax = 2,
    OpType_AsString = 3,
    OpType_InstanceNorm = 4,
    // ...
};

继续往底层是flatbuffers::Table的数据接口,提供了获取输入索引、算子类型等信息的接口,可以通过unpack操作转换为上层的OpT结构体,主要在读取、解析、构建算子的Execution时用到Op结构体,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
  typedef OpT NativeTableType;
  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
    return OpTypeTable();
  }
  // 获取输入索引列表
  const flatbuffers::Vector<int32_t> *inputIndexes() const {
    return GetPointer<const flatbuffers::Vector<int32_t> *>(4);
  }
  // 获取参数类型
  OpParameter main_type() const {
    return static_cast<OpParameter>(GetField<uint8_t>(6, 0));
  }
  // 获取指针
  const void *main() const {
    return GetPointer<const void *>(8);
  }
  
  // 通过一堆main_as_XXX函数根据算子类型转换成对应的参数结构体指针,例如main_as_ArgMax算子就转换成ArgMax结构体指针
  template<typename T> const T *main_as() const;
  const QuantizedAdd *main_as_QuantizedAdd() const {
    return main_type() == OpParameter_QuantizedAdd ? static_cast<const QuantizedAdd *>(main()) : nullptr;
  }
  const ArgMax *main_as_ArgMax() const {
    return main_type() == OpParameter_ArgMax ? static_cast<const ArgMax *>(main()) : nullptr;
  }
  // ... 其它代码

  // 通过FlatBuffers进行序列化和反序列化
  OpT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
  void UnPackTo(OpT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
  static flatbuffers::Offset<Op> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OpT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};

接下来各个算子都有一个自己的参数解析类,通过判断main_type()的参数类型来转换成对应的参数结构体指针,例如main_as_ArgMax()函数会判断参数类型是否是ArgMax,如果是就转换成ArgMax结构体指针,这里OpParameter是一个枚举类型,用于表示不同的算子参数类型,例如 QuantizedAdd, ArgMax, InstanceNorm 等,部分代码如下:

1
2
3
4
5
6
7
// schema/current/MNN_generated.h
enum OpParameter {
  OpParameter_NONE = 0,
  OpParameter_QuantizedAdd = 1,
  OpParameter_ArgMax = 2,
  // ... 其它代码 
};

ArgMax为例,算子可以通过main_as_XXX转换成对应算子读取时使用的底层Flatbuffer结构体,可以继续通过UnPack函数转换成修改/写入时用到ArgMaxT结构体,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// schema/current/CaffeOp_generated.h
struct ArgMax FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
  typedef ArgMaxT NativeTableType;
  static const flatbuffers::TypeTable *MiniReflectTypeTable() {
    return ArgMaxTypeTable();
  }
  // 获取各种算子的参数
  int32_t outMaxVal() const {
    return GetField<int32_t>(4, 0);
  }
  int32_t topK() const {
    return GetField<int32_t>(6, 0);
  }
  int32_t axis() const {
    return GetField<int32_t>(8, 0);
  }
  int32_t softmaxThreshold() const {
    return GetField<int32_t>(10, 0);
  }
  bool Verify(flatbuffers::Verifier &verifier) const {
    return VerifyTableStart(verifier) &&
           VerifyField<int32_t>(verifier, 4) &&
           VerifyField<int32_t>(verifier, 6) &&
           VerifyField<int32_t>(verifier, 8) &&
           VerifyField<int32_t>(verifier, 10) &&
           verifier.EndTable();
  }
  // 序列化和反序列化
  ArgMaxT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
  void UnPackTo(ArgMaxT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
  static flatbuffers::Offset<ArgMax> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};

反序列化后的ArgMaxT结构体就是我们平时上层使用的算子参数结构体,属性包含了算子参数的具体数值,例如outMaxVal, topK, axis, softmaxThreshold等,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
struct ArgMaxT : public flatbuffers::NativeTable {
  typedef ArgMax TableType;
  int32_t outMaxVal;
  int32_t topK;
  int32_t axis;
  int32_t softmaxThreshold;
  ArgMaxT()
      : outMaxVal(0),
        topK(0),
        axis(0),
        softmaxThreshold(0) {
  }
};

Op类再往上层就是后端算子的具体实现,通过继承Execution类实现,该类解释可见后端介绍,例如:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// source/backend/cpu/CPUArgMax.hpp
class CPUArgMax : public Execution {
public:
    enum ArgMinOrMax {
        ARGMIN,
        ARGMAX
    };
    CPUArgMax(Backend *backend, ArgMinOrMax mode, int topk, int outMaxVal, int softmaxThreshold, int axis);
    virtual ~CPUArgMax() = default;
    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;

private:
    Tensor mInputBuffer;
    Tensor mOutputBuffer;
    int mTopk;
    int mOutMaxVal;
    int mSoftmaxThreshold;
    int mAxis;
    int mNum;
    int mDim;
    int mKeyExtent;
    bool mFromNHWC;
    ArgMinOrMax mMode;
};

在构造不同后端的算子实现时,会把Op结构体中的参数传入到成员变量中,例如:topK, outMaxVal, softmaxThreshold, axis等,后续在执行时就可以直接使用这些参数进行计算了,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// source/backend/cpu/CPUArgMax.cpp
class CPUArgMaxCreator : public CPUBackend::Creator {
public:
    virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
                                const MNN::Op *op, Backend *backend) const {
        auto argMax = op->main_as_ArgMax();
        if (op->type() == OpType_ArgMin) {
            return new CPUArgMax(backend, CPUArgMax::ArgMinOrMax::ARGMIN,
                    argMax->topK(), argMax->outMaxVal(), argMax->softmaxThreshold(), argMax->axis());
        } else {
            return new CPUArgMax(backend, CPUArgMax::ArgMinOrMax::ARGMAX,
                    argMax->topK(), argMax->outMaxVal(), argMax->softmaxThreshold(), argMax->axis());
        }
    }
};

总结一下:读取算子时通常用Op结构体,修改/写入时用OpT结构体。阅读代码时可以查看带T的结构体名称,例如ArgMaxTConvolutionT,便于阅读。

MNN中会把常见的线性层转换为卷积Convolution算子,常用算子UnaryOpBinaryOp分别表示一元和二元算子。

常用接口:可以使用EnumNameXXX方式获取OpType、OpParameter等枚举类型的字符串名称,便于调试,例如:

1
2
3
4
5
6
// schema/current/MNN_generated.h
inline const char *EnumNameOpType(OpType e);
inline const char *EnumNameOpParameter(OpParameter e);

// schema/current/TensorflowOp_generated.h
inline const char *EnumNameBinaryOpOperation(BinaryOpOperation e);

这里的获取都是从一个静态数组中取值实现的,例如二元操作的各个名称存储在静态数组中,并且通过EnumNameBinaryOpOperation函数根据枚举值获取对应的名称字符串,核心代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// schema/current/TensorflowOp_generated.h
inline const char * const *EnumNamesBinaryOpOperation() {
  static const char * const names[] = {
    "ADD",
    "SUB",
    "MUL",
    "DIV",
    "MAX_TEMP",
    "MIN_TEMP",
    "POW",
    "REALDIV",
    "MINIMUM",
    "MAXIMUM",
    "GREATER",
    "GREATER_EQUAL",
    "LESS",
    "FLOORDIV",
    "SquaredDifference",
    "EQUAL",
    "LESS_EQUAL",
    "FLOORMOD",
    "",
    "MOD",
    "ATAN2",
    "LOGICALOR",
    "NOTEQUAL",
    "BITWISE_AND",
    "BITWISE_OR",
    "BITWISE_XOR",
    "LOGICALXOR",
    "LEFTSHIFT",
    "RIGHTSHIFT",
    nullptr
  };
  return names;
}
inline const char *EnumNameBinaryOpOperation(BinaryOpOperation e) {
  if (e < BinaryOpOperation_ADD || e > BinaryOpOperation_RIGHTSHIFT) return "";
  const size_t index = static_cast<int>(e);
  return EnumNamesBinaryOpOperation()[index];
}

1.5.1 GEMM转卷积算子的理解

MNN中会把常见的线性层转换为$1\times1$的卷积Convolution算子,

TODO:

1.6 Pipeline类

TODO:

1.7 Session类

TODO:

1.8 Executor & ExecutorScope类