当前位置:   article > 正文

binary_cross_entropy_with_logits中的weight参数与pos_weight参数

binary_cross_entropy_with_logits


一、weight参数

根据官方给出的binary_cross_entropy_with_logits函数的二分类交叉熵损失计算公式:
在这里插入图片描述
其中, N代表batch大小。
可以看到,weight参数代表每个样本的权重。


二、pos_weight参数

根据官方对pos_weight参数的解释:a weight of positive examples to be broadcasted with target. Must be a tensor with equal size along the class dimension to the number of classes.
我认为pos_weight参数代表每个类别的权重,结合官方给出的binary_cross_entropy_with_logits函数的多标签分类交叉熵损失计算公式:
l c ( x , y ) = L c = { l 1 , c , . . . , l N , c } , l n , c = − w n , c [ p c y n , c ∗ l o g σ ( x n , c ) + ( 1 − y n , c ) ∗ l o g ( 1 − σ ( x n , c ) ) ] \mathcal{l}_c(x,y)=L_c=\{l_{1,c},...,l_{N,c}\}, l_{n,c}=-w_{n,c}[p_cy_{n,c}*log\sigma (x_{n,c})+(1-y_{n,c})*log(1-\sigma (x_{n,c}))] lc(x,y)=Lc={l1,c,...,lN,c},ln,c=wn,c[pcyn,clogσ(xn,c)+(1yn,c)log(1σ(xn,c))]
其中, c c c代表类数,N代表batch大小。
binary_cross_entropy_with_logits在计算多标签分类任务的交叉熵损失时,分割成 c c c个单标签二分类任务,具体原理参考:https://www.cnblogs.com/Fish0403/p/17073047.html
在上述公式中, w n , c w_{n,c} wn,c就是weight参数,具体意思是:样本 n n n在第 c c c类二分类任务上的权重, p c p_c pc就是pos_weight参数,代表在第 c c c类二分类任务上,正样本类别所对应的权重。以一个猫狗分类的多标签分类模型为例:
在这里插入图片描述
上图是一个多分类任务,有三个类别(猫、狗、猪),每张图片可能包含1-3个类别。输入一张包含猫和狗的图片(假设为样本1),经过神经网络及sigmoid后,输出三个类别的概率分布。要计算交叉熵损失,可以先分别计算三个二分类任务的交叉熵损失,即:是否有狗,是否有猫,是否有猪。然后将三个损失相加,得到该样本的交叉熵损失。那么在计算过程中,weight参数与pos_weight参数分别作用于哪里呢?
以是否有狗这个二分类任务为例,若weight与pos_weight均为None,那么 l o s s 1 = − y 1 l o g y 1 ^ − ( 1 − y i ) l o g ( 1 − y i ^ ) loss_1=-y_1log\hat{y_1}-(1-y_i)log(1-\hat{y_i}) loss1=y1logy1^(1yi)log(1yi^),若weight与pos_weight均不为None,那么 l o s s 1 = − w 1 , 1 [ p 1 y 1 l o g y 1 ^ + ( 1 − y i ) l o g ( 1 − y i ^ ) ] loss_1=-w_{1,1}[p_1y_1log\hat{y_1}+(1-y_i)log(1-\hat{y_i})] loss1=w1,1[p1y1logy1^+(1yi)log(1yi^)]。这里, w 1 , 1 w_{1,1} w1,1代表样本1在是否有狗这个二分类任务上的权重, p 1 p_1 p1代表在是否有狗这个二分类任务上,“有狗”这一类别的权重。

总结

总结来说,weight参数代表样本权重,pos_weight参数代表类别权重,根据具体情况,两者可以结合使用。
这是我对这两个参数的理解,说实话我也不确定自己理解的对不对,只是在看过官方文档后,感觉这么理解能讲得通,可能有很多不对的地方,欢迎大家批评交流!

参考文献

https://www.cnblogs.com/Fish0403/p/17073047.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/525365
推荐阅读
相关标签
  

闽ICP备14008679号