当前位置:   article > 正文

通俗易懂理解PyTorch模型的保存和加载_pytorch训练的模型如何保存为非zip格式

pytorch训练的模型如何保存为非zip格式

温故而知新,可以为师矣!

一、参考资料

[译]保存和加载模型
save and load PyTorch tensors
碎片篇——Pytorch模型 .pt, .pth, .pkl的区别及模型不同保存加载方式的区别
SAVING AND LOADING MODELS
pytorch模型保存、加载与续训练
PyTorch保存和加载模型详解(一)

二、相关介绍

1. 序列化与反序列化

序列化是把内存中的数据保存到磁盘中,保存模型就是序列化;而反序列化则是将硬盘中的数据加载到内存当中,加载模型的过程就是反序列化过程。

在这里插入图片描述

2. pkl模型格式

加载不同pytorch版本之间,训练得到的pkl文件

2.1 引言

我们将服务器上训练好的pkl文件下载到本地电脑上,经常会由于pytorch版本不统一问题,例如出现这种问题“_pickle.UnpicklingError: A load persistent id instruction was encountered…”

而无法直接加载,因此需要将pkl文件转成pth文件,这样不同版本的pytorch就可以互相加载。

2.2 pkl转成pth

示例1:

import pickle as pkl
import torch
 
#服务器端上的pkl保存位置
info_dict = torch.load('VGG_pre/VGG_4.pkl') 
with open('pkl_model_vgg16.pth', 'wb') as f: 
    #转成pth文件,并以pkl_model_vgg16.pth命名
    pkl.dump(info_dict, f)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

示例2:

import torch
# An instance of your model.
from models.MIMOUNet import MIMOUNetPlus
import torchvision.transforms as transforms
 
model = MIMOUNetPlus()
state_dict = torch.load('model.pkl') # pkl模型文件路径
model.load_state_dict(state_dict) 

model.eval()
 
example = torch.rand(1, 3, 256, 256) # 与训练模型时输入大小相同
 
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt") # 更换为自己的模型输出路径
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

2.3 pkl与pth对比

如果使用 torch.save() 方法来进行模型参数的保存,那保存文件的后缀其实没有任何影响,结果都是一样的,很多 .pkl 的文件也是用 torch.save()保存下来的,和 .pth 文件一模一样的。但是,这两种格式的文件有以下区别:

  • .pkl 文件是python里面保存文件的一种格式,如果直接打开会显示一堆序列化的东西,就是以二进制形式存储的,如果去read这些文件,则需要用 rb 而不是 r 模式。
  • Python在遍历已知的库文件目录过程中,如果见到一个 .pth 文件,就会将文件中所记录的路径加入到 sys.path() 设置中,于是 .pth 文件指明的库也就可以被 Python 运行环境找到。

三、模型保存与加载

后缀名为 .pt.pth.pklpytorch模型文件,在格式上没有区别,只是后缀名不同而已(仅此而已)!

在用 torch.save() 函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth.pkl。用相同的 torch.save() 语句保存出来的模型文件没有什么不同。

0. 搭建网络模型

import torch
import torch.nn as nn
import torch.nn.functional as F


# 模型定义
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
#模型初始化
model = Net()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

1. 方式一(推荐)

仅保存模型权重参数,不保存模型结构。

这种方式需要自己定义网络,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用VGG的前几层),相对灵活,便于对网络进行修改。

1.1 保存模型

torch.save 将序列化的对象保存到disk。这个函数使用Python的pickle库进行序列化。

# 保存模型
torch.save(model.state_dict(), '/PATH/TO/params.pt') 
  • 1
  • 2

1.2 加载模型

  • torch.load:使用pickle库的unpickle工具将pickle对象文件反序列化到内存,也就是将保存的模型参数反序列化。
  • torch.nn.Module.load_state_dict:加载模型的参数字典。
# 加载模型
# 实例化model模型,重构模型结构
model = My_model(*args, **kwargs)  
# 根据模型结构,加载模型参数
model.load_state_dict(torch.load('/PATH/TO/params.pt'))  
model.eval()  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2. 方式二(不推荐)

保存/载入整个pytorch模型

这种方式无需自定义网络,保存时已把网络结构保存,比较死板,不能调整网络结构。

2.1 保存模型

# 保存
torch.save(model, '/PATH/TO/mymodel.pth')
  • 1
  • 2

2.2 加载模型

# 加载
# 不需要重构模型结构,直接load即可
model = torch.load('/PATH/TO/mymodel.pth')  
model.eval()
  • 1
  • 2
  • 3
  • 4

优点:

  • 以这种方式保存模型将使用Pythonpickle模块保存整个model的状态;
  • 保存/加载过程使用最直观的语法,涉及的代码量最少;

缺点:

  • 序列化数据绑定到特定的类;
  • 保存模型时,使用确切目录结构。因此,当在其他项目中使用或重构后,代码可能会以各种方式中断。

2.3 通俗理解

这种方式是不推荐使用的,因为使用这种方式保存模型,在加载时会遇到各种各样的错误。为了加深大家理解,举一个例子。
文件的结构如下图所示,models.py文件是模型的定义,其位于models文件夹下。save_model.py文件是保存模型的代码,load_mode.py文件是加载模型的代码,

在这里插入图片描述

2.3.1 save_model.py

执行 save_model.py 文件可保存模型,生成models.pt文件。

from models.models import Net
from torch import optim
import torch


# 实例化模型
model = Net()

# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 保存模型
torch.save(model, './models/models.pt')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
2.3.2 load_mode.py

执行load_mode.py文件可加载模型。

from models.models import Net
import torch


# 加载模型
model_test2 = Net()
model_test2 = torch.load('./models/models.pt')     
model_test2.eval() 
print(model_test2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

此时可以正常加载。但如果我们将models文件夹修改为model,如下所示:

在这里插入图片描述

from models.models import Net
import torch


# 加载模型
model_test2 = Net()
model_test2 = torch.load('./model/models.pt')     #这里需要修改模型文件路径  
model_test2.eval()
print(model_test2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

此时加载模型就会出现错误:

在这里插入图片描述
出现这种错误的原因是,使用本方式进行模型保存的时候,会把模型结构定义文件路径记录下来,加载的时候就会根据路径解析加载参数;当把模型定义文件路径修改以后,使用torch.load(path) 就会报错。

其实,使用本方式进行模型的保存和加载还会存在各种问题,感兴趣的可以看看这篇 torch模型保存和加载中的一些问题记录 。总之,在我们今后的使用中,尽量不要用本方式来加载模型。

3. 方式三(推荐)

如果因为某种原因导致训练异常中止,采用checkpoint方式可以很方便的接着上次继续训练。正因为这样,非常推荐使用这种方式进行模型的保存与加载。

3.1 保存模型

# 保存checkpoint
torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            'loss':loss
            }, './model/model_checkpoint.tar'    #这里的后缀名官方推荐使用.tar
            )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3.2 加载模型

# 加载checkpoint
model = Net()
optimizer =  torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('./model/model_checkpoint.tar')    # 先反序列化模型
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3.3 代码示例

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1、准备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)

# 2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)

# 3、搭建神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, input):
        input = self.model1(input)
        return input


# 4、创建网络模型
model = Net()
model.to(device)

# 5、设置损失函数
loss_fun = nn.CrossEntropyLoss()   #交叉熵函数

# 设置优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), learning_rate)   #SGD:梯度下降算法

# 6、设置网络训练中的一些参数
Max_epoch = 10    #设置训练轮数
total_train_step = 0   #记录总训练次数
total_test_step = 0    #记录总测试次数

# 7、开始进行训练
for epoch in range(Max_epoch):
    print("---第{}轮训练开始---".format(epoch))

    model.train()     #开始训练,不是必须的,在网络中有BN,dropout时需要
    #由于训练集数据较多,这里我没用训练集训练,而是采用测试集(test_dataset_loader)当训练集,但思想是一致的
    for data in test_dataset_loader:  # 遍历所有batch
        imgs, targets = data
        imgs, targets = imgs.to(device), targets.to(device)
        
        #反向传播,更新参数
        optimizer.zero_grad()  # 重置每个batch的梯度
        outputs = model(imgs)  # 前向传播计算预测值
        loss = loss_fun(outputs, targets) # 计算当前损失
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新所有的参数

        total_train_step += 1

        if total_train_step % 50 == 0:
            print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))

    if epoch > 5:
        print("---意外中断---")
        break
	
	if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 这里的后缀名官方推荐使用.tar
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

两个epoch保存一次。当epoch=6时,设置一个break模拟程序意外中断,中断后可以来看到终端的输出信息,如下图所示:

在这里插入图片描述

从上图可以看到,在进行第6轮循环时,程序中断,此时最新的保存的模型是第五次训练结果,如下图所示:

在这里插入图片描述
同时注意到第5次训练结束的loss在2.0左右,如果下次接着训练,损失应该是在2.0附近。

此时,接着上次训练的结果继续训练,代码如下所示:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1、准备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)

# 2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)

# 3、搭建神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, input):
        input = self.model1(input)
        return input


# 4、创建网络模型
model = Net()
model.to(device)

# 5、设置损失函数
loss_fun = nn.CrossEntropyLoss()   #交叉熵损失

# 设置优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), learning_rate)   #SGD:梯度下降算法

# 6、设置网络训练中的一些参数
Max_epoch = 10    #设置训练轮数
total_train_step = 0   #记录总训练次数
total_test_step = 0    #记录总测试次数

##########################################################################################
# 加载checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar')    # 先反序列化模型
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################

# 7、开始进行训练
for epoch in range(start_epoch+1, Max_epoch):
    print("---第{}轮训练开始---".format(epoch))

    model.train()     #开始训练,不是必须的,在网络中有BN,dropout时需要
    for data in test_dataset_loader:  # 遍历所有batch
        imgs, targets = data
        imgs, targets = imgs.to(device), targets.to(device)

        #反向传播,更新参数
        optimizer.zero_grad()  # 重置每个batch的梯度
        outputs = model(imgs)  # 前向传播计算预测值
        loss = loss_fun(outputs, targets)  # 计算当前损失
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新所有的参数

        total_train_step += 1

        if total_train_step % 50 == 0:
            print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))

    if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 这里的后缀名官方推荐使用.tar
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

这里的代码相较之前的多了一个加载checkpoint的过程,将其截取出来,如下所示:

##########################################################################################
# 加载checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar')    # 先反序列化模型
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

通过加载checkpoint,就保存了之前训练的参数,进而实现断点继续训练。执行该代码,结果如下图所示:

在这里插入图片描述
从上图可以看出,训练是从第6轮开始的,并且初始的loss为1.99,和2.0接近。这就说明已经实现了中断后恢复训练的操作。

三、可能出现的问题

  • pytorch模型加载问题

    raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
    RuntimeError: yolov5s.pt is a zip archive (did you mean to use torch.jit.load()?)
    
    • 1
    • 2
    参考资料 [RuntimeError: xxx.pth is a zip archive (did you mean to use torch.jit.load()?)](https://blog.csdn.net/studyeboy/article/details/116451980)
    错误原因:
    pytorch版本不匹配的问题。比如,用torch 1.8.1训练保存的模型,用torch 1.1.0进行模型加载。PyTorch的1.6版本将torch.save切换为使用新的基于zipfile的文件格式,但 torch.load仍然保留以旧格式加载文件的功能。
    
    方法一:
    保存模型时,传递kwarg _use_new_zipfile_serialization = False参数,使用旧格式。
    torch.save(net.state_dict(), 'model.pth', _use_new_zipfile_serialization=False)
    
    方法二:
    保存模型时,将压缩格式转换为非压缩格式
    import torch
    from model import U2NET
    
    net = U2NET(3, 1)
    state_dict = torch.load('model.pth')
    net.load_state_dict(state_dict)
    torch.save(net.state_dict(), 'model.pth',_use_new_zipfile_serialization=False)
    
    方法三:
    下载新版本的pytorch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号