赞
踩
本篇文章给大家介绍一下如何在pytorch中保存和加载模型,对于这一个知识点,pytorch官方文档其实已经介绍的比较清楚,详情可以参考: 保存和加载模型 — PyTorch 教程 。大家可以将官方文档与本文章进行对照学习。
一、三个核心功能函数
在介绍pytorch中保存和加载模型的方法前,我们需要先了解三个核心功能函数:
1、torch.save(): 将序列化对象保存到对应路径。此函数将模型序列化为Python中的Pickle格式。各种模型、张量和字典都可以使用此功能保存对象。
2、torch.load():将相应的序列化对象反序列化到内存中。
3、torch.nn.Module.load_state_dict():使用反序列化state_dict加载模型的参数字典。
看完上面三个函数的功能介绍,也许会有一些小伙伴比较懵,这里简单解释一下什么是序列化和反序列化。
序列化:将某一对象转换为某一标准化的格式;
反序列化:复原序列化后的对象数据,并保持数据一致性,即原有内容不变。
简单来说,我们保存模型即为序列化的过程,而加载调用模型即为反序列化的过程。它的意义在于不仅可保持对象原有的内容和规范、减少资源的占用,同时可使被序列化后的对象跨程序甚至跨平台调用。
二、模型的保存和加载
我们根据以下例子说明一下pytorch中几种保存和加载模型的方法
# 搭建网络模型 class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # 实例化模型 model = TheModelClass()
方法一(官方推荐):
保存模型:
torch.save(model.state_dict(), 'model_path')
'model_path' 为模型保存路径,保存模型的后缀名推荐为 '.pth' 或 '.pt'
加载模型:
- # 实例化模型
- model = TheModelClass()
- # 加载模型
- model.load_state_dict(torch.load('model_path'))
- # 模型调整为验证模式
- model.eval()
注意:model.load_state_dict()函数接收的是一个字典对象,因此不能直接将序列化后的对象传入。我们需要先使用torch.load()函数对模型对象进行反序列化,以此加载模型的参数字典,再传入model.load_state_dict()函数中进行接收。
方法二:
保存模型:
torch.save(model, 'model_path')
加载模型:
- model = torch.load('model_path')
- model.eval()
这种保存和加载模型的方式很简单,但并不推荐,这是因为此方式在保存模型时会将模型的参数数据绑定到特定类和确切的目录结构中。因此如果采用此方式保存模型,那么在加载模型时很可能会因为当前代码所在文件的目录结构与原目录结构不一致而导致各种报错,比较常见的错误如下图所示:
可以看到上图报错的原因是导入的模型无法获取到module文件中的main函数,如果这种错误是在同一个项目文件中加载模型时产生的,那么可以将保存模型的文件引入到加载模型的代码中解决,即导入下面一句代码:
from 保存模型的代码文件名 import *
但如果是将保存的模型用于不同的项目之中呢?我们就需要将原项目中所有有关保存模型的代码都放在新项目中,再导入使用,虽然这是一种解决办法,但这是没有必要的。因此方法2是不推荐大家使用的。
方法三:
保存模型:
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': loss,
- ...
- }, 'model_path')
这种保存模型的方法不仅将模型以字典的形式进行了保存,同时还将模型迭代训练轮数:epoch、定义的优化器:otimizer、损失值:loss等都进行了保存,因此这种保存方法很有利于我们因各种原因而导致训练中止后继续恢复训练。
加载模型:
- # 实例化模型
- model = TheModelClass()
- # 定义优化器(不唯一)
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
-
- # 加载模型
- checkpoint = torch.load('model_path')
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
这种保存和加载模型的方法在很多目前主流的模型框架中是存在的,但是在我们日常模型训练过程中我们可能并未在意,因此有很多小伙伴可能不知道如何使用,这里给一个简单的示例进行解释。
- import torch
- import torchvision
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils.data import DataLoader
- import torchvision.transforms as transforms
-
- # 加载数据集
- train_dataset = torchvision.datasets.ImageFolder(root=r"你的训练数据集路径",
- transform=torchvision.transforms.ToTensor())
- test_dataset = torchvision.datasets.ImageFolder(root=r"你的测试数据集路径",
- transform=torchvision.transforms.ToTensor())
-
- # 导入数据集
- train_data = DataLoader(dataset=train_dataset, batch_size=20,
- shuffle=True, num_workers=0)
- test_data = DataLoader(dataset=test_dataset, batch_size=10,
- shuffle=True, num_workers=0)
-
- # 搭建模型
- class TheModelClass(nn.Module):
- def __init__(self):
- super(TheModelClass, self).__init__()
- self.conv1 = nn.Conv2d(3, 6, 5)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
-
- def forward(self, x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1, 16 * 5 * 5)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
-
-
- # 实例化模型
- model = TheModelClass()
- # 获取GPU设备
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # 定义训练次数,学习率,损失函数与优化器
- epoch = 50 # 迭代训练次数
- learning_rate = 0.001 # 学习率
- Loss_func = nn.CrossEntropyLoss() # 损失函数
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 优化器
- model.to(device)
-
- # 加载意外中断训练的模型
- # checkpoint = torch.load('保存的训练中止模型') # 反序列化模型
- # model.load_state_dict(checkpoint['model_state_dict'])
- # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- # start_epoch = checkpoint['epoch']
- # loss = checkpoint['loss']
-
- # 训练模型
- for epoch in range(epoch):
- train_num = 0 # 训练集样本总数初始设为0
- model.train() # 调整为训练模式
- for step, (data, target) in enumerate(train_data):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
- output = model(data)
- loss_value = Loss_func(output, target)
- loss_value.backward()
- optimizer.step()
- train_num += data.size(0)
-
- # 输出当前训练轮数与模型loss值
- print(f'当前为第{epoch + 1}轮训练 '
- f'Loss is {loss_value.item():.4f}')
-
- # 每轮训练结束后保存模型
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': loss_value,
- }, './model_epoch_{}.pth'.format(epoch))
- print()

上述代码可以在模型每次训练结束后都保存一次模型,如果我们因某种原因而导致训练中止或者在设定的训练轮数结束后发现模型仍未收敛,都可以接着上一次保存下来的模型继续进行训练。
注意:此处只是一个示例,在实际训练过程中每隔多少轮进行一次模型的保存由大家设定的训练轮数决定,如果设定的epoch比较大,那么就无需每次训练结束后都进行一次模型的保存,这样会增加众多资源的开销。
假设我们在训练第10轮后因某种原因而导致训练中止,这时候代码就会保存第10次训练的模型,若我们需要接着第10轮保存的模型继续进行训练,则只需在模型训练前写入以下加载模型的代码:
- # 加载意外中断的第10轮模型
- checkpoint = torch.load('./model_epoch_10.pth') # 反序列化模型
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- start_epoch = checkpoint['epoch']
- loss = checkpoint['loss']
好了,以上就是pytorch中保存和加载模型的几个主要方法,不过本篇文章还未结束,假设我们有多个模型,且多个模型都要保存在同一文件中呢?其实这种情况pytorch官方文档中也给出了解决办法,如下:
保存模型:
- torch.save({
- 'modelA_state_dict': modelA.state_dict(),
- 'modelB_state_dict': modelB.state_dict(),
- 'optimizerA_state_dict': optimizerA.state_dict(),
- 'optimizerB_state_dict': optimizerB.state_dict(),
- ...
- }, 'model_path')
加载模型:
- # 实例化每一个模型对象
- modelA = TheModelAClass()
- modelB = TheModelBClass()
- # 分别定义每个模型的优化器对象
- optimizerA = TheOptimizerAClass()
- optimizerB = TheOptimizerBClass()
-
- # 分别加载每一个模型
- checkpoint = torch.load('modle_path')
- modelA.load_state_dict(checkpoint['modelA_state_dict'])
- modelB.load_state_dict(checkpoint['modelB_state_dict'])
- optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
- optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
-
- modelA.eval()
- modelB.eval()
- # - or -
- modelA.train()
- modelB.train()

OK,以上就是本次文章的全部内容了,建议大家使用方法一和方法三来保存并加载模型,感谢大家的阅读!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。