当前位置:   article > 正文

2、TensorRT学习笔记之PT转ONNX、可视化ONNX

pt转onnx

        摘要:主要是解析PT权重转ONNX的过程、代码。文章末尾附有完整运行代码。

目录:

        1、导入模型

        2、 关闭梯度更新

        3、精度转换、fuse等

        4、参数的设置

        5、 pt转onnx

        6、添加其余信息

        7、保存onnx

        8、 完整代码

        9、ONNX可视化


1、导入模型

  1. model = torch.load("yolov8n.pt")
  2. # 由于我用v8做测试,v8的权重中除了模型结构,还有配置文件的参数,这里只需要其模型结构
  3. ckpt = torch.load("yolov8n.pt")
  4. model = ckpt['model']

2、 关闭梯度更新

        作用:减少计算所需的资源

  1. for p in model.parameters():
  2. p.requires_grad = False

3、精度转换、fuse等

        fuse的作用:将conv和BN层融合,提高推理速度

  1. model.eval()
  2. model.float()
  3. model = model.fuse()

4、参数的设置

  1. device = 'cpu' # 如果cuda能用,设置'cuda:0'
  2. input = torch.zeros((1, 3, 640, 640)).to(device) # 根据模型输入尺寸设置(1, 3, 640, 640)
  3. f = 'name.onnx' # 文件名,保存路径
  4. opset_version = 11 # opset版本
  5. input_names = ['img'] # 输入名
  6. output_names=['out'] # 输出名
  7. # 通过以下规则设置动态的维度
  8. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  9. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)}

5、 pt转onnx

  1. torch.onnx.export(
  2. model, # 模型pt权重
  3. input, # 输入张量,模型输入,如[1,3,640,640]
  4. f, # 保存onnx模型的文件
  5. opset_version = opset_version, # Opset版本
  6. input_names=input_names, # 输入张量名称
  7. output_names=output_names, # 输出的张量名称
  8. dynamic_axes=dynamic_axes, # 通过以下规则设置动态的维度
  9. )

也可以结合步骤四直接这么写:

  1. torch.onnx.export(model = torch.load(weight)[model],
  2. im = torch.zeros(1, 3, 640, 640),
  3. f = ./weight.onnx,
  4. export_params = 12,
  5. input_names=['images'],
  6. output_names=['output'],
  7. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},
  8. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)})

6、添加其余信息

        比如你想在onnx中保存一些模型之外的信息,比如作者、模型名、classes、Opset版本等信息的话。

  1. metadata = {
  2. 'author': 'NO郑',
  3. 'version': '11',
  4. 'stride': int(max(model.stride)),
  5. 'task': 'yolov8',
  6. 'batch': 1,
  7. 'imgsz': (640, 640),
  8. 'labels': ['car', 'person']} # model metadata
  9. model_onnx = onnx.load(f)
  10. for k, v in metadata.items():
  11. meta = model_onnx.metadata_props.add()
  12. meta.key, meta.value = k, str(v)

7、保存onnx

onnx.save(model, f)

8、 完整代码

  1. import torch
  2. import onnx
  3. ckpt = torch.load("yolov8n.pt")
  4. model = ckpt['model']
  5. im = torch.zeros((1, 3, 640, 640)).to('cpu')
  6. for p in model.parameters():
  7. p.requires_grad = False
  8. model.eval()
  9. model.float()
  10. model = model.fuse()
  11. f = 'onnx01.onnx'
  12. torch.onnx.export(model.cpu(), im, f, opset_version=11, input_names=['img'], output_names=['out'], dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  13. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  14. })
  15. metadata = {
  16. 'description': '1',
  17. 'author': 'NO郑',
  18. 'version': '11',
  19. 'batch': 1,
  20. 'imgsz': (640, 640),
  21. 'names': ['class1', 'classes2']} # model metadata
  22. model_onnx = onnx.load(f)
  23. for k, v in metadata.items():
  24. meta = model_onnx.metadata_props.add()
  25. meta.key, meta.value = k, str(v)
  26. onnx.save(model_onnx, f)

9、ONNX可视化

  •         网页搜索netron
  •         将ONNX文件打开即可

        其中:METADATA是第6步添加的信息

上一篇:1、TensorRT学习笔记之安装TensorRT

下一篇:3、TensorRT学习笔记之ONNX转engine

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/850136
推荐阅读
相关标签
  

闽ICP备14008679号