赞
踩
摘要:主要是解析PT权重转ONNX的过程、代码。文章末尾附有完整运行代码。
目录:
- model = torch.load("yolov8n.pt")
-
- # 由于我用v8做测试,v8的权重中除了模型结构,还有配置文件的参数,这里只需要其模型结构
- ckpt = torch.load("yolov8n.pt")
- model = ckpt['model']
作用:减少计算所需的资源
- for p in model.parameters():
- p.requires_grad = False
fuse的作用:将conv和BN层融合,提高推理速度
- model.eval()
- model.float()
- model = model.fuse()
- device = 'cpu' # 如果cuda能用,设置'cuda:0'
- input = torch.zeros((1, 3, 640, 640)).to(device) # 根据模型输入尺寸设置(1, 3, 640, 640)
- f = 'name.onnx' # 文件名,保存路径
- opset_version = 11 # opset版本
- input_names = ['img'] # 输入名
- output_names=['out'] # 输出名
- # 通过以下规则设置动态的维度
- dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
- 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)}
- torch.onnx.export(
- model, # 模型pt权重
- input, # 输入张量,模型输入,如[1,3,640,640]
- f, # 保存onnx模型的文件
- opset_version = opset_version, # Opset版本
- input_names=input_names, # 输入张量名称
- output_names=output_names, # 输出的张量名称
- dynamic_axes=dynamic_axes, # 通过以下规则设置动态的维度
- )
也可以结合步骤四直接这么写:
- torch.onnx.export(model = torch.load(weight)[model],
- im = torch.zeros(1, 3, 640, 640),
- f = ./weight.onnx,
- export_params = 12,
- input_names=['images'],
- output_names=['output'],
- dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},
- 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)})
比如你想在onnx中保存一些模型之外的信息,比如作者、模型名、classes、Opset版本等信息的话。
- metadata = {
- 'author': 'NO郑',
- 'version': '11',
- 'stride': int(max(model.stride)),
- 'task': 'yolov8',
- 'batch': 1,
- 'imgsz': (640, 640),
- 'labels': ['car', 'person']} # model metadata
- model_onnx = onnx.load(f)
- for k, v in metadata.items():
- meta = model_onnx.metadata_props.add()
- meta.key, meta.value = k, str(v)
onnx.save(model, f)
- import torch
- import onnx
-
- ckpt = torch.load("yolov8n.pt")
- model = ckpt['model']
- im = torch.zeros((1, 3, 640, 640)).to('cpu')
- for p in model.parameters():
- p.requires_grad = False
-
- model.eval()
- model.float()
- model = model.fuse()
- f = 'onnx01.onnx'
-
- torch.onnx.export(model.cpu(), im, f, opset_version=11, input_names=['img'], output_names=['out'], dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
- 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
- })
-
- metadata = {
- 'description': '1',
- 'author': 'NO郑',
- 'version': '11',
- 'batch': 1,
- 'imgsz': (640, 640),
- 'names': ['class1', 'classes2']} # model metadata
- model_onnx = onnx.load(f)
- for k, v in metadata.items():
- meta = model_onnx.metadata_props.add()
- meta.key, meta.value = k, str(v)
- onnx.save(model_onnx, f)
其中:METADATA是第6步添加的信息
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。