当前位置:   article > 正文

关于C++ libtorch调用pytorch模型的总结_libtorch dqn

libtorch dqn

最近接到了一个需求,需要把一个用python基于pytorch实现的DQN强化学习模型移植到Arm平台。 经历了很多坑,最终都解决了,记录一下过程:

环境:
系统:Ubuntu20.04 LTS;
pytorch版本:1.9.0;
python版本:3.8;
libtorch版本:1.9;

准备步骤:
①如何安装python和pytorch请自行百度;
②进入pytorch官网,下载合适版本的libtorch https://pytorch.org/,因为我的需求是移植到Arm平台上,没有GPU以及CUDA,所以选择CPU版本;
选择libtorch如图所示
③解压缩后,得到libtorch的文件目录,这是官方在X86平台上已经编译好的一些库,参照
libtorch中文文档 https://pytorch.apachecn.org/docs/1.0/cpp_export.html,新建一个自己的目录,比如“example-app”。在“example-app”目录下新建如下文件:
(1)CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

(2)convert.py(将pytorch的模型转换为C++能用的libtorch模型)
此处要注意的是,你需要去看官方教程,不同的模型文件转换方式是不一样的,之前的那个中文文档里有写,此处记录我的模型转换过程。
我的实例:

import torch


model = torch.load('eval_net.pkl')

traced_script_module = torch.jit.script(model)

traced_script_module.save("eval_net.pt")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

官方的实例:

import torch
import torchvision

# 获取模型实例
model = torchvision.models.resnet18()

# 生成一个样本供网络前向传播 forward()
example = torch.rand(1, 3, 224, 224)

# 使用 torch.jit.trace 生成 torch.jit.ScriptModule 来跟踪
traced_script_module = torch.jit.trace(model, example)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

执行convert.py后生成 eval_net.pt 文件。
(3)example-app.cpp(用来测试转换为C++模型后的程序)

#include <torch/script.h> 
#include <vector>
#include <iostream>

using namespace std;


int main() {
  torch::jit::script::Module module = torch::jit::load("./eval_net.pt");
  vector<torch::jit::IValue> inputs;
  //自定义一组输入数据并设置格式,否则会报错,对应pytorch中的FloatTensor
  inputs.push_back(torch::tensor({0, 1, 2, 3, 4, 6, 0, 0, 1, 0, 2, 1, 5, 6, 2, 3, 4, 6, 0, 5, 1, 3, 2, 1},torch::kFloat));
  //使用模型
  torch::Tensor res = module.forward(inputs).toTensor();
  for (size_t i = 0; i < res.itemsize(); i++)
  {
    if(res[i].equal(res.max())){
      cout << "result:"<< i << endl;
    }
  }
  cout << "加载成功\n";
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

此时example-app目录下文件应为:
在这里插入图片描述

此时在example-app下打开终端,执行语句:

cmake -DCMAKE_PREFIX_PATH=/home/software/libtorch-shared-with-deps-1.9.0+cpu/libtorch
  • 1

运行后没有提示错误则为成功,目录下会生成Makefile文件。

执行语句:

make
  • 1

生成example-app可执行程序。
执行语句,运行example-app得到如下提示,说明调用成功:
在这里插入图片描述
至此,X86版本的libtorch已经正常跑通。Arm版本的libtorch相对比较麻烦,需要官网下载pytorch源码,单独交叉编译libtorch,其它的调用方式跟X86是基本一样的,后续再补充。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/417794
推荐阅读
相关标签
  

闽ICP备14008679号