赞
踩
用torchvision导出mobilenet_v2网络结构,生成mobilenet_v2.pt
- import torch
- from torch import nn
- from torchvision import models
- import torch.nn.functional as F
- class MobileNet_v2(nn.Module):
- def __init__(self):
- super(MobileNet_v2, self).__init__()
- model = models.mobilenet_v2(pretrained=True)
- # Remove linear and pool layers (since we're not doing classification)
- modules = list(model.children())[:-1]
- self.resnet = nn.Sequential(*modules)
- self.fc = nn.Linear(1280, 2)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, images):
- x = self.resnet(images) # [N, 1280, 1, 1]
- x = F.adaptive_avg_pool2d(x,(1,1))
- x = x.view(-1, 1280) # [N, 1280]
- x = self.fc(x)
- out= self.softmax(x)
- return out
-
- model = MobileNet_v2()
- x = torch.rand(1, 3,224, 224)
- torch.save(model.state_dict(),"mobilenet_v2.pt")
- out=model(x)
- print(out)
- import torch
- from torch import nn
- from torchvision import models
- import torch.nn.functional as F
- import torch.onnx
- class MobileNet_v2(nn.Module):
- def __init__(self):
- super(MobileNet_v2, self).__init__()
- model = models.mobilenet_v2(pretrained=True)
- # Remove linear and pool layers (since we're not doing classification)
- modules = list(model.children())[:-1]
- self.resnet = nn.Sequential(*modules)
- self.fc = nn.Linear(1280, 2)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, images):
- x = self.resnet(images) # [N, 1280, 1, 1]
- x = F.adaptive_avg_pool2d(x,(1,1))
- x = x.view(-1, 1280) # [N, 1280]
- x = self.fc(x)
- out= self.softmax(x)
- return out
-
- model = MobileNet_v2()
- model.load_state_dict(torch.load("mobilenet_v2.pt",map_location=torch.device('cpu')))
- # An example input you would normally provide to your model's forward() method
- x = torch.rand(1, 3,224, 224)
- # Export the model
- torch_out = torch.onnx.export(model, x, "mobilenet_v2.onnx", export_params=True)
注意:.pt输入类型为tensor,而onnx输入类型为numpy
- import torch
- import torch.onnx
- import onnxruntime
-
- class OnnxModel():
- def __init__(self, onnx_path):
- """
- :param onnx_path:
- """
- self.onnx_session = onnxruntime.InferenceSession(onnx_path)
- self.input_name = self.get_input_name(self.onnx_session)
- self.output_name = self.get_output_name(self.onnx_session)
- def get_output_name(self, onnx_session):
- """
- output_name = onnx_session.get_outputs()[0].name
- :param onnx_session:
- :return:
- """
- output_name = []
- for node in onnx_session.get_outputs():
- output_name.append(node.name)
- return output_name
-
- def get_input_name(self, onnx_session):
- """
- input_name = onnx_session.get_inputs()[0].name
- :param onnx_session:
- :return:
- """
- input_name = []
- for node in onnx_session.get_inputs():
- input_name.append(node.name)
- return input_name
-
- def get_input_feed(self, input_name, image_numpy):
- """
- input_feed={self.input_name: image_numpy}
- :param input_name:
- :param image_numpy:
- :return:
- """
- input_feed = {}
- for name in input_name:
- input_feed[name] = image_numpy
- return input_feed
-
- def forward(self, image_numpy):
- '''
- # image_numpy = image.transpose(2, 0, 1)
- # image_numpy = image_numpy[np.newaxis, :]
- # onnx_session.run([output_name], {input_name: x})
- # :param image_numpy:
- # :return:
- '''
-
- input_feed = self.get_input_feed(self.input_name, image_numpy)
- # scores = self.onnx_session.run(self.output_name[0], input_feed=input_feed)
- output = self.onnx_session.run(self.output_name, input_feed=input_feed)
- return output
-
- def to_numpy(tensor):
- return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
- onnx_model_path = "mobilenet_v2.onnx"
- model = OnnxModel(onnx_model_path)
- x = torch.rand(1, 3,224, 224)
- out = model.forward(to_numpy(x))
- print(out)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。