赞
踩
咱们直接进入正题!
- def train(model, loss1, loss2, train_dataloader, optimizer_loss1, optimizer_loss2, epoch, writer, device_num):
- model.train()
- device = torch.device("cuda:"+str(device_num))
- correct = 0
- value_loss1 = 0
- value_loss2 = 0
- result_loss = 0
- for data_nnl in train_dataloader:
- data, target = data_nnl
- target = target.long()
- if torch.cuda.is_available():
- data = data.to(device)
- target = target.to(device)
-
- optimizer_loss1.zero_grad()
- optimizer_loss2.zero_grad()
- output = model(data)
- classifier_output = F.log_softmax(output[1], dim=1)
- value_loss1_batch = loss1(classifier_output, target) //第一个损失项
- value_loss2_batch = loss2(output[0], target) //第二个损失项
-
- weight_loss2 = 0.005
-
- result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch
-
- result_loss_batch.backward()
- optimizer_loss1.step()
- for param in loss2.parameters():
- param.grad.data *= (1. / weight_loss2)
- optimizer_loss2.step()
我这里采用的是两项损失,loss1用于优化网络权重,loss2用于优化中心矢量,二者均是可训练的超参,因此包含两个优化器,如果多个损失项均用于优化网络权重,那么只采用一个优化器即可,如下所示
- def train(model, loss1, loss2, train_dataloader, optimizer, epoch, writer, device_num):
- model.train()
- device = torch.device("cuda:"+str(device_num))
- correct = 0
- value_loss1 = 0
- value_loss2 = 0
- result_loss = 0
- for data_nnl in train_dataloader:
- data, target = data_nnl
- target = target.long()
- if torch.cuda.is_available():
- data = data.to(device)
- target = target.to(device)
-
- optimizer.zero_grad()
- output = model(data)
- classifier_output = F.log_softmax(output[1], dim=1)
- value_loss1_batch = loss1(classifier_output, target) //第一个损失项
- value_loss2_batch = loss2(output[0], target) //第二个损失项
-
- weight_loss2 = 0.005
-
- result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch
-
- result_loss_batch.backward()
- optimizer.step()
详细代码,请翻阅我们的论文,代码已开源,开源链接可查论文摘要。
若该经验贴对您科研、学习有所帮助,欢迎您引用我们的论文。
[1] X. Fu et al., "Semi-Supervised Specific Emitter Identification Method Using Metric-Adversarial Training," in IEEE Internet of Things Journal, vol. 10, no. 12, pp. 10778-10789, 15 June15, 2023, doi: 10.1109/JIOT.2023.3240242.
[2] X. Fu et al., "Semi-Supervised Specific Emitter Identification via Dual Consistency Regularization," in IEEE Internet of Things Journal, doi: 10.1109/JIOT.2023.3281668.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。