赞
踩
目录
一般来说,pytorch训练好的模型是不能够直接用于生产环境,有很多的地方没有优化
而ONNX 格式可以兼顾不同框架的模型,相当于一个中间人的角色。这样部署到不同的环境中,就不需要考虑兼容的问题
测试代码如下:
- import torch
- from torchvision import models
-
-
- # 设备
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
- print(DEVICE)
-
- # 载入预训练模型
- model = models.resnet34(weights=True)
- model.to(DEVICE)
- model.eval()
-
- # 模型输入的维度
- input_tensor = torch.randn(1,3,256,256).to(DEVICE)
- output_tensor = model(input_tensor)
- print(output_tensor.shape)
-
- # pytorch 转 ONNX 格式
- with torch.no_grad():
- torch.onnx.export(
- model, # 转换的模型
- input_tensor, # 输入的维度
- 'resnet34_imagenet.onnx', # 导出的 ONNX 文件名
- opset_version=11, # ONNX 算子集的版本
- input_names= ['input'], # 输入的 tensor名称,可变
- output_names= ['output'] # 输出的 tensor名称,可变
- )
这里采用的是官方imageNet 上预训练的resnet34模型
转ONNX格式的时候,需要提供一个输入的维度,没有意义,类似于与tensorboard中输入流经model就可以知道model的配置一样
如下,运行之后会生成一个.onnx 文件
测试代码如下:
- import onnx
-
-
- # 读取模型
- onnx_model = onnx.load('./resnet34_imagenet.onnx')
-
- # 检查模型格式是否正确,没有报错的话,说明载入成功
- onnx.checker.check_model(onnx_model)
-
- # 打印
- print(onnx.helper.printable_graph(onnx_model.graph))
-
print的信息大概就是网络的结构之类的
将生成的 onnx 文件载入可以看的网络的信息
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。