赞
踩
环境配置与训练测试的注意点请参考文章:《基于pytorch的深度学习图像分类》
本文将resnet18部署在jetson-nx平台,关于其他模型的部署步骤也是类似的。
resnet系列的分类模型是常用的分类模型,一般opencv-dnn也是支持的(博主没有尝试,如果大家想简单方便一点可以尝试使用最新版的opencv试下,看看是否能正常加载推理),但在jetson系列平台部署一般都是使用tensorrt进行推理。github的作者给出了两种c++的推理部署方法,分别是基于libtorch和tensorrt的。博主以tensorrt为例改写了作者的一些代码,使得阅读更为简单方便。
目录
模型转tensorrt引擎一般步骤如下:
其中第2-6步需要在平台端进行,以保证GPU型号的一致性。
脚本在./pytorch_classification-master\trt_inference目录,核心代码如下
- def torch_convert_onnx():
-
- model = ClsModel('resnet50', num_classes=2, dropout=0, is_pretrained=False)
- sd = torch.load('../cpp_inference/traced_model/trained_model.pth', map_location='cpu')
- model.load_state_dict(sd)
- model.eval()
-
- dummy_input = torch.randn(1, 3, 224, 224)
- torch.onnx.export(
- model,
- dummy_input,
- "./saved_model/torch_res50.onnx",
- export_params=True,
- input_names=["input_image"],
- output_names = ["model_output"],
- dynamic_axes = {
- "input_image":{0: "batch"},
- "model_output":{0: "batch"}
- },
- opset_version=11
- )
主要修改模型名称,类别数目,以及模型的路径。输入大小如果修改的话也要改过来。
本文API以tensorrt8.x为例进行说明:
tensorrt8.x以后,默认是只支持量化到int8的,第2-6步C++核心代码如下:
- int TrtClassification::build_engine(const string onnx_path, const string trt_engine_path)
- {
- nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger); ///创建编译引擎
-
- nvinfer1::IBuilderConfig* config = builder->createBuilderConfig(); ///创建引擎配置
-
- nvinfer1::INetworkDefinition* network=builder->createNetworkV2(1);///创建网络
-
- nvonnxparser::IParser* parser=nvonnxparser::createParser(*network, logger);///创建onnx模型解析器
- if (!parser->parseFromFile((char*)onnx_path.c_str(), 1))
- {
- printf("Faile to parse %s \n", onnx_path);
- }
- config->setMaxWorkspaceSize(1 << 30);//设置每层最大可利用显存空间
- nvinfer1::IOptimizationProfile* profile = builder->createOptimizationProfile();
- nvinfer1::ITensor* input_tensor = network->getInput(0);
- nvinfer1::Dims input_dims = input_tensor->getDimensions();
- input_dims.d[0] = 1;
- profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);
- profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);
- profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);
- config->addOptimizationProfile(profile);
- nvinfer1::ICudaEngine* bengine = builder->buildEngineWithConfig(*network, *config);///创建推理引擎
- if (bengine == nullptr)
- {
- printf("build engine failed \n");
- return -1;
- }
- ///将推理引擎保存下来,需要将引擎序列化,保存为二进制文件
- nvinfer1::IHostMemory* model_data = bengine->serialize();
- FILE* f = fopen((char*)trt_engine_path.c_str(), "wb");
- fwrite(model_data->data(), 1, model_data->size(), f);
- fclose(f);
- printf("build engine done. \n");
- return 0;
- }
说明:
(1). builder:构建器,搜索cuda内核目录以获得最快的可用实现,必须使用和运行时的GPU相同的GPU来构建优化引擎。在构建引擎时,TensorRT会复制权重。
- nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger); ///创建编译引擎
- nvinfer1::INetworkDefinition* network=builder->createNetworkV2(1);///创建网络
(2). engine:引擎,不能跨平台和TensorRT版本移植。若要存储,需要将引擎转化为一种格式,即序列化,若要推理,需要反序列化引擎。引擎用于保存网络定义和模型参数。
- nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
- config->setMaxWorkspaceSize(1 << 30);
- nvinfer1::ICudaEngine* bengine = builder->buildEngineWithConfig(*network, *config);
- nvinfer1::IHostMemory* model_data = bengine->serialize();
- FILE* f = fopen((char*)trt_engine_path.c_str(), "wb");
- fwrite(model_data->data(), 1, model_data->size(), f);
- fclose(f);
其中:setMaxWorkspaceSize设置比较重要,设置的小可能发挥不出并行计算的性能,设置过大可能会导致运行的显存不够而导致推理失败,甚至程序异常退出。setMaxWorkspaceSize(1 << 30)表示设置的显存大小为1G,即2的30次方,如果设置256M,则为setMaxWorkspaceSize(1 << 28)。
使用推理引擎执行推理,首先需要将推理引擎反序列化,然后创建推理上下文,使用上下文执行推理过程得到推理结果,核心代码如下:
- auto engine_data = load_engine_data(trt_engine_path);
- auto runtime = make_nvshared(nvinfer1::createInferRuntime(logger));
- engine = make_nvshared(runtime->deserializeCudaEngine(engine_data.data(), engine_data.size()));///推理引擎反序列化
- if (engine == nullptr)
- {
- printf("Deserialize cuda engine failed.\n");
- return -1;
- }
-
- cudaStream_t stream = nullptr;
- checkRuntime(cudaStreamCreate(&stream));
- auto execution_context = make_nvshared(engine->createExecutionContext());由engine创建上下文context
需要完整推理代码的请私信我
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。