赞
踩
目录
在PyTorch中,可以使用不同的方式来保存深度学习模型。两种常用的方式如下:
方式1:保存整个模型结构和参数
可以使用torch.save()
函数将整个模型,包括模型的结构和参数,保存到文件中。这种方式保存的文件通常较大。
- import torch
- import torchvision
-
- model = torchvision.models.vgg16(weights=None)
- torch.save(model, "vgg16_method1.pth")
方式2:仅保存模型的参数
更常见的方式是只保存模型的参数,而不包括模型的结构。这可以通过model.state_dict()
来实现。这种方式的文件通常较小,适合分享和部署。
- import torch
- import torchvision
-
- vgg16 = torchvision.models.vgg16(weights=None)
- torch.save(vgg16.state_dict(), "vgg16_method2.pth")
一旦模型保存在文件中,可以使用torch.load()
函数来加载模型。在加载模型时,确保代码中定义了与保存模型相同的模型结构,以避免出现错误。
方式1:加载整个模型
- import torch
- import torchvision
-
- model = torch.load("vgg16_method1.pth")
- print(model)
方式2:加载模型参数
- import torch
- import torchvision
- from torch import nn
-
- vgg16 = torchvision.models.vgg16(weights=None)
- vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
- print(vgg16)
注意事项
在使用方式1保存模型时,如果加载的代码中的模型结构与保存的模型结构不匹配,可能会导致错误。为了避免这种情况,建议将模型结构的定义放在单独的Python文件中,然后通过from
导入,以确保加载模型时能够正确匹配模型结构。
- # P26_model_save.py
- import torch
- from torch import nn
-
- class Tudui(nn.Module):
- def __init__(self):
- super(Tudui, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
-
- def forward(self, x):
- x = self.conv1(x)
- return x
然后在加载模型的代码中使用from P26_model_save import *
导入模型结构。
完整代码如下:
- import torch
- import torchvision
- from P26_model_save import *
-
- # 方式1,保存方式1,加载模型
- model = torch.load("vgg16_method1.path")
- print(model)
-
- # 方式2,加载模型
- vgg16 = torchvision.models.vgg16(weights=None)
- vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
- print(vgg16)
-
- # 陷阱,用第一种方式保存时,如果是自己的模型,就需要在加载中,把class重新写一遍,但并不需要实例化,即可
- # 这个陷阱,也是可以避免的,最上面的 from model_save import *,就是在做这个事情,避免出现错误
- # class Tudui(nn.Module):
- # def __init__(self):
- # super(Tudui, self).__init__()
- # self.conv1 = nn.Conv2d(3, 64, ke nel_size=3)
-
- # def forward(self, x):
- # x = self.conv1(x)
- # return x
-
- # model = torch.load('tudui_method1.pth')
- # print(model)
- import torch
- import torchvision
- from torch import nn
-
- vgg16 = torchvision.models.vgg16(weights=None)
- # 保存方式1,模型结构+模型参数 模型 + 参数 都保存
- # torch.save(vgg16,"vgg16_method1.pth") # 引号里是保存路径
- # 保存方式2,模型参数(官方推荐) ,因为这个方式,储存量小,在terminal中,ls -all可以查看
- torch.save(vgg16.state_dict(),"vgg16_method2.pth")
- # 把网络模型的参数,保存下来,储存成字典的形式
-
- # 把网络模型的参数,保存下来,储存成字典的形式
-
- # 陷阱
- # class Tudui(nn.Module):
- # def __init__(self):
- # super(Tudui, self).__init__()
- # self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
- #
- # def forward(self, x):
- # x = self.conv1(x)
- # return x
-
-
- # tudui = Tudui()
- # torch.save(tudui, "tudui_method1.pth")
参考资料:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。