当前位置:   article > 正文

Pytorch之Save&Load(保存和加载模型)_torch.save

torch.save

模型的保存和加载都在系列化的模块下

先看保存的

更详细的可以参考这里https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing
torch.save()并torch.load()让您轻松保存和加载张量:最简单的就是

  1. t = torch.tensor([1., 2.])
  2. torch.save(t, 'tensor.pth')
  3. torch.load('tensor.pth')

按照惯例,PyTorch 文件通常使用“.pt”或“.pth”扩展名编写。
torch.save()并torch.load()默认使用 Python 的 pickle,因此您还可以将多个张量保存为 Python 对象(如元组、列表和字典)的一部分:

  1. >>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
  2. >>> torch.save(d, 'tensor_dict.pth')
  3. >>> torch.load('tensor_dict.pth')
  4. {'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

如果数据结构是pickle允许的格式,也可以保存包含 PyTorch 张量的自定义数据结构。

保存张量并保留视图关系

  1. >>> numbers = torch.arange(1, 10)
  2. >>> evens = numbers[1::2]
  3. >>> torch.save([numbers, evens], 'tensors.pt')
  4. >>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
  5. >>> loaded_evens *= 2
  6. >>> loaded_numbers
  7. tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])

当 PyTorch 保存张量时,会分别保存模型和张量数据。
在某些情况下,保存模型可能是不必要的,并且会创建过大的文件。在以下代码段中,将比保存的张量大得多的模型写入文件:

  1. >>> large = torch.arange(1, 1000)
  2. >>> small = large[0:5]
  3. >>> torch.save(small, 'small.pth')
  4. >>> loaded_small = torch.load('small.pth')
  5. >>> loaded_small.storage().size()
  6. 999

在仅保存5个张量的“small.pth”储存999个值也进行了保存与加载。

当保存元素少于其存储对象的张量时,可以通过首先克隆张量来减小保存文件的大小。克隆张量会产生一个新的张量,并带有一个仅包含张量中值的新存储对象:

  1. >>> large = torch.arange(1, 1000)
  2. >>> small = large[0:5]
  3. >>> torch.save(small.clone(), 'small.pth') # saves a clone of small
  4. >>> loaded_small = torch.load('small.pth')
  5. >>> loaded_small.storage().size()
  6. 5

然而,由于克隆的张量彼此独立,因此它们没有原始张量所具有的任何视图关系。

保存和加载 torch.nn.Modules

在 PyTorch 中,模块的状态经常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持续缓冲区:

  1. >>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
  2. >>> list(bn.named_parameters())
  3. [('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
  4. ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
  5. >>> list(bn.named_buffers())
  6. [('running_mean', tensor([0., 0., 0.])),
  7. ('running_var', tensor([1., 1., 1.])),
  8. ('num_batches_tracked', tensor(0))]
  9. >>> bn.state_dict()
  10. OrderedDict([('weight', tensor([1., 1., 1.])),
  11. ('bias', tensor([0., 0., 0.])),
  12. ('running_mean', tensor([0., 0., 0.])),
  13. ('running_var', tensor([1., 1., 1.])),
  14. ('num_batches_tracked', tensor(0))])

出于兼容性原因,建议不要直接保存模块,而是只保存其状态字典。Python 模块的load_state_dict()可以从状态字典中恢复它们的状态:

  1. >>> torch.save(bn.state_dict(), 'bn.pth')
  2. >>> bn_state_dict = torch.load('bn.pth')
  3. >>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
  4. >>> new_bn.load_state_dict(bn_state_dict)
  5. <All keys matched successfully>

注意state_dict先从其文件中加载torch.load()然后使用load_state_dict() 自定义模块和包含其他模块的模块也有state_dict并且可以使用这种模式:

  1. # A module with two linear layers
  2. >>> class MyModule(torch.nn.Module):
  3. def __init__(self):
  4. super(MyModule, self).__init__()
  5. self.l0 = torch.nn.Linear(4, 2)
  6. self.l1 = torch.nn.Linear(2, 1)
  7. def forward(self, input):
  8. out0 = self.l0(input)
  9. out0_relu = torch.nn.functional.relu(out0)
  10. return self.l1(out0_relu)
  11. >>> m = MyModule()
  12. >>> m.state_dict()
  13. OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
  14. [-0.3289, 0.2827, 0.4588, 0.2031]])),
  15. ('l0.bias', tensor([ 0.0300, -0.1316])),
  16. ('l1.weight', tensor([[0.6533, 0.3413]])),
  17. ('l1.bias', tensor([-0.1112]))])
  18. >>> torch.save(m.state_dict(), 'mymodule.pt')
  19. >>> m_state_dict = torch.load('mymodule.pt')
  20. >>> new_m = MyModule()
  21. >>> new_m.load_state_dict(m_state_dict)
  22. <All keys matched successfully>

算了,上面都是文档
先有如下已经用烂的模型

  1. import torch
  2. from torch import nn
  3. class Module(nn.Module):
  4. def __init__(self):
  5. super(Module, self).__init__()
  6. self.model = nn.Sequential(
  7. nn.Conv2d(3, 16, 5),
  8. nn.MaxPool2d(2, 2),
  9. nn.Conv2d(16, 32, 5),
  10. nn.MaxPool2d(2, 2),
  11. nn.Flatten(), # 注意一下,线性层需要进行展平处理
  12. nn.Linear(32*5*5, 120),
  13. nn.Linear(120, 84),
  14. nn.Linear(84, 10)
  15. )
  16. def forward(self, x):
  17. x = self.model(x)
  18. return x

然后用以下方法加载和保存

  1. Module = Module()
  2. # 保存方式1,模型结构+张量
  3. torch.save(Module, "module.pth")
  4. # 保存方式2,张量(推荐)
  5. torch.save(Module.state_dict(), "module_state_dict.pth")
  6. # 加载方式1 对应保存方式1,同时加载模型结构+张量
  7. load_module = torch.load("module.pth")
  8. # 加载方式2 对应保存方式2,加载模型后加载张量(必须先实例化模型)
  9. Module.load_state_dict(torch.load("module_state_dict.pth"))
  10. print(module)

至于完整的操作在下面一节统一说

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/772751
推荐阅读
相关标签
  

闽ICP备14008679号