赞
踩
如果使用pytorch,通常使用ONNX,也就是中间一条方案。
因为由pytorch到ONNX由pytorch官方维护,并且更新频率较快,由ONNX到TensorRT由TensorRT官方维护,所以采用下面的方案,GitHub地址:链接
对于第一点:是因为如果写成size或shape返回的参数时,会造成pytorch对size的跟踪,生成gather和shape等节点。
指定维度时不加int,会生成shape、gather、unsqueeze、concat等节点
代码:
import torch import torch.nn as nn class Module(nn.Module): def __init__(self): super(Module, self).__init__() self.conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=True) self.conv.weight.data.fill_(0.3) self.conv.bias.data.fill_(0.2) def forward(self, x): x = self.conv(x) return x.view(x.size(0), -1) model = Module().eval() x = torch.full((1, 1, 3, 3), 1.0) y = model(x) torch.onnx.export(model, (x,), "lesson1.onnx", verbose=True)
代码:
import torch import torch.nn as nn class Module(nn.Module): def __init__(self): super(Module, self).__init__() self.conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=True) self.conv.weight.data.fill_(0.3) self.conv.bias.data.fill_(0.2) def forward(self, x): x = self.conv(x) return x.view(-1, int(x.numel()//x.size(0))) model = Module().eval() x = torch.full((1, 1, 3, 3), 1.0) y = model(x) torch.onnx.export(model, (x,), "lesson1.onnx", verbose=True)
使用正确导出ONNX的第1、3条,batch指定为-1,其余维度指定前面加上 int;
关于动态batch,因为tensorRT对静态batch的处理,即使设的batch是32,如果输入的图片是1或者是32都是按32来处理的,所以耗时是固定的,且缺少灵活性。
关于动态宽高,如果使用trt的动态宽高,即是说可以接收分辨率输入为320×320、640×640的图片,这样做灵活性增高但复杂性也在增高。所以采用在编译时修改ONNX的输入实现相对动态,避免重回pytorch再做导出的操作。
如下图,在 {} 修改输入的shape。
视频也说到:一个trt引擎可以处理多个分辨率大小的图片是没有意义的,因为复杂度会增高,效率也会大打折扣,所以一个引擎一个固定大小的输入分辨率即可。
不建议使用dynamic_axes指定0意外的维度为动态,意思是说:batch维度为动态指定,指定为-1,气态维度固定大小即可。
编译时指定最大的batch,上图指定的最大batch为5,推理时使用的是2。
resize_single_dim(0, 2),指定单个维度,0:第一个维度即batch,2:指定batch为2。
修改两个地方,一个是return返回,-1指定batch,25是5×5的结果;编译时也修改为{{1,1,5,5}}
首先是test_plugin.py导出ONNX,再调用这些插件。先看一下test_plugin.py,
代码:
import torch import torch.nn.functional as F import torch.nn as nn import json class HSwishImplementation(torch.autograd.Function): # 主要是这里,对于autograd.Function这种自定义实现的op,只需要添加静态方法symbolic即可, # 除了g以外的参数应与forward函数的除ctx以外完全一样 # 这里演示了input->作为tensor输入,bias->作为参数输入,两者将会在tensorRT里面具有不同的处理方式 # 对于附加属性(attributes),以 "名称_类型简写" 方式定义,类型简写, # 请参考:torch/onnx/symbolic_helper.py中_parse_arg函数的实现【from torch.onnx.symbolic_helper import _parse_arg】 # 属性的定义会在对应节点生成attributes,并传给tensorRT的onnx解析器做处理 @staticmethod def symbolic(g, input, bias): # 如果配合当前tensorRT框架,则必须名称为Plugin,参考:tensorRT/src/tensorRT/onnx_parser/builtin_op_importers.cpp的160行定义 # 若你想自己命名,可以考虑做类似修改即可 # # name_s表示,name是string类型的,对应于C++插件的名称,参考:tensorRT/src/tensorRT/onnxplugin/plugins/HSwish.cu的82行定义的名称 # info_s表示,info是string类型的,通常我们可以利用json.dumps,传一个复杂的字符串结构,然后在CPP中json解码即可。参考: # sxai/tensorRT/src/tensorRT/onnxplugin/plugins/HSwish.cu的39行 return g.op("Plugin", input, bias, name_s="HSwish", info_s=json.dumps({"alpha": 3.5, "beta": 2.88})) # 这里的forward只是为了让onnx导出时可以执行,实际上写与不写意义不大,只需要返回同等的输出维度即可 @staticmethod def forward(ctx, i, bias): ctx.save_for_backward(i) return i * F.relu6(i + 3) / 6 # 这里省略了backward class MemoryEfficientHSwish(nn.Module): def __init__(self): super(MemoryEfficientHSwish, self).__init__() # 这里我们假设有bias作为权重参数 self.bias = nn.Parameter(torch.zeros((3, 3, 3, 3))) self.bias.data.fill_(3.15) def forward(self, x): # 我们假设丢一个bias进去 return HSwishImplementation.apply(x, self.bias) class FooModel(torch.nn.Module): def __init__(self): super(FooModel, self).__init__() self.hswish = MemoryEfficientHSwish() def forward(self, input1, input2): return F.relu(input2 * self.hswish(input1)) dummy_input1 = torch.zeros((1, 3, 3, 3)) dummy_input2 = torch.zeros((1, 3, 3, 3)) model = FooModel() # 这里演示了2个输入的情况,实际上你可以自己定义几个输入 torch.onnx.export( model, (dummy_input1, dummy_input2), 'hswish.plugin.onnx', input_names=["input.0", "input.1"], output_names=["output.0"], verbose=True, opset_version=11, # >=11支持性更好,默认等于9 # 动态指定全为batch dynamic_axes={"input.0": {0: "batch"}, "input.1": {0: "batch"}, "output.0": {0: "batch"}}, enable_onnx_checker=False # 作为插件而言老是报错,所以改为False ) print("Done")
输出:
其实就是先使用test_plugin.py文件先导出ONNX文件,再用tensorRT进行编译和推理,
#include <builder/trt_builder.hpp> #include <infer/trt_infer.hpp> #include <common/ilogger.hpp> #include "app_yolo/yolo.hpp" using namespace std; static void test_hswish(TRT::Mode mode){ // The plugin.onnx can be generated by the following code // cd workspace // python test_plugin.py iLogger::set_log_level(iLogger::LogLevel::Verbose); TRT::set_device(0); auto mode_name = TRT::mode_string(mode); auto engine_name = iLogger::format("hswish.plugin.%s.trtmodel", mode_name); TRT::compile( mode, 3, "hswish.plugin.onnx", engine_name, {} ); auto engine = TRT::load_infer(engine_name); engine->print(); auto input0 = engine->input(0); auto input1 = engine->input(1); auto output = engine->output(0); INFO("offset %d", output->offset(1, 0)); INFO("input0: %s", input0->shape_string()); INFO("input1: %s", input1->shape_string()); INFO("output: %s", output->shape_string()); float input0_val = 0.8; float input1_val = 2; input0->set_to(input0_val); input1->set_to(input1_val); auto hswish = [](float x){float a = x + 3; a=a<0?0:(a>=6?6:a); return x * a / 6;}; auto sigmoid = [](float x){return 1 / (1 + exp(-x));}; auto relu = [](float x){return max(0.0f, x);}; float output_real = relu(hswish(input0_val) * input1_val); engine->forward(true); INFO("output %f, output_real = %f", output->at<float>(0, 0), output_real); }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。