当前位置:   article > 正文

BCELoss,BCEWithLogitsLoss和CrossEntropyLoss

bcewithlogitsloss

目录

二分类

1. BCELoss

2. BCEWithLogitsLoss

多分类

1. CrossEntropyLoss

 举例


二分类

两个损失:BCELoss,BCEWithLogitsLoss

1. BCELoss

输入:([B,C], [B,C]),代表(prediction,target)的维度,其中,B是Batchsize,C为样本的class,即样本的类别数。

输出:一个标量

等价于:BCELoss + sigmoid

  1. import torch
  2. from torch import nn
  3. input = torch.randn(3) # (3,1) 随机生成一个输入,没有被sigmoid。
  4. print(input)
  5. print(input.shape)
  6. target=torch.Tensor([0., 1., 1.])
  7. loss1=nn.BCELoss()
  8. print("BCELoss:",loss1(torch.sigmoid(input), target))#需要sigmod
  9. 输出:
  10. BCELoss: tensor(1.0053)

2. BCEWithLogitsLoss

输入:([B,C], [B,C]),输出:一个标量

  1. import torch
  2. from torch import nn
  3. input = torch.randn(3) # (3,1) 随机生成一个输入,没有被sigmoid。
  4. print(input)
  5. print(input.shape)
  6. target=torch.Tensor([0., 1., 1.])
  7. loss2=nn.BCEWithLogitsLoss()
  8. print("BCEWithLogitsLoss:",loss2(input,target))#不需要sigmoid
  9. 输出:
  10. BCEWithLogitsLoss: tensor(1.0053)

多分类

1. CrossEntropyLoss

输入:([B,C], [B]) 输出:一个标量(这个minibatch的mean/sum的loss)

nn.CrossEntropyLoss计算过程: 
input: logits(未经过softmax的模型的"输出”)

  •  softmax(input)
  • -log(softmax(input))
  • 用target做选择提取(关于logsoftmax)· mean

等价于:nn.CrossEntropyLoss = nn.NLLLoss(nn.LogSoftmax)
 

  1. import torch
  2. from torch import nn
  3. loss2 = nn.CrossEntropyLoss(reduction="none")
  4. target2 = torch.tensor([0, 1, 2])
  5. predict2 = torch.tensor([[0.9, 0.2, 0.8], [0.5, 0.2, 0.4], [0.4, 0.2, 0.9]])
  6. print(predict2.shape) # torch.Size([3, 3])
  7. print(target2.shape) # torch.Size([3])
  8. print(loss2(predict2, target2))
  9. # #结果计算为:
  10. # tensor([0.8761, 1.2729, 0.7434])

 举例

1. BCEWithLogitsLoss计算ACC和Loss:

参考:https://github.com/Loche2/IMDB_RNN/blob/master/training.py

  1. criterion = nn.BCEWithLogitsLoss()
  2. # 计算准确率
  3. def binary_accuracy(predicts, y):
  4. rounded_predicts = torch.round(torch.sigmoid(predicts))
  5. correct = (rounded_predicts == y).float()
  6. accuracy = correct.sum() / len(correct)
  7. return accuracy
  8. # 训练
  9. def train(model, iterator, optimizer, criterion):
  10. model.train()
  11. epoch_loss = 0
  12. epoch_accuracy = 0
  13. for batch in tqdm(iterator, desc=f'Epoch [{epoch + 1}/{EPOCHS}]', delay=0.1):
  14. optimizer.zero_grad()
  15. predictions = model(batch.text[0]).squeeze(1)
  16. loss = criterion(predictions, batch.label)
  17. accuracy = binary_accuracy(predictions, batch.label)
  18. loss.backward()
  19. optimizer.step()
  20. epoch_loss += loss.item()
  21. epoch_accuracy += accuracy.item()
  22. return epoch_loss / len(iterator), epoch_accuracy / len(iterator)

2. 计算ACC和Loss

  1. # 截取情感分析部分代码
  2. criterion = nn.CrossEntropyLoss()
  3. total_loss = 0.0
  4. correct_predictions = 0
  5. total_predictions = 0
  6. for batch in train_loader:
  7. input_ids = batch['input_ids'].to(device)
  8. labels = batch['label'].to(device)
  9. optimizer.zero_grad()
  10. logits = model(input_ids)
  11. loss_sentiment = criterion(logits, labels.long())
  12. loss_sentiment.backward()
  13. optimizer.step()
  14. total_loss += loss_sentiment.item()
  15. # get sentiment accuracy
  16. predicted_labels = torch.argmax(logits, dim=1)
  17. correct_predictions += torch.sum(predicted_labels == labels).item()
  18. total_predictions += labels.size(0)
  19. accuracy = correct_predictions / total_predictions
  20. loss = total_loss / len(train_loader)

也可以直接看github上别人写的例子:https://github.com/songyouwei/ABSA-PyTorch/blob/master/train.py

参考:

深刻剖析与实战BCELoss详解(主)和BCEWithLogitsLoss(次)以及与普通CrossEntropyLoss的区别(次)-CSDN博客

另外提出一个问题:

二分类必须用BCEWithLogitsLoss吗,也可以用CrossEntropyLoss吧?

(1)如果用CrossEntropyLoss的话,只要让网络的fc层为nn.Linear(hidden_size, 2)就行,这样就和多分类一样算。另外CrossEntropyLoss里面包含了softmax,所以在计算loss的时候也不需要过softmax再算loss.

(2)如果用BCEWithLogitsLoss的话,就按照上面举例中BCEWithLogitsLoss计算Loss,只是如上面代码可是,再计算Acc的时候将predict使用sigimoid缩放到0,1来计算预测正确的个数

注:仅供学习记录,理解或者学习有误请与我联系

参考问题:二分类问题,应该选择sigmoid还是softmax? - 知乎 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/790718
推荐阅读
相关标签
  

闽ICP备14008679号