赞
踩
一般来说,深度学习模型是基于python训练的,在模型部署时,一般需要基于C++代码进行部署。本文以LeNet为例,介绍将一个深度学习模型从基于PyTorch训练到TensorRT部署的全过程,详细代码见:https://github.com/linghu8812/tensorrt_inference/tree/master/lenet。
通过命令python3 train.py
可以训练一个识别手写数字的模型,模型的识别准确率在99%左右。
通过命令python3 export_onnx.py
可以将PyTorch模型转为ONNX模型,TensorRT框架无法直接解析PyTorch模型,所以先将PyTorch模型进行转换,以下代码中,转换的batch size为10。
import onnx
import torch
# export from pytorch to onnx
net = torch.load('mnist_net.pt').to('cpu')
image = torch.randn(10, 1, 28, 28)
torch.onnx.export(net, image, 'mnist_net.onnx', input_names=['input'], output_names=['output'])
# check onnx model
onnx_model = onnx.load("mnist_net.onnx") # load onnx model
onnx.checker.check_model(onnx_model)
clone代码之后,直接编译即可。
mkdir build && cd build
cmake ..
make -j
完成编译后,直接运行 ./lenet_trt ../config.yaml ../samples/
。sample
文件夹中是需要识别的图片。config.yaml
中定义了模型运行的配置,如需要转换的onnx文件和转换后的tensorrt文件名,推理时的batch size大小以及模型推理时的宽、高和通道数。
lenet:
onnx_file: "../mnist_net.onnx"
engine_file: "../mnist_net.trt"
BATCH_SIZE: 10
INPUT_CHANNEL: 1
IMAGE_WIDTH: 28
IMAGE_HEIGHT: 28
运行结果如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。