当前位置:   article > 正文

PyTorch入门学习(十六):网络模型的保存与读取_pytorch如何保存训练好的网络模型

pytorch如何保存训练好的网络模型

目录

一、保存模型的方式

二、加载模型


一、保存模型的方式

在PyTorch中,可以使用不同的方式来保存深度学习模型。两种常用的方式如下:

方式1:保存整个模型结构和参数

可以使用torch.save()函数将整个模型,包括模型的结构和参数,保存到文件中。这种方式保存的文件通常较大。

  1. import torch
  2. import torchvision
  3. model = torchvision.models.vgg16(weights=None)
  4. torch.save(model, "vgg16_method1.pth")

方式2:仅保存模型的参数

更常见的方式是只保存模型的参数,而不包括模型的结构。这可以通过model.state_dict()来实现。这种方式的文件通常较小,适合分享和部署。

  1. import torch
  2. import torchvision
  3. vgg16 = torchvision.models.vgg16(weights=None)
  4. torch.save(vgg16.state_dict(), "vgg16_method2.pth")

二、加载模型

一旦模型保存在文件中,可以使用torch.load()函数来加载模型。在加载模型时,确保代码中定义了与保存模型相同的模型结构,以避免出现错误。

方式1:加载整个模型

  1. import torch
  2. import torchvision
  3. model = torch.load("vgg16_method1.pth")
  4. print(model)

方式2:加载模型参数

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. vgg16 = torchvision.models.vgg16(weights=None)
  5. vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
  6. print(vgg16)

注意事项

在使用方式1保存模型时,如果加载的代码中的模型结构与保存的模型结构不匹配,可能会导致错误。为了避免这种情况,建议将模型结构的定义放在单独的Python文件中,然后通过from导入,以确保加载模型时能够正确匹配模型结构。

  1. # P26_model_save.py
  2. import torch
  3. from torch import nn
  4. class Tudui(nn.Module):
  5. def __init__(self):
  6. super(Tudui, self).__init__()
  7. self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
  8. def forward(self, x):
  9. x = self.conv1(x)
  10. return x

然后在加载模型的代码中使用from P26_model_save import *导入模型结构。

完整代码如下:

  1. import torch
  2. import torchvision
  3. from P26_model_save import *
  4. # 方式1,保存方式1,加载模型
  5. model = torch.load("vgg16_method1.path")
  6. print(model)
  7. # 方式2,加载模型
  8. vgg16 = torchvision.models.vgg16(weights=None)
  9. vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
  10. print(vgg16)
  11. # 陷阱,用第一种方式保存时,如果是自己的模型,就需要在加载中,把class重新写一遍,但并不需要实例化,即可
  12. # 这个陷阱,也是可以避免的,最上面的 from model_save import *,就是在做这个事情,避免出现错误
  13. # class Tudui(nn.Module):
  14. # def __init__(self):
  15. # super(Tudui, self).__init__()
  16. # self.conv1 = nn.Conv2d(3, 64, ke nel_size=3)
  17. # def forward(self, x):
  18. # x = self.conv1(x)
  19. # return x
  20. # model = torch.load('tudui_method1.pth')
  21. # print(model)
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. vgg16 = torchvision.models.vgg16(weights=None)
  5. # 保存方式1,模型结构+模型参数 模型 + 参数 都保存
  6. # torch.save(vgg16,"vgg16_method1.pth") # 引号里是保存路径
  7. # 保存方式2,模型参数(官方推荐) ,因为这个方式,储存量小,在terminal中,ls -all可以查看
  8. torch.save(vgg16.state_dict(),"vgg16_method2.pth")
  9. # 把网络模型的参数,保存下来,储存成字典的形式
  10. # 把网络模型的参数,保存下来,储存成字典的形式
  11. # 陷阱
  12. # class Tudui(nn.Module):
  13. # def __init__(self):
  14. # super(Tudui, self).__init__()
  15. # self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
  16. #
  17. # def forward(self, x):
  18. # x = self.conv1(x)
  19. # return x
  20. # tudui = Tudui()
  21. # torch.save(tudui, "tudui_method1.pth")

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/103384
推荐阅读
相关标签
  

闽ICP备14008679号