当前位置:   article > 正文

pytorch模型的保存与加载_torch加载模型

torch加载模型

1 pytorch保存和加载模型的三种方法

PyTorch提供了三种方式来保存和加载模型,在这三种方式中,加载模型的代码和保存模型的代码必须相匹配,才能保证模型的加载成功。通常情况下,使用第一种方式(保存和加载模型状态字典)更加常见,因为它更轻量且不依赖于特定的模型类。

1.1 仅保存和加载模型参数(推荐)

1.1.1 保存模型参数

  1. import torch
  2. import torch.nn as nn
  3. model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
  4. # 保存整个模型
  5. torch.save(model.state_dict(), 'sample_model.pt')

1.1.2 加载模型参数

  1. import torch
  2. import torch.nn as nn
  3. # 下载模型参数 并放到模型中
  4. loaded_model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
  5. loaded_model.load_state_dict(torch.load('sample_model.pt'))
  6. print(loaded_model)

显示如下:

  1. Sequential(
  2. (0): Linear(in_features=128, out_features=16, bias=True)
  3. (1): ReLU()
  4. (2): Linear(in_features=16, out_features=1, bias=True)
  5. )

state_dict:PyTorch中的state_dict是一个python字典对象,将每个层映射到其参数Tensor。state_dict对象存储模型的可学习参数,即权重和偏差,并且可以非常容易地序列化和保存。

1.2 保存和加载整个模型

1.2.1 保存整个模型

  1. import torch
  2. import torch.nn as nn
  3. net = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
  4. # 保存整个模型,包含模型结构和参数
  5. torch.save(net, 'sample_model.pt')

1.2.2  加载整个模型

  1. import torch
  2. import torch.nn as nn
  3. # 加载整个模型,包含模型结构和参数
  4. loaded_model = torch.load('sample_model.pt')
  5. print(loaded_model)

显示如下:

  1. Sequential(
  2. (0): Linear(in_features=128, out_features=16, bias=True)
  3. (1): ReLU()
  4. (2): Linear(in_features=16, out_features=1, bias=True)
  5. )

1.3 导出和加载ONNX格式模型

1.3.1 保存模型

  1. import torch
  2. import torch.nn as nn
  3. model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
  4. input_sample = torch.randn(16, 128) # 提供一个输入样本作为示例
  5. torch.onnx.export(model, input_sample, 'sample_model.onnx')

1.3.2 加载模型

  1. import torch
  2. import torch.nn as nn
  3. import onnx
  4. import onnxruntime
  5. loaded_model = onnx.load('sample_model.onnx')
  6. session = onnxruntime.InferenceSession('sample_model.onnx')
  7. print(session)

2 模型保存与加载使用的函数

2.1 保存模型函数torch.save

将对象序列化保存到磁盘中,该方法原理是基于python中的pickle来序列化,各种Models,tensors,dictionaries 都可以使用该方法保存。保存的模型文件名可以是.pth, .pt, .pkl

  1. def save(
  2. obj: object,
  3. f: FILE_LIKE,
  4. pickle_module: Any = pickle,
  5. pickle_protocol: int = DEFAULT_PROTOCOL,
  6. _use_new_zipfile_serialization: bool = True
  7. ) -> None:
  • obj:保存的对象
  • f:一个类似文件的对象(必须实现写入和刷新)或字符串或操作系统。包含文件名的类似路径对象
  • pickle_module:用于挑选元数据和对象的模块
  • pickle_protocol:可以指定以覆盖默认协议

备注:关于模型的后缀.pt、.pth、.pkl它们并不存在格式上的区别,只是后缀名不同而已。 torch.save()语句保存出来的模型文件没有什么不同。

2.2 加载模型函数torch.load

  1. def load(
  2. f: FILE_LIKE,
  3. map_location: MAP_LOCATION = None,
  4. pickle_module: Any = None,
  5. *,
  6. weights_only: bool = False,
  7. **pickle_load_args: Any
  8. ) -> Any:
  • f:类文件对象 (返回文件描述符)或一个保存文件名的字符串
  • map_location:一个函数或字典规定如何映射存储设备,torch.device对象
  • pickle_module:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )

2.3 加载模型参数torch.nn.Module.load_state_dict

序列化 (Serialization)是将对象的状态信息转换为可以存储或传输的形式的过程。 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象。

  1. def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
  2. strict: bool = True):
  • state_dict:保存 parameters 和 persistent buffers 的字典
  • strict:可选,bool型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。

2.4 状态字典state_dict

函数作用是“获取优化器当前状态信息字典”,在神经网络中模型上训练出来的模型参数,也就是权重和偏置值。在Pytorch中,定义网络模型是通过继承torch.nn.Module来实现的。其网络模型中包含可学习的参数(weights, bias, 和一些登记的缓存如batchnorm’s running_mean 等)。模型内部的可学习参数可通过两种方式进行调用:

  • 通过model.parameters()这个生成器来访问所有参数。
  • 通过model.state_dict()来为每一层和它的参数建立一个映射关系并存储在字典中,其键值由每个网络层和其对应的参数张量构成。
def state_dict(self, destination=None, prefix='', keep_vars=False):

除模型外,优化器对象(torch.optim)同样也有一个状态字典,包含的优化器状态信息以及使用的超参数。由于状态字典属于Python 字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都比较便捷。

2.5 指定map_location加载模型

采用仅加载模型参数的方式,指定设备类型进行模型加载,代码如下:

  1. model_path = '/opt/sample_model.pth'
  2. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  3. map_location = torch.device(device)
  4. model.load_state_dict(torch.load(self.model_path, map_location=self.map_location))

3 模型知识补充

3.1 模型的保存与加载到底在做什么

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

  • 模型代码:
    • 包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
    • 包含了我们如何定义的训练过程,包括epoch batch_size等参数;
    • 包含了我们如何加载数据和使用;
    • 包含了我们如何测试评估模型。
  • 模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
    • 包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
    • 包含optimizer.state_dict(),这是模型的优化器中的参数;
    • 包含我们其他参数信息,如epoch/batch_size/loss等。
  • 数据集:
    • 包含了我们训练模型使用的所有数据;
    • 可以提示对方如何去准备同样格式的数据来训练模型。
  • 使用文档:
    • 根据使用文档的步骤,每个人都可以重现模型;
    • 包含了模型的使用细节和我们相关参数的设置依据等信息。

可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,别人就可以复现整个模型。

3.2 为什么要约定格式

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。

torch.save
: Saves a serialized object to disk. This function uses Python’s pickle
utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的README.md才知道。而在使用torch.load()方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于torch.load()方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取。

torch.load
: Uses pickle
’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into

“一切皆文件”的思维才是正确打开计算机世界的思维方式,文件后缀只作为提示作用,在Windows系统中也会用于提示系统默认如何打开或执行文件,除此之外,文件后缀不应该成为我们认识和了解文件阻碍。

3.3 格式汇总

下面是一个整理了 .pt.pth.bin、ONNX 和 TorchScript 等 PyTorch 模型文件格式的表格:

格式解释适用场景可对应的后缀
.pt 或 .pthPyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息。需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型。.pt 或 .pth
.bin一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据。需要将 PyTorch 模型转换为通用的二进制格式的场景。.bin
ONNX一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式。需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景。.onnx
TorchScriptPyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.jit.trace 或 torch.jit.script 函数将 PyTorch 模型转换为 TorchScript 格式。需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景。.pt 或 .pth

.pt .pth格式

一个完整的Pytorch模型文件,包含了如下参数:

  • model_state_dict:模型参数
  • optimizer_state_dict:优化器的状态
  • epoch:当前的训练轮数
  • loss:当前的损失值
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/537909
推荐阅读
相关标签
  

闽ICP备14008679号