赞
踩
best_state_dict = model.state_dict() # 训练前
best_state_dict = model.state_dict() # 训练时更新最优state_dict
完整代码:
# 初始化一个变量来保存最优的state_dict best_state_dict = model.state_dict() for epoch in range(epochs): model.train() # 训练集上训练模型权重 for data, targets in tqdm.tqdm(train_dataloader): # 把数据加载到GPU上 data = data.to(devices[0]) targets = targets.to(devices[0]) # 前向传播 preds = model(data) loss = criterion(preds, targets) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 测试集上评估模型性能 model.eval() num_correct = 0 num_samples = 0 with torch.no_grad(): for x, y in tqdm.tqdm(test_dataloader): x = x.to(devices[0]) y = y.to(devices[0]) preds = model(x) predictions = preds.max(1).indices # 返回每一行的最大值和该最大值在该行的列索引 num_correct += (predictions == y).sum() num_samples += predictions.size(0) acc = (num_correct / num_samples).item() if acc > best_acc: best_acc = acc best_epoch = epoch+1 # 保存模型最优准确率的参数 best_state_dict = model.state_dict() # 更新最优state_dict model.train() # 训练结束保存 torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")
if acc > best_acc:
best_acc = acc
best_epoch = epoch+1
torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")
copy模块可以用来创建一个对象的深拷贝。这意味着复制后的模型和原始模型是完全独立的,包括它们的参数。
import torch
import copy
import torch.nn as nn
# 假设我们有一个模型实例
original_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# 复制模型
model_copy = copy.deepcopy(original_model)
best_model = copy.deepcopy(model.state_dict()) # 训练前
best_model = copy.deepcopy(model.state_dict()) # 训练时更新最优state_dict
代码案例:
def fit_zsl(self): best_acc = 0 mean_loss = 0 last_loss_epoch = 1e8 # 定义best_model best_model = copy.deepcopy(self.model.state_dict()) for epoch in range(self.nepoch): for i in range(0, self.ntrain, self.batch_size): self.model.zero_grad() batch_input, batch_label = self.next_batch(self.batch_size) self.input.copy_(batch_input) self.label.copy_(batch_label) inputv = Variable(self.input) labelv = Variable(self.label) output = self.model(inputv) loss = self.criterion(output, labelv) mean_loss += loss.item() loss.backward() self.optimizer.step() acc = self.val( self.test_unseen_feature, self.test_unseen_label, self.unseenclasses, ) if acc > best_acc: best_acc = acc # 更新best_model best_model = copy.deepcopy(self.model.state_dict()) #训练完毕本地保存 torch.save(best_model.state_dict(), f"weights/{self.nepoch}_{best_acc}.pth") return best_acc, best_model
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。