当前位置:   article > 正文

pytorch文档阅读(五)如何保存、加载网络模型_torch 只保存网络结构

torch 只保存网络结构

1.网络的保存

torch.save()有两种方法

1)仅保存网络参数

torch.save(net.state_dict(), 'net_params.pkl')

2)保存整个网络结构

 torch.save(net, 'net.pkl')

2.网络的加载

1)仅加载参数

model_object.load_state_dict(torch.load('net_params.pkl')) 

2)加载整个模型

model = torch.load('net.pkl') 

两种方法在载入模型时都需要有预设的网络结构,例如下边代码,否则会提示找不到相应的module

  1. #加载整个网络
  2. class AlexNet(nn.Module):
  3. def __init__(self):
  4. super(AlexNet,self).__init__()
  5. self.conv1 = nn.Conv2d(3, 64, 5)
  6. self.pool1 = nn.MaxPool2d(3, 2)
  7. self.conv2 = nn.Conv2d(64, 64, 5)
  8. self.pool2 = nn.MaxPool2d(3, 2)
  9. self.fc1 = nn.Linear(1024, 384)
  10. self.fc2 = nn.Linear(384, 192)
  11. self.fc3 = nn.Linear(192, 10)
  12. def forward(self, x):
  13. x = self.pool1(F.relu(self.conv1(x)))
  14. x = self.pool2(F.relu(self.conv2(x)))
  15. x = x.view(x.shape[0], -1)
  16. x = F.relu(self.fc1(x))
  17. x = F.relu(self.fc2(x))
  18. x = F.softmax(self.fc3(x))
  19. return x
  20. net = torch.load("TestSave.pkl")#加载整个模型时直接用这句就可以实例化网络,并且把CUDA上运行这个属性也继承了过来
  1. #只加载网络参数
  2. class AlexNet(nn.Module):
  3. def __init__(self):
  4. super(AlexNet,self).__init__()
  5. self.conv1 = nn.Conv2d(3, 64, 5)
  6. self.pool1 = nn.MaxPool2d(3, 2)
  7. self.conv2 = nn.Conv2d(64, 64, 5)
  8. self.pool2 = nn.MaxPool2d(3, 2)
  9. self.fc1 = nn.Linear(1024, 384)
  10. self.fc2 = nn.Linear(384, 192)
  11. self.fc3 = nn.Linear(192, 10)
  12. def forward(self, x):
  13. x = self.pool1(F.relu(self.conv1(x)))
  14. x = self.pool2(F.relu(self.conv2(x)))
  15. x = x.view(x.shape[0], -1)
  16. x = F.relu(self.fc1(x))
  17. x = F.relu(self.fc2(x))
  18. x = F.softmax(self.fc3(x))
  19. return x
  20. net = AlexNet()#只加载网络参数的时候需要自行实例化网络
  21. net.cuda()#并设置网络运行在cpu还是gpu上
  22. net.load_state_dict(torch.load('net_params.pkl'))#再加载网络的参数

注意:

1.只加载网络参数的速度比加载整个网络快得多

2.pth、pkl格式效果相同,ckpt是tensorflow的格式

参考链接:

https://www.jb51.net/article/139102.htm

https://www.jianshu.com/p/0eda629e4007

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/103363
推荐阅读
相关标签
  

闽ICP备14008679号