当前位置:   article > 正文

【ONNX】pytorch模型导出成ONNX格式:支持多参数与动态输入_导出模型onnx

导出模型onnx

        pytorch格式的模型在部署之前一般需要做格式转换。本文介绍了如何将pytorch格式的模型导出到ONNX格式的模型。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,比如:ONNXRuntime, Intel OpenVINO, TensorRT等。

1. 网络结构定义        

我们以一个Image Super Resolution的模型为例。首先,需要知道模型的网络定义SuperResolutionNet,并创建模型对象torch_model:

  1. # Super Resolution model definition in PyTorch
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. class SuperResolutionNet(nn.Module):
  5.     def __init__(self, upscale_factor, inplace=False):
  6.         super(SuperResolutionNet, self).__init__()
  7.         self.relu = nn.ReLU(inplace=inplace)
  8.         self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
  9.         self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
  10.         self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
  11.         self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
  12.         self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
  13.         self._initialize_weights()
  14.     def forward(self, x):
  15.         x = self.relu(self.conv1(x))
  16.         x = self.relu(self.conv2(x))
  17.         x = self.relu(self.conv3(x))
  18.         x = self.pixel_shuffle(self.conv4(x))
  19.         return x
  20.     def _initialize_weights(self):
  21.         init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
  22.         init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
  23.         init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
  24.         init.orthogonal_(self.conv4.weight)
  25.         init.zeros_(self.conv4.bias)
  26. # Create the super-resolution model by using the above model definition.
  27. torch_model = SuperResolutionNet(upscale_factor=3)

2. 加载模型文件

        Pytorch模型的参数信息存储在state_dict中。 state_dict是一个Python字典结构的对象,里面存储了神经网络中每层对应的参数张量。将每层的参数结构以及最后一层的bias打印出来:

  1. def print_state_dict(state_dict):    
  2.     print(len(state_dict))
  3.     for layer in state_dict:
  4.         print(layer, '\t', state_dict[layer].shape)
  5.     print(state_dict['conv4.bias'])
  6. print_state_dict(model.state_dict())

输出:

  1. 8
  2. conv1.weight      torch.Size([64, 1, 5, 5])
  3. conv1.bias      torch.Size([64])
  4. conv2.weight      torch.Size([64, 64, 3, 3])
  5. conv2.bias      torch.Size([64])
  6. conv3.weight      torch.Size([32, 64, 3, 3])
  7. conv3.bias      torch.Size([32])
  8. conv4.weight      torch.Size([9, 32, 3, 3])
  9. conv4.bias      torch.Size([9])
  10. tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.])

        因为之前将第四层的bias初始化为0,所以输出是全零。然后调用load_state_dict加载模型文件,可以看到加载之后参数的变化。eval将模型设置为推理状态。

  1. # Load pretrained model weights
  2. model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
  3. model.load_state_dict(model_zoo.load_url(model_url))
  4. print_state_dict(model.state_dict())
  5. # set the model to inference mode
  6. model.eval()

输出

  1. 8
  2. conv1.weight      torch.Size([64, 1, 5, 5])
  3. conv1.bias      torch.Size([64])
  4. conv2.weight      torch.Size([64, 64, 3, 3])
  5. conv2.bias      torch.Size([64])
  6. conv3.weight      torch.Size([32, 64, 3, 3])
  7. conv3.bias      torch.Size([32])
  8. conv4.weight      torch.Size([9, 32, 3, 3])
  9. conv4.bias      torch.Size([9])
  10. tensor([-0.0151, -0.0191, -0.0362, -0.0224,  0.0548,  0.0113,  0.0529,  0.0258,
  11.         -0.0180])
  12. SuperResolutionNet(
  13.   (relu): ReLU()
  14.   (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  15.   (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  16.   (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  17.   (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  18.   (pixel_shuffle): PixelShuffle(upscale_factor=3)
  19. )


3. 输出成ONNX格式

        在调用torch.onnx.export之前,需要先创建输入数据。因为模型的导出实际上是执行了一次推理过程。在执行的过程中记录使用到的操作。输入数据可以是随机的:

  1. # Input to the model
  2. x = torch.randn(1, 1, 224, 224, requires_grad=True)
  3. # Export the model
  4. torch.onnx.export(model,               # model being run
  5.                   x,                         # model input 
  6.                   "D:\\super_resolution.onnx",   # where to save the model (can be a file or file-like object)                  
  7.                   opset_version=11,          # the ONNX version to export the model to                  
  8.                   input_names = ['input'],   # the model's input names
  9.                   output_names = ['output']  # the model's output names
  10.                   )

        export的第一个参数是模型对象,第二个参数是输入数据,第三个参数是输出的模型文件名,这三个参数是必须指定的。还有一些常用的可选参数:

        opset_version, 指定的操作版本,一般越高的版本会支持更多的操作。如果遇到某个操作不支持,可以将版本号设置的高一点试试。
        input_names, 输入参数名。如果不指定,会使用默认名字。
        output_names, 输出参数名。如果不知道,会使用默认名字。
        输出成功后,可以使用Netron查看网络结构。Netron是一个开源的神经网络模型可视化工具,可以使用在线网页版的https://netron.app/,或者下载安装桌面版的https://github.com/lutzroeder/netron。打开导出的模型,结构如下:

在这里插入图片描述

4. 导出动态输入模型

        可以看到上面导出的模型输入是固定的1 x 1 x 224 x 224输出是固定的1 x 1 x 672 x 672.实际应用的时候输入图片的尺寸是不固定的,而且可能一次输入多种图片一起处理。我们可以通过指定dynamic_axes参数来导出动态输入的模型。dynamic_axes的参数是一个字典类型,字典的key就是输入或者输出的名字,对应key的value可以是一个字典或者列表,指定了输入或者输出的index以及对应的名字。比如想要让输入的index为0的维度表示动态的batch_size那么就指定{0: 'batch_size'}。同样的方法可以指定宽高所在的维度输出成动态的。

  1. input_name = 'input'
  2. output_name = 'output'
  3. torch.onnx.export(model,               # model being run
  4.                   x,                         # model input 
  5.                   "D:\\super_resolution_2.onnx",   # where to save the model (can be a file or file-like object)                  
  6.                   opset_version=11,          # the ONNX version to export the model to                  
  7.                   input_names = [input_name],   # the model's input names
  8.                   output_names = [output_name],  # the model's output names
  9.                   dynamic_axes= {
  10.                         input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
  11.                         output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
  12.                   )

输出的模型使用Netron打开,结构如下:

在这里插入图片描述

        查看输入输出信息可以看到,输入的维度变成:[batch_size,1,in_width,int_height],输出的维度变成:[batch_size,1,out_width,out_height]。表示这个模型可以接收动态的批次大小和宽高尺寸。

在这里插入图片描述

5. 多参数输入

5.1 多参数输入模型的导出

        有时候可能会遇到比较复杂的模型,推理时需要输入多个参数的情况。我们可以通过将参数列表包在一个list中来输出ONNX模型。我们先将模型的forward方法修改一下,增加一个输入参数scale:

  1. class SuperResolutionNet2(nn.Module):
  2.     def __init__(self, upscale_factor, inplace=False):
  3.         super(SuperResolutionNet2, self).__init__()
  4.         self.relu = nn.ReLU(inplace=inplace)
  5.         self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
  6.         self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
  7.         self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
  8.         self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
  9.         self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
  10.         self._initialize_weights()
  11.     def forward(self, x, scale):
  12.         x = self.relu(self.conv1(x))
  13.         x = self.relu(self.conv2(x))
  14.         x = self.relu(self.conv3(x))
  15.         x = self.pixel_shuffle(self.conv4(x))
  16.         return x
  17.     def _initialize_weights(self):
  18.         init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
  19.         init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
  20.         init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
  21.         init.orthogonal_(self.conv4.weight)
  22.         init.zeros_(self.conv4.bias)  
  23. # Create the super-resolution model by using the above model definition.
  24. model2 = SuperResolutionNet2(upscale_factor=3)

调用export输出到ONNX:

  1. input_name = 'input'
  2. output_name = 'output'
  3. torch.onnx.export(model2,               
  4.                   (x, 2),                         
  5.                   "D:\\super_resolution_3.onnx",   
  6.                   opset_version=11,          
  7.                   input_names = [input_name],  
  8.                   output_names = [output_name],
  9.                   dynamic_axes= {
  10.                         input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
  11.                         output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
  12.                   )

5.2 易错点

        由于export函数的机制,会把模型输入的参数自动转换成tensor类型,比如上面的scale参数,虽然传入的时候是int32类型,但是export在执行时会调用到forward函数,此时scale已经变成一个tensor类型。我们可以做个测试,打印一下scale的类型来验证:

  1. def forward(self, x, scale):
  2.     print(scale)
  3.     x = self.relu(self.conv1(x))
  4.     x = self.relu(self.conv2(x))
  5.     x = self.relu(self.conv3(x))
  6.     x = self.pixel_shuffle(self.conv4(x))
  7.     return x

重新运行export后输出:

tensor(2)
        这种机制带来的影响是,在使用scale参数时可能需要做一个转换,比如转换成float类型。否则某些函数的调用会失败。以插值函数为例做个测试,将forward修改一下:

  1. def forward(self, x, scale):
  2.     print(scale)        
  3.     y = F.interpolate(x, scale_factor= 1./scale, mode="bilinear")
  4.     x = self.relu(self.conv1(x))
  5.     x = self.relu(self.conv2(x))
  6.     x = self.relu(self.conv3(x))
  7.     x = self.pixel_shuffle(self.conv4(x))
  8.     return x

这个时候运行export会报错,因为插值函数的scale_factor参数不能是一个tensor类型。修改后的正确版本:

  1. def forward(self, x, scale):
  2.     print(scale)        
  3.     y = F.interpolate(x, scale_factor= 1./float(scale), mode="bilinear")
  4.     x = self.relu(self.conv1(x))
  5.     x = self.relu(self.conv2(x))
  6.     x = self.relu(self.conv3(x))
  7.     x = self.pixel_shuffle(self.conv4(x))
  8.     return x

6. 完整代码

在这里 https://github.com/jb2020-super/pytorch-utils/blob/main/to_onnx_ex.ipynb

7. 参考

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
https://pytorch.org/docs/stable/onnx.html?highlight=export#torch.onnx.export
https://onnxruntime.ai/docs/get-started/with-python.html
————————————————
Thanks to:https://blog.csdn.net/superbinlovemiaomi/article/details/121344667

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/367218
推荐阅读
相关标签
  

闽ICP备14008679号