当前位置:   article > 正文

loss盘点: BCE loss —— binary_cross_entropy_with_logits

binary_cross_entropy_with_logits

我的 torch 版本: 1.8.1+cu111
我的 paddle 版本: 2.4.1

torch API位置

torch.nn.functional.binary_cross_entropy_with_logits
  • 1

paddle API位置

paddle.nn.functional.binary_cross_entropy_with_logits
  • 1

二者除了 Deprecated 参数外,在大部分计算上基本都是对齐的

1. 计算方式

logit 是模型的输出,通过sigmoid激活函数 ( σ \sigma σ) 之后便可以转化为概率

l o s s loss loss 是这样计算的:
O u t = − L a b e l s ∗ log ⁡ ( σ ( L o g i t ) ) − ( 1 − L a b e l s ) ∗ log ⁡ ( 1 − σ ( L o g i t ) ) Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit)) Out=Labelslog(σ(Logit))(1Labels)log(1σ(Logit))
其实也就是,交叉熵 l o s s loss loss 的最基本公式:
O u t = − Y ∗ l o g ( y p r e d ) − ( 1 − Y ) ∗ l o g ( 1 − y p r e d ) Out = -Y * log(y_{pred}) - (1 - Y) * log(1 - y_{pred}) Out=Ylog(ypred)(1Y)log(1ypred)

σ ( L o g i t ) = 1 1 + e − L o g i t \sigma(Logit) = \frac{1}{1 + e^{-Logit}} σ(Logit)=1+eLogit1 带入可以简化计算,则:
O u t = L o g i t − L o g i t ∗ L a b e l s + log ⁡ ( 1 + e − L o g i t ) Out = Logit - Logit * Labels + \log(1 + e^{-Logit}) Out=LogitLogitLabels+log(1+eLogit)

文档上这样说:

该 OP 结合了 sigmoid 操作和 BCELoss 操作。同时,我们也可以认为该 OP 是sigmoid_cross_entrop_with_logits 和一些 reduce 操作的组合。

2. 实验代码

# -*- coding: utf-8 -*-
"""
Created on Wed Jan  4 22:36:50 2023

@author: zihao
"""

import numpy as np
import torch
import paddle


# ----------- numpy 参数 -----------
np.random.seed(1107)

# 假设 bs=4, 7种(多分类)
np_logit = np.random.rand(4, 7).astype("float32") 
np_target = np.random.randint(2, size=(4, 7)).astype("float32")

# 给每个类加权重
np_pos_weight = np.random.randint(2, 4, size=(7,)).astype("float32")

# 给每个 batch 的元素 加权重
np_weight = np.random.randint(2, 4, size=(7,)).astype("float32")


# ----------- torch -----------
t_logit = torch.tensor(np_logit)
t_target = torch.tensor(np_target)
t_pos_weight = torch.tensor(np_pos_weight)
t_weight = torch.tensor(np_weight)
t_out = torch.nn.functional.binary_cross_entropy_with_logits(t_logit, t_target,
                                                             weight=t_weight,
                                                             pos_weight=t_pos_weight,
                                                             reduction='none')

# torch 手动计算
t_out_hand = t_logit - t_logit * t_target + torch.log(1 + torch.exp(-t_logit))
t_pos_weight = t_target * t_pos_weight + (1 - t_target)
t_out_hand = t_out_hand * t_pos_weight 
t_out_hand = t_out_hand * t_weight 


# ----------- paddle -----------
p_logit = paddle.to_tensor(np_logit)
p_target = paddle.to_tensor(np_target)
p_pos_weight = paddle.to_tensor(np_pos_weight)
p_weight = paddle.to_tensor(np_weight)
p_out = paddle.nn.functional.binary_cross_entropy_with_logits(p_logit, p_target, 
                                                              weight=p_weight,
                                                              pos_weight=p_pos_weight,
                                                              reduction='none')

# paddle  手动计算
p_out_hand = p_logit - p_logit * p_target + paddle.log(1 + paddle.exp(-p_logit))
p_pos_weight = p_target * p_pos_weight + (1 - p_target)
p_out_hand = p_out_hand * p_pos_weight 
p_out_hand = p_out_hand * p_weight 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

在以上代码中,t_out 和 t_out_hand 近乎相等,p_out 和 p_out_hand 近乎相等。前者是调用API计算的,后者是根据公式手动计算的

3. 稍稍看下源码

在 Paddle 源码新动态图部分是这样计算的:

    if in_dygraph_mode():
        one = _C_ops.full(
            [1],
            float(1.0),
            core.VarDesc.VarType.FP32,
            _current_expected_place(),
        )
		
		# 此处按照公式进行计算
        out = _C_ops.sigmoid_cross_entropy_with_logits(
            logit, label, False, -100
        )
		
		# 给每个正例乘以对应的权重 pos_weight 
        if pos_weight is not None:
            log_weight = _C_ops.add(
                _C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
            )
            out = _C_ops.multiply(out, log_weight)
		
		# 给每个 batch 乘以对应权重
        if weight is not None:
            out = _C_ops.multiply(out, weight)
		
		# 做 reduce 操作
        if reduction == "sum":
            return _C_ops.sum(out, [], None, False)
        elif reduction == "mean":
            return _C_ops.mean_all(out)
        else:
            return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

关于 pos_weight 的计算,诸位需要稍微认真看一下

我是这样计算的:

p_pos_weight = p_target * p_pos_weight + (1 - p_target)
  • 1

可以这样简化一下:

p_pos_weight = p_pos_weight * p_target - one * p_target + one
             = (p_pos_weight - one) * p_target + one
             = p_target * (p_pos_weight - one) + one
  • 1
  • 2
  • 3

也就是源码中这样的计算

log_weight = _C_ops.add(
                _C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
            )
  • 1
  • 2
  • 3
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/525372
推荐阅读
相关标签
  

闽ICP备14008679号