当前位置:   article > 正文

pytorch模型转ONNX_netron输入的维度[1,3,256,256]

netron输入的维度[1,3,256,256]

目录

1. ONNX

2. pytorch 转 ONNX

3. 加载 ONNX 文件

4. Netron


1. ONNX

一般来说,pytorch训练好的模型是不能够直接用于生产环境,有很多的地方没有优化

而ONNX 格式可以兼顾不同框架的模型,相当于一个中间人的角色。这样部署到不同的环境中,就不需要考虑兼容的问题

 

2. pytorch 转 ONNX

测试代码如下:

  1. import torch
  2. from torchvision import models
  3. # 设备
  4. DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
  5. print(DEVICE)
  6. # 载入预训练模型
  7. model = models.resnet34(weights=True)
  8. model.to(DEVICE)
  9. model.eval()
  10. # 模型输入的维度
  11. input_tensor = torch.randn(1,3,256,256).to(DEVICE)
  12. output_tensor = model(input_tensor)
  13. print(output_tensor.shape)
  14. # pytorch 转 ONNX 格式
  15. with torch.no_grad():
  16. torch.onnx.export(
  17. model, # 转换的模型
  18. input_tensor, # 输入的维度
  19. 'resnet34_imagenet.onnx', # 导出的 ONNX 文件名
  20. opset_version=11, # ONNX 算子集的版本
  21. input_names= ['input'], # 输入的 tensor名称,可变
  22. output_names= ['output'] # 输出的 tensor名称,可变
  23. )

这里采用的是官方imageNet 上预训练的resnet34模型

转ONNX格式的时候,需要提供一个输入的维度,没有意义,类似于与tensorboard中输入流经model就可以知道model的配置一样

如下,运行之后会生成一个.onnx 文件 

 

3. 加载 ONNX 文件

测试代码如下:

  1. import onnx
  2. # 读取模型
  3. onnx_model = onnx.load('./resnet34_imagenet.onnx')
  4. # 检查模型格式是否正确,没有报错的话,说明载入成功
  5. onnx.checker.check_model(onnx_model)
  6. # 打印
  7. print(onnx.helper.printable_graph(onnx_model.graph))

print的信息大概就是网络的结构之类的

 

4. Netron

链接:https://netron.app/

 

将生成的 onnx 文件载入可以看的网络的信息

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/656117
推荐阅读
相关标签
  

闽ICP备14008679号