当前位置:   article > 正文

PyTorch 如何利用多个损失开展深度神经网络的训练过程【持续更新】_pytorch训练模型使用多个损失函数

pytorch训练模型使用多个损失函数

咱们直接进入正题!

  1. def train(model, loss1, loss2, train_dataloader, optimizer_loss1, optimizer_loss2, epoch, writer, device_num):
  2. model.train()
  3. device = torch.device("cuda:"+str(device_num))
  4. correct = 0
  5. value_loss1 = 0
  6. value_loss2 = 0
  7. result_loss = 0
  8. for data_nnl in train_dataloader:
  9. data, target = data_nnl
  10. target = target.long()
  11. if torch.cuda.is_available():
  12. data = data.to(device)
  13. target = target.to(device)
  14. optimizer_loss1.zero_grad()
  15. optimizer_loss2.zero_grad()
  16. output = model(data)
  17. classifier_output = F.log_softmax(output[1], dim=1)
  18. value_loss1_batch = loss1(classifier_output, target) //第一个损失项
  19. value_loss2_batch = loss2(output[0], target) //第二个损失项
  20. weight_loss2 = 0.005
  21. result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch
  22. result_loss_batch.backward()
  23. optimizer_loss1.step()
  24. for param in loss2.parameters():
  25. param.grad.data *= (1. / weight_loss2)
  26. optimizer_loss2.step()

我这里采用的是两项损失,loss1用于优化网络权重,loss2用于优化中心矢量,二者均是可训练的超参,因此包含两个优化器,如果多个损失项均用于优化网络权重,那么只采用一个优化器即可,如下所示

  1. def train(model, loss1, loss2, train_dataloader, optimizer, epoch, writer, device_num):
  2. model.train()
  3. device = torch.device("cuda:"+str(device_num))
  4. correct = 0
  5. value_loss1 = 0
  6. value_loss2 = 0
  7. result_loss = 0
  8. for data_nnl in train_dataloader:
  9. data, target = data_nnl
  10. target = target.long()
  11. if torch.cuda.is_available():
  12. data = data.to(device)
  13. target = target.to(device)
  14. optimizer.zero_grad()
  15. output = model(data)
  16. classifier_output = F.log_softmax(output[1], dim=1)
  17. value_loss1_batch = loss1(classifier_output, target) //第一个损失项
  18. value_loss2_batch = loss2(output[0], target) //第二个损失项
  19. weight_loss2 = 0.005
  20. result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch
  21. result_loss_batch.backward()
  22. optimizer.step()

详细代码,请翻阅我们的论文,代码已开源,开源链接可查论文摘要。

若该经验贴对您科研、学习有所帮助,欢迎您引用我们的论文。

[1] X. Fu et al., "Semi-Supervised Specific Emitter Identification Method Using Metric-Adversarial Training," in IEEE Internet of Things Journal, vol. 10, no. 12, pp. 10778-10789, 15 June15, 2023, doi: 10.1109/JIOT.2023.3240242.

[2] X. Fu et al., "Semi-Supervised Specific Emitter Identification via Dual Consistency Regularization," in IEEE Internet of Things Journal, doi: 10.1109/JIOT.2023.3281668.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/133215
推荐阅读
相关标签
  

闽ICP备14008679号