当前位置:   article > 正文

windows 下使用 c++ 调用 tensorflow(msvc/mingw) 进行 model inference_windows下编译tensorflow源码 用其c++接口调用训练好的模型

windows下编译tensorflow源码 用其c++接口调用训练好的模型

windows 下使用 c++ 调用 tensorflow(msvc/mingw) 进行 model inference

一、背景

假如想在 c++ 的应用中使用 tensorflow 做诸如 model inference 的操作,一个简单的方法是链接Python.dll,使用 python 脚本代替完成。但是众所周知,tensorflow 的底层是由 c++ 写的,上述方式虽然简单,但是丧失了性能和实现的优雅性。本文主要介绍如何在 windows 平台上直接调用 c++ 版本的tensorflow

二、模型固化为 pb 文件

模型在使用 model.save 保存后,必须固化为与语言无关的 pb 格式才能进行被 c++ 版 tensorflow 所调用,可以用从网上抄的一个 h5 转 pb 的脚本:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def h5_to_pb(h5_save_path):
    model = tf.keras.models.load_model(h5_save_path, compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./pb",
                      name="model.pb",
                      as_text=False) 


h5_to_pb('./model.h5')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

三、使用 c++ 调用 tensorflow

1. 编译 c++ 版 tensorflow

可以参考这篇文章在 windows 下编译 tensorflow(链接 ),但是最新版本的tensorflow在本机编译时导出的符号不全,导致报 undefined reference 的链接错误,可以改成用以下命令编译生成动态链接库

 bazel build --config=opt //tensorflow/tools/lib_package:libtensorflow
  • 1

需要注意的是,默认情况下 bazel 编译时会占用 cpu 所有的核,可能占用大量内存导致编译失败,可以用 --jobs 选项限制使用的 cpu 核数

2. msvc 调用 tensorflow

c++ 版 tensorflow 是 bazel 使用 msvc 编译的,所以直接链接 tensorflow_cc.lib 即可,代码和下文中的 “3.a 把 model inference 封装为 dll” 基本类似

3. mingw 调用 tensorflow

mingw 编译器直接链接 tensorflow 是有问题的,其直接原因是二者是不同的编译器,对于 c++ 来说 mangle 后的符号不同(mangle 相关的知识可以自行百度),肯定会报 undefined reference 这个错误。再者,由于标准库的实现方式也不同,所以即使正常链接接,生成了可执行文件,但是 dll 导出的涉及到以标准库为参或为返回值的函数其本身的执行也有可能出现问题。(诸如同一成员函数的偏移地址不一样,可能会导致程序直接崩溃)
有两个可能方案可以解决这个问题:

  1. 把 model inference 用 msvc 封装成一个新的 dll ,供 mingw 调用
  2. patch tensorflow_cc.lib 里面的符号,使其可以直接被 mingw 链接

但是方案2实现起来会有问题,一个原因是上文所说的二者标准库实现方式不同,直接链接肯定会导致程序崩溃,还有一个原因是 tensorflow 以及它所使用的诸如 protobuf 等库,内部的一些类在 mingw 和 msvc 下的实现也是不一样的(一个简单的判断方式是在 mingw 系的编译器引用 tensorflow 相关的头文件前加上 #define _MSC_VER 1900,可以发现有大量报错),即使想办法消除了标准库的差异,程序同样会面临崩溃

对于 windows 下的动态链接库,可以用如下命令生成导出符号的定义文件和 .a 文件

gendef xxxx.dll
dlltool -D xxxx.dll -d xxxx.def -l xxxx.a
  • 1
  • 2

gendef 和 dlltool 在mingw 的 bin 目录下就能找到,需要注意的是,dlltool 和 动态链接库的位数要保持相同。
假如直接用 extern “C” 导出和使用 C 风格的函数接口时,那么就不存在 mangle 的问题,但是为什么不导出可读性和封装性更强的 c++ 的接口呢?
可是,直接导出 c++ 接口会面临和上文中提到的相同的两个错误,然而都有相应的解决对策:

  1. 对于标准库实现方式不同,其实 model inference 的输入输出都是固定大小的向量,所以可以直接使用数组作为函数的参数,根本不需要用的标准库
  2. 对于不同编译器 mangle 后的符号不同,可以 patch 生成的 .a 文件,将 msvc mangle 后的符号替换为 mingw mangle后的符号
a. 把 model inference 封装为 dll

针对解决对策一,我们先把 model inference 的过程封装为 dll ,接口全部使用数组传参,代码如下:
libevaluate.h

#pragma once
class chessEvaluate {
private:
    _declspec(dllexport) chessEvaluate();
public:
    _declspec(dllexport) float evaluate(int map[10][9]);
    static chessEvaluate& instance() {
        static chessEvaluate instance_;
        return instance_;
    }
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

libevaluate.cpp

#include <fstream>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session.h"
template<typename T>
struct is_vector_type : std::false_type {};

template<typename T>
struct is_vector_type<std::vector<T>> : std::true_type {};

template <typename T>
void vectorSize(T& vec,std::vector<int64_t> &size) {
    if constexpr (is_vector_type<T>::value) {
        size.push_back(vec.size());
        vectorSize(vec[0], size);
    }
}
template <typename T>
constexpr int vectorDepth() {
    if constexpr (is_vector_type<T>::value) {
        return vectorDepth<typename T::value_type>() + 1;
    }
    else {
        return 0;
    }
}
template <typename Mat,typename Vec,typename ...Args>
void tensorAssign(Mat& mat, Vec& vec, Args... args) {
    if constexpr (is_vector_type<Vec>::value) {
        for (int i = 0; i < vec.size(); i++) {
            tensorAssign(mat, vec[i], args..., i);
        } 
    }
    else {
        mat(args...) = vec;
    }
}
 template <typename Mat, typename Vec>
 void tensorAssign(Mat& mat, Vec& vec) {
     for (int i = 0; i < vec.size(); i++) {
        tensorAssign(mat, vec[i], i);
     }
 }
template <typename T>
std::shared_ptr<tensorflow::Tensor> vector2Tensor(T vec) {
    std::vector<int64_t> vecSize_;
    vectorSize(vec, vecSize_);
    const std::vector<int64_t> vecSize=vecSize_;
    auto span = absl::Span<const int64_t>(vecSize.data(), vecSize.size());
    auto tensor=std::make_shared<tensorflow::Tensor>(tensorflow::DT_FLOAT, tensorflow::TensorShape(span));
    auto input_tensor_mapped = tensor->tensor<float, vectorDepth<T>()>();
    tensorAssign(input_tensor_mapped, vec);
    return tensor;
}


static tensorflow::Session* session;

static float evaluate_impl(std::vector<std::vector<std::vector<std::vector<float>>>>&& chessboard) {
    auto input_tensor_ptr = vector2Tensor(chessboard);

    std::vector<tensorflow::Tensor> outputs;
    std::string output_node = "Identity:0";

    //开始预测,这里的输入名images要和模型的输入相匹配
    tensorflow::Status status_run = session->Run({ {"Input:0", *input_tensor_ptr} }, { output_node }, {}, &outputs);
    if (!status_run.ok()) {
        std::cout << "ERROR: RUN failed..." << std::endl;
        std::cout << status_run.ToString() << "\n";
        return -1;
    }

    assert(outputs.size() == 1);
    auto p = outputs[0].flat<float>();
    return p(0) * 256;
}
#include "libevaluate.h"
chessEvaluate::chessEvaluate() {
    std::string model_file = "model.pb";
    session = tensorflow::NewSession(tensorflow::SessionOptions());           //创建新会话Session

    tensorflow::GraphDef graphdef;                                                //当前模型的图定义
    tensorflow::Status status_load = ReadBinaryProto(tensorflow::Env::Default(), model_file, &graphdef); //从pb文件中读取图模型;
    if (!status_load.ok()) {
        std::cout << "ERROR: Loading model failed..." << model_file << std::endl;
        std::cout << status_load.ToString() << "\n";
        return;
    }

    tensorflow::Status status_create = session->Create(graphdef);               //将图模型导入会话Session中;
    if (!status_create.ok()) {
        std::cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
        return;
    }
    return;
}
float chessEvaluate::evaluate(int map[10][9]) {
    std::vector<std::vector<std::vector<float>>> chessboard;
    for (int i = 0; i < 15; i++) {
        std::vector<std::vector<float>> one_piece_chessboard;
        for (int j = 0; j < 10; j++) {
            std::vector<float> line;
            for (int k = 0; k < 9; k++) {
                line.push_back((map[j][k] + 1) == i);
            }
            one_piece_chessboard.emplace_back(line);
        }
        chessboard.emplace_back(one_piece_chessboard);
    }
    return evaluate_impl({chessboard});
    
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
b. 制作 mingw 可链接的 .a 文件

在前面生成的导出符号的定义文件中,可以找到函数名对应的 mangle 后的符号。

;
; Definition file of libevaluate.dll
; Automatic generated by gendef
; written by Kai Tietz 2008
;
LIBRARY "libevaluate.dll"
EXPORTS
; private: __cdecl chessEvaluate::chessEvaluate(void)__ptr64 
??0chessEvaluate@@AEAA@XZ
; public: float __cdecl chessEvaluate::evaluate(unknown ecsu[])__ptr64 throw()
?evaluate@chessEvaluate@@QEAAMQEAY08H@Z

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

以 chessEvaluate::evaluate 这个函数为例,可以看出 msvc 对其 mangle 后的符号是 ?evaluate@chessEvaluate@@QEAAMQEAY08H@Z。接下来的问题是,如何找出 mingw 对上述函数 mangle 后的符号表示?
一个简单的方法是,把上述的 libevaluate.h 加以改造,在 mingw 下编译一下,看看编译后的符号是什么。改造后的 test.cpp :

class chessEvaluate {
private:
    __attribute((used)) chessEvaluate() {
    	
	}
public:
    __attribute((used)) float evaluate(int map[10][9]) {
    	
	}
    static chessEvaluate& instance() {
        static chessEvaluate instance_;
        return instance_;
    }
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

然后通过以下命令编译并查看 mangle 后的符号

g++ -c test.cpp
nm test.o|grep evaluate
  • 1
  • 2

在这里插入图片描述
可以看出,chessEvaluate::evaluate 在 mingw 下 mangle 后的符号为 _ZN13chessEvaluate8evaluateEPA9_i 。同理,也可以按照上述方式找出 chessEvaluate的构造函数 mangle 后的符号为 _ZN13chessEvaluateC1Ev
此时可以使用 objcopy 命令,将 .a 文件里 msvc mangle 后的符号替换为 mingw mangle 后的符号。

objcopy --redefine-sym ?evaluate@chessEvaluate@@QEAAMQEAY08H@Z=_ZN13chessEvaluate8evaluateEPA9_i  libevaluate.a libevaluate.out.a
objcopy --redefine-sym ??0chessEvaluate@@AEAA@XZ=_ZN13chessEvaluateC1Ev  libevaluate.out.a libevaluate.out.a
  • 1
  • 2

此时,在 mingw 系的编译器里便可以链接 libevaluate.out.a 正常编译,从而可以在运行时动态链接 libevaluate.dll ,调用封装好的类进行 model inference

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/600241
推荐阅读
相关标签
  

闽ICP备14008679号