当前位置:   article > 正文

Pytorch损失函数之BCELoss与BCEWithLogitsLoss_nn.bcewithlogitsloss

nn.bcewithlogitsloss

1.先说结论

nn.BCEWithLogitsLoss等于nn.BCELoss+nn.Sigmoid
主要用于二分类问题,多标签分类问题。

图为Pytorch Document对于BCEWithLogitsLoss的描述,这个损失函数结合了Sigmoid和BCELoss。
BCEWithLogitsLoss

2.公式分解

  • BCEWithLogitsLoss
    假设有N个batch,每个batch预测n个标签,则Loss为:
    L o s s = { l 1 , . . . , l N } ,   l n = − [ y n ⋅ log ⁡ ( σ ( x n ) ) + ( 1 − y n ) ⋅ log ⁡ ( 1 − σ ( x n ) ) ] Loss = \{ l_1 , ... , l_N \} , \ l_n = - [ y_n \cdot \log ( \sigma { ( x_n ) }) + ( 1 - y_n ) \cdot \log ( 1 - \sigma { ( x_n ) } ) ] Loss={l1,...,lN}, ln=[ynlog(σ(xn))+(1yn)log(1σ(xn))]
    其中 σ ( x n ) σ(x_n) σ(xn)Sigmoid函数,可以把x映射到(0, 1)的区间:
    σ ( x ) = 1 1 + exp ⁡ ( − x ) \sigma ( x ) = \frac { 1 } { 1 + \exp ( -x ) } σ(x)=1+exp(x)1
    图自百度百科

  • BCELoss
    同样假设有N个batch,每个batch预测n个标签,则Loss为:
    L o s s = { l 1 , . . . , l N } ,   l n = − [ y n ⋅ log ⁡ ( x n ) + ( 1 − y n ) ⋅ log ⁡ ( 1 − x n ) ] Loss = \{ l_1 , ... , l_N \} , \ l_n = - [ y_n \cdot \log ( x_n ) + ( 1 - y_n ) \cdot \log ( 1 - x_n ) ] Loss={l1,...,lN}, ln=[ynlog(xn)+(1yn)log(1xn)]
    可见与BCEWithLogitsLoss差了一个 σ ( x ) \sigma(x) σ(x)函数

3.实验代码

# 随机初始化label值,两个Batch,每个含3个标签
label = torch.empty((2, 3)).random_(2)
# 注意这是多标签问题,因此每个样本可能同时对应多种标签
# 每个标签内则是二分类问题,属于或者不属于这个标签
# tensor([[0., 1., 0.],
#         [0., 1., 1.]])

# 随机初始化x值,代表模型的预测值
x = torch.randn((2, 3))
# tensor([[-0.6117,  0.1446,  0.0415],
#         [-1.5376, -0.2599, -0.9680]])

sigmoid = nn.Sigmoid()
x1 = sigmoid(x)
# 归一化至 (0, 1)区间
# tensor([[0.3517, 0.5361, 0.5104],
#         [0.1769, 0.4354, 0.2753]])

bceloss = nn.BCELoss()
bceloss(x1, label)
# tensor(0.6812)

# 再用BCEWithLogitsLoss计算,对比结果
bce_with_logits_loss = nn.BCEWithLogitsLoss()
bce_with_logits_loss(x, label)
# tensor(0.6812)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

4.log-sum-exp数值稳定
当我们使用BCEWithLogitsLoss损失函数时,除了相比于BCELoss方便外,还因整合了Sigmoid函数,以实现LogSumExp的技巧,达到数值稳定的优势。

但经过测试,单纯使用Sigmoid+BCELoss也并没有出现 inf ⁡ \inf inf − inf ⁡ -\inf inf的溢出情况,还望小伙伴指点。

x = torch.tensor(1e+10)
x1 = sigmoid(x)
# tensor(1.)

label = torch.tensor(1.)
bceloss(x1, label)
# tensor(0.)

bce_with_logits_loss(x, label)
# tensor(0.)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/708942
推荐阅读
相关标签
  

闽ICP备14008679号