当前位置:   article > 正文

Pytorch模型转为onnx,onnx 模型 inference测试_python onnx inference

python onnx inference

1、mobilenet_v2生成.pt

用torchvision导出mobilenet_v2网络结构,生成mobilenet_v2.pt

  1. import torch
  2. from torch import nn
  3. from torchvision import models
  4. import torch.nn.functional as F
  5. class MobileNet_v2(nn.Module):
  6. def __init__(self):
  7. super(MobileNet_v2, self).__init__()
  8. model = models.mobilenet_v2(pretrained=True)
  9. # Remove linear and pool layers (since we're not doing classification)
  10. modules = list(model.children())[:-1]
  11. self.resnet = nn.Sequential(*modules)
  12. self.fc = nn.Linear(1280, 2)
  13. self.softmax = nn.Softmax(dim=-1)
  14. def forward(self, images):
  15. x = self.resnet(images) # [N, 1280, 1, 1]
  16. x = F.adaptive_avg_pool2d(x,(1,1))
  17. x = x.view(-1, 1280) # [N, 1280]
  18. x = self.fc(x)
  19. out= self.softmax(x)
  20. return out
  21. model = MobileNet_v2()
  22. x = torch.rand(1, 3,224, 224)
  23. torch.save(model.state_dict(),"mobilenet_v2.pt")
  24. out=model(x)
  25. print(out)

2、mobilenet_v2.pt 转化为mobilenet_v2.onnx,此时便可脱离pytorch框架,进行跨平台部署。

  1. import torch
  2. from torch import nn
  3. from torchvision import models
  4. import torch.nn.functional as F
  5. import torch.onnx
  6. class MobileNet_v2(nn.Module):
  7. def __init__(self):
  8. super(MobileNet_v2, self).__init__()
  9. model = models.mobilenet_v2(pretrained=True)
  10. # Remove linear and pool layers (since we're not doing classification)
  11. modules = list(model.children())[:-1]
  12. self.resnet = nn.Sequential(*modules)
  13. self.fc = nn.Linear(1280, 2)
  14. self.softmax = nn.Softmax(dim=-1)
  15. def forward(self, images):
  16. x = self.resnet(images) # [N, 1280, 1, 1]
  17. x = F.adaptive_avg_pool2d(x,(1,1))
  18. x = x.view(-1, 1280) # [N, 1280]
  19. x = self.fc(x)
  20. out= self.softmax(x)
  21. return out
  22. model = MobileNet_v2()
  23. model.load_state_dict(torch.load("mobilenet_v2.pt",map_location=torch.device('cpu')))
  24. # An example input you would normally provide to your model's forward() method
  25. x = torch.rand(1, 3,224, 224)
  26. # Export the model
  27. torch_out = torch.onnx.export(model, x, "mobilenet_v2.onnx", export_params=True)

3、onnx模型进行推理

注意:.pt输入类型为tensor,而onnx输入类型为numpy

  1. import torch
  2. import torch.onnx
  3. import onnxruntime
  4. class OnnxModel():
  5. def __init__(self, onnx_path):
  6. """
  7. :param onnx_path:
  8. """
  9. self.onnx_session = onnxruntime.InferenceSession(onnx_path)
  10. self.input_name = self.get_input_name(self.onnx_session)
  11. self.output_name = self.get_output_name(self.onnx_session)
  12. def get_output_name(self, onnx_session):
  13. """
  14. output_name = onnx_session.get_outputs()[0].name
  15. :param onnx_session:
  16. :return:
  17. """
  18. output_name = []
  19. for node in onnx_session.get_outputs():
  20. output_name.append(node.name)
  21. return output_name
  22. def get_input_name(self, onnx_session):
  23. """
  24. input_name = onnx_session.get_inputs()[0].name
  25. :param onnx_session:
  26. :return:
  27. """
  28. input_name = []
  29. for node in onnx_session.get_inputs():
  30. input_name.append(node.name)
  31. return input_name
  32. def get_input_feed(self, input_name, image_numpy):
  33. """
  34. input_feed={self.input_name: image_numpy}
  35. :param input_name:
  36. :param image_numpy:
  37. :return:
  38. """
  39. input_feed = {}
  40. for name in input_name:
  41. input_feed[name] = image_numpy
  42. return input_feed
  43. def forward(self, image_numpy):
  44. '''
  45. # image_numpy = image.transpose(2, 0, 1)
  46. # image_numpy = image_numpy[np.newaxis, :]
  47. # onnx_session.run([output_name], {input_name: x})
  48. # :param image_numpy:
  49. # :return:
  50. '''
  51. input_feed = self.get_input_feed(self.input_name, image_numpy)
  52. # scores = self.onnx_session.run(self.output_name[0], input_feed=input_feed)
  53. output = self.onnx_session.run(self.output_name, input_feed=input_feed)
  54. return output
  55. def to_numpy(tensor):
  56. return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
  57. onnx_model_path = "mobilenet_v2.onnx"
  58. model = OnnxModel(onnx_model_path)
  59. x = torch.rand(1, 3,224, 224)
  60. out = model.forward(to_numpy(x))
  61. print(out)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/351396
推荐阅读
相关标签
  

闽ICP备14008679号