当前位置:   article > 正文

Pytorch:模型的加载和保存 torch.save,torch.load,torch.nn.Module.state_dict 和 torch.nn.Module.load_state_dict_torch.save module

torch.save module

在PyTorch中,模型的保存和加载是通过几个关键的函数来完成的,它们分别是 torch.savetorch.loadtorch.nn.Module.state_dicttorch.nn.Module.load_state_dict

torch.save

torch.save 是用于保存一个序列化对象到磁盘的函数。这个序列化对象可以是任何类型的对象,包括模型、张量和字典等。函数内部使用Python的 pickle 模块将对象进行序列化。

# object可以是模型、张量、字典等任何对象
torch.save(object, 'path_to_file.pth')
  • 1
  • 2

torch.load

torch.load 是用于从磁盘加载一个通过 torch.save 保存的序列化对象的函数。加载时,会根据保存的对象重新构造出原来的对象。

# 加载先前保存的对象
object = torch.load('path_to_file.pth')
  • 1
  • 2

torch.nn.Module.state_dict

state_dict 是模型(更确切地说是 nn.Module)的一个方法,返回模型的参数的字典。字典的键为层的名字,值为对应层的参数张量。state_dict 只包含模型的参数(卷积层、线性层等),不包含模型的结构。

# 实例化模型
model = SomeModel()

# 获取模型的参数
state_dict = model.state_dict()

# 保存模型的参数
torch.save(state_dict, 'model_state_dict.pth')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

torch.nn.Module.load_state_dict

load_state_dict 是模型的一个方法,它将一个 state_dict 中的参数加载到模型中。这个方法通常与 torch.load 结合使用来恢复模型的状态。

# 实例化模型
model = SomeModel()

# 加载保存的模型的状态
state_dict = torch.load('model_state_dict.pth')

# 应用状态字典到模型中
model.load_state_dict(state_dict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在实际使用中,我们通常先训练好一个模型,然后调用它的 state_dict 方法来获取参数字典,并将这个字典通过 torch.save 函数保存到磁盘上。在需要的时候,我们可以用 torch.load加载保存的参数字典,并使用 load_state_dict将这些参数加载到模型中

值得注意的是,通常情况下应该只保存和加载 state_dict,而不是完整的模型。这是因为模型结构在代码中定义,只保存和加载参数可以使模型更加灵活,易于调整。此外,在跨版本或者跨环境使用时,state_dict 也更加稳定可靠。

load_state_dict

在深度学习中,load_state_dict 是用于加载模型参数的常用方法,特别是在PyTorch框架中。state_dict 是一个从参数名称映射到参数张量的字典对象。这么做的原理是将训练好的模型参数保存下来,以便之后的使用或继续训练。

当调用 load_state_dict 方法时,它会获取当前模型的所有层和对应参数,然后将传入的 state_dict 中的参数值加载到模型的各层中。这一过程中,在state_dict中的键(即参数名称)需要和当前模型中的键完全匹配,否则会触发错误。

使用 load_state_dict 的典型步骤如下:

  1. 定义模型结构:首先需要有一个和原来训练时相同结构的模型。
  2. 加载 state_dict:使用 torch.load 函数加载保存的模型参数。
  3. 应用 state_dict:将加载的 state_dict 应用到模型上。

下面是一个基本的使用示例:

import torch
import torch.nn as nn

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(20 * 12 * 12, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 20 * 12 * 12)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model = SimpleModel()

# 加载保存的state_dict
state_dict = torch.load('model_parameters.pth')

# 应用state_dict
model.load_state_dict(state_dict)

# 现在模型已经加载了保存的参数,可以用于预测或继续训练
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

在上述代码中,先是定义了一个神经网络模型 SimpleModel,然后通过 torch.load 加载了一个叫做 model_parameters.pth 的文件,该文件应该是一个先前训练好并保存下来的模型的 state_dict。最后,通过调用 load_state_dict 方法,将参数加载到定义好的模型中。

要注意的是,加载的 state_dict 必须与模型的结构严格对应。如果结构不一致,你会收到一个错误,指出无法匹配的键值。此外,load_state_dict 方法默认会严格执行键值匹配,但你可以设置 strict=False 以便在参数部分匹配时也能加载模型。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/785879
推荐阅读
相关标签
  

闽ICP备14008679号