当前位置:   article > 正文

深度学习在训练时更新和保存最佳训练结果的方法(字典方法,本地保存方法,模型深拷贝方法)_深度学习怎么存储epoch

深度学习怎么存储epoch

1.用参数字典 model.state_dict()更新最优参数

best_state_dict = model.state_dict()  # 训练前
best_state_dict = model.state_dict()  # 训练时更新最优state_dict
  • 1
  • 2

完整代码:

 # 初始化一个变量来保存最优的state_dict
  best_state_dict = model.state_dict()
  for epoch in range(epochs):
      model.train()
      # 训练集上训练模型权重
      for data, targets in tqdm.tqdm(train_dataloader):
          # 把数据加载到GPU上
          data = data.to(devices[0])
          targets = targets.to(devices[0])

          # 前向传播
          preds = model(data)
          loss = criterion(preds, targets)

          # 反向传播
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

      # 测试集上评估模型性能
      model.eval()
      num_correct = 0
      num_samples = 0
      with torch.no_grad():
          for x, y in tqdm.tqdm(test_dataloader):
              x = x.to(devices[0])
              y = y.to(devices[0])
              preds = model(x)
              predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引
              num_correct += (predictions == y).sum()
              num_samples += predictions.size(0)
          acc = (num_correct / num_samples).item()
          if acc > best_acc:
              best_acc = acc
              best_epoch = epoch+1
              # 保存模型最优准确率的参数
              best_state_dict = model.state_dict()  # 更新最优state_dict
      model.train()
  # 训练结束保存
  torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

2.训练过程中保存最优参数

if acc > best_acc:
    best_acc = acc
    best_epoch = epoch+1
    torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")
  • 1
  • 2
  • 3
  • 4

3.对模型深拷贝方法保存最优模型

深拷贝方法介绍

copy模块可以用来创建一个对象的深拷贝。这意味着复制后的模型和原始模型是完全独立的,包括它们的参数。

import torch  
import copy  
import torch.nn as nn  
  
# 假设我们有一个模型实例  
original_model = nn.Sequential(  
    nn.Linear(10, 5),  
    nn.ReLU(),  
    nn.Linear(5, 2)  
)  
  
# 复制模型  
model_copy = copy.deepcopy(original_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

深拷贝方法保存最优模型

best_model = copy.deepcopy(model.state_dict())  # 训练前
best_model = copy.deepcopy(model.state_dict())  # 训练时更新最优state_dict
  • 1
  • 2

代码案例:

   def fit_zsl(self):
        best_acc = 0
        mean_loss = 0
        last_loss_epoch = 1e8
        # 定义best_model
        best_model = copy.deepcopy(self.model.state_dict())
        for epoch in range(self.nepoch):
            for i in range(0, self.ntrain, self.batch_size):
                self.model.zero_grad()
                batch_input, batch_label = self.next_batch(self.batch_size)
                self.input.copy_(batch_input)
                self.label.copy_(batch_label)

                inputv = Variable(self.input)
                labelv = Variable(self.label)
                output = self.model(inputv)
                loss = self.criterion(output, labelv)
                mean_loss += loss.item()
                loss.backward()
                self.optimizer.step()
            acc = self.val(
                self.test_unseen_feature,
                self.test_unseen_label,
                self.unseenclasses,
            )
            if acc > best_acc:
                best_acc = acc
                # 更新best_model
                best_model = copy.deepcopy(self.model.state_dict())
        #训练完毕本地保存
		torch.save(best_model.state_dict(), f"weights/{self.nepoch}_{best_acc}.pth")
        return best_acc, best_model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/207702
推荐阅读
相关标签
  

闽ICP备14008679号