当前位置:   article > 正文

CrossEntropyLoss 和NLLLoss的关系

CrossEntropyLoss 和NLLLoss的关系

交叉熵损失在做一件什么事?

看公式:

x是预测(不需要softmax归一化),y是label, N是batch维度的数量,交叉熵损失,干了三件事.

1. 对输入在类别维度求softmax

2. 多softmax后的数,求log

3. 对(样本数, 类别数)为shape的tensor计算NLLLoss.

其中,NLLloss做的就是log取负, 和one-hot编码点乘.相加得到最终的总损失,因为reduction默认为mean,所以除以样本数.看以下代码.

代码实现

  1. #
  2. import torch
  3. import torch.nn as nn
  4. #
  5. # cross entropy loss = softmax + log + nllloss
  6. # 先softmax, 再 log,
  7. # 初始化 input_ 和 target
  8. input_ = torch.randn(3,3)
  9. target = torch.tensor([0,2,1])
  10. mask = torch.zeros(3,3)
  11. mask[0,0] = 1
  12. mask[1,2] = 1
  13. mask[2,1] = 1
  14. # 1.0 输入softmax
  15. sft_ = nn.Softmax(dim = -1)(input_)
  16. # 2.0 log
  17. log_ = torch.log(sft_)
  18. # 3.0 nllloss
  19. loss = nn.NLLLoss()
  20. print("split loss")
  21. print(loss(log_, target))
  22. # 4.0 crossentropy
  23. print("ce loss")
  24. loss = nn.CrossEntropyLoss()
  25. print(loss(input_, target))
  26. print("manual loss")
  27. neg_log = 0 - log_
  28. print(torch.sum(mask *neg_log ) / 3)
  29. # ----------输出--------------
  30. >> loss_function python crossEntropyLoss.py
  31. >> split loss
  32. >> tensor(1.2294)
  33. >> ce loss
  34. >> tensor(1.2294)
  35. >> manual loss
  36. >> tensor(1.2294)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/243715
推荐阅读
相关标签
  

闽ICP备14008679号