赞
踩
本文的tensorRT加速部署的流程主要有四步:
这里是上述流程的第一步,直接上PyTorch模型转换的代码,使用的是自带的torch.onnx.export。若是tensorflow模型可以参考这个链接:tensorflow转换为onnx
import torch
import torch.nn
import onnx
from torch.autograd import Variable
import torch.onnx
import torchvision
# 定义一个随机的输入变量,用来遍历模型参数,注意要放入cuda中
dummy_input = Variable(torch.randn(1, 3, 480, 640)).cuda()
# 加载一个PyTorch模型
model = torch.load(r"PyTorch_Model.pt")
# 导出为ONNX格式文件
torch.onnx.export(model, dummy_input, "save_your_onnx_model.onnx", verbose=True)
转换完后可以通过这个网站上传并查看转换后的onnx模型结构。
这一步相对较麻烦,包含了流程的第二步和第三步
1. 查看cuda版本和cudnn版本
要注意与电脑的系统版本(windows/linux/iOS)以及电脑调用tensorRT时要使用的cuda、cudnn版本对应。
①win+R,输入cmd打开终端;
②输入nvcc --version
,查看cuda版本;
③进入以下路径C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
,进入你的cuda版本文件夹,我的是v10.2
,进入include
文件夹,打开cudnn.h文件。
④往下翻一点,找到下面的版本号,如我的就是7.6.5
。后来我把cudnn7升级为cudnn8了,tensorrt也换成了8,因为tensorrt7没有python版本,如果有的话请告知。不过这都不重要hhh
2. 下载并配置tensorRT
下载的tensorRT到nvidia官网下载,我的电脑的cuda是10.2.89,cudnn版本是8.x,网上搜到对应的tensorRT版本就是8.x。下载好了之后就需要配置环境变量。在系统环境变量中添加:
你的tensorRT下载路径\TensorRT-8.2.1.8.Windows10.x86_64.cuda-10.2.cudnn8.2\TensorRT-8.2.1.8\bin
你的tensorRT下载路径\TensorRT-8.2.1.8.Windows10.x86_64.cuda-10.2.cudnn8.2\TensorRT-8.2.1.8\lib
3. 使用tensorRT转换成Engine
win+R,输入cmd打开终端,输入trtexec --onnx=model.onnx --saveEngine=xxx.trt
,回车,稍等一会,就可以将模型转换成trt文件啦。
4. 转换的时候记得添加优化参数,这里举个例子
trtexec --onnx=model.onnx --saveEngine=xxx.trt --fp16
①--fp16
通常情况下,fp16精度就可以把模型的遍历速度提高一倍,并损失非常少的精度,若使用int8,精度可能会大大降低,同时速度不一定能比fp16更快一倍。
②其他优化参数可以自行网上搜索添加。
######### 在转换过程中,我遇到了报错如下 ##########
①由于找不到nvparsers.dll,无法继续执行代码。重新安装程序可能会解决此问题。
解决方法:
将tensorrt8/lib添加系统环境变量后,就没出现这个报错。
如果还是有错,就把dll复制一份到tensorrt8/bin中试试。
②找不到cudnn64_8.dll,无法继续执行代码。
解决方法 :重新检查tensorRT版本下载的对不对,cuda和cudnn的版本一定都要校对。
C++调用加速,直接贴代码。
//读取 engine 文件 bool Widget::initial_the_trtfile(const std::string& enginePath){ char* trtModelStream; data = new float[1 * 3 * INPUT_H * INPUT_W]; output = new float[1 * OUTPUT_SIZE]; std::ifstream file(enginePath, std::ios::binary); if (file.good()) { file.seekg(0, file.end); size = file.tellg(); file.seekg(0, file.beg); trtModelStream = new char[size]; assert(trtModelStream); file.read(trtModelStream, size); file.close(); } runtime = nvinfer1::createInferRuntime(gLogger); assert(runtime != nullptr); nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size); assert(engine != nullptr); context = engine->createExecutionContext(); assert(context != nullptr); delete[] trtModelStream; assert(engine->getNbBindings() == 2); // In order to bind the buffers, we need to know the names of the input and output tensors. // Note that indices are guaranteed to be less than IEngine::getNbBindings() for (int i = 0; i < engine->getNbBindings(); i++){ nvinfer1::Dims dims = engine->getBindingDimensions(i); printf("index %d, dims: (",i); for (int d = 0; d < dims.nbDims; d++){ if (d < dims.nbDims - 1) printf("%d,", dims.d[d]); else printf("%d", dims.d[d]); } printf(")\n"); } const int inputIndex = 0, outputIndex = 1; assert(inputIndex == 0); assert(outputIndex == 1); // Create GPU buffers on device cudaMalloc(&buffers[inputIndex], 1 * 3 * INPUT_H * INPUT_W * sizeof(float)); cudaMalloc(&buffers[outputIndex], 1 * OUTPUT_SIZE * sizeof(float)); // Create stream cudaStreamCreate(&stream); return true; } //遍历engine void Widget::Engine_Inference(nvinfer1::IExecutionContext& context, cudaStream_t& stream, void **buffers, float* input, float* output, int batchSize) { // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host cudaMemcpyAsync(buffers[0], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream); context.enqueue(batchSize, buffers, stream, nullptr); cudaMemcpyAsync(output, buffers[1], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); } //这几个参数可以放到头文件,上面初始化trt文件和遍历的两个函数都要用到。 nvinfer1::IExecutionContext* context; void* buffers[2]; size_t size; cudaStream_t stream; //这个是定义输入输出 data = new float[1 * 3 * INPUT_H * INPUT_W]; output = new float[1 * OUTPUT_SIZE]; //初始化trt文件 bool initial_res = initial_the_trtfile("path to your trt file."); //遍历的时候调用的是初始化trt文件后得到的context和stream,data就是输入(shape应与你设置输入trt文件的大小一致),output就是输出啦 Widget::Engine_Inference(*context, stream, buffers, data, output, batchsize);
完结撒花
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。