当前位置:   article > 正文

cross_entropy、binary_cross_entropy、binary_cross_entropy_with_logits

binary_cross_entropy_with_logits

cross_entropy

原理

函数具体介绍参考torch.nn.functional.cross_entropy使用
交叉熵的计算参考交叉熵损失(Cross Entropy Loss)计算过程

实例

code

z = torch.tensor([[1,2],[1,4]],dtype=float)  # input tensor
y = torch.tensor([0,1])
print(z)
print(y)

loss1 = torch.nn.functional.cross_entropy(z,y)
print(loss1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

output
tensor([[1., 2.],
[1., 4.]], dtype=torch.float64)
tensor([0, 1])
tensor(0.6809, dtype=torch.float64)

计算过程

− [ l n ( e e + e 2 ) + l n ( e 4 e + e 4 ) ] / 2 = 0.6809 -[ln(\frac{e}{e+e^2})+ln(\frac{e^4}{e+e^4})]/2 = 0.6809 [ln(e+e2e)+ln(e+e4e4)]/2=0.6809

binary_cross_entropy

原理

默认做二分类,将模型输出z和理想输出x的每个元素当作是一个二分类的结果,然后计算交叉熵(在该函数中,z的每个元素的范围是[0,1],y没有限制,但理论上y应该在[0,1]之间。z和y可以是任意维数的,但必须形状相同。)

实例

code

z = torch.ones(3,2)*0.8  # input tensor
y = torch.ones(3,2)*0.4
print(z)
print(y)

loss2 = torch.nn.functional.binary_cross_entropy(z, y)
print(loss2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

output

tensor([[0.8000, 0.8000],
[0.8000, 0.8000],
[0.8000, 0.8000]])
tensor([[0.4000, 0.4000],
[0.4000, 0.4000],
[0.4000, 0.4000]])
tensor(1.0549)

计算过程

− ( 0.4 ∗ l n ( 0.8 ) + 0.6 ∗ l n ( 1 − 0.8 ) ) ∗ 2 ∗ 3 / ( 2 ∗ 3 ) = 1.0549 -(0.4*ln(0.8)+0.6*ln(1-0.8))*2*3/(2*3)=1.0549 (0.4ln(0.8)+0.6ln(10.8))23/(23)=1.0549

binary_cross_entropy_with_logits

原理

默认做二分类,将模型输出z做sigmoid后和理想输出x的每个元素当作是一个二分类的结果,然后计算交叉熵(在该函数中,由于z的每个元素是做完sigmoid后再做交叉熵,因此没有值的范围限制,y也没有,但理论上y应该在[0,1]之间。z和y可以是任意维数的,但必须形状相同。)

实例

code

z = torch.ones(3,2)*0.8  # input tensor
y = torch.ones(3,2)*0.4
print(z)
print(y)

loss3 = torch.nn.functional.binary_cross_entropy_with_logits(z, y,reduction='sum')
print(loss3)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

output

tensor([[0.8000, 0.8000],
[0.8000, 0.8000],
[0.8000, 0.8000]])
tensor([[0.4000, 0.4000],
[0.4000, 0.4000],
[0.4000, 0.4000]])
tensor(5.1066)

计算过程

− ( 0.4 ∗ l n ( s i g m o i d ( 0.8 ) ) + 0.6 ∗ l n ( 1 − s i g m o i d ( 0.8 ) ) ) ∗ 2 ∗ 3 = 5.1066 -(0.4*ln(sigmoid(0.8))+0.6*ln(1-sigmoid(0.8)))*2*3 = 5.1066 (0.4ln(sigmoid(0.8))+0.6ln(1sigmoid(0.8)))23=5.1066

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/525379
推荐阅读
相关标签
  

闽ICP备14008679号