赞
踩
根据官方给出的binary_cross_entropy_with_logits函数的二分类交叉熵损失计算公式:
其中, N代表batch大小。
可以看到,weight参数代表每个样本的权重。
根据官方对pos_weight参数的解释:a weight of positive examples to be broadcasted with target. Must be a tensor with equal size along the class dimension to the number of classes.
我认为pos_weight参数代表每个类别的权重,结合官方给出的binary_cross_entropy_with_logits函数的多标签分类交叉熵损失计算公式:
l
c
(
x
,
y
)
=
L
c
=
{
l
1
,
c
,
.
.
.
,
l
N
,
c
}
,
l
n
,
c
=
−
w
n
,
c
[
p
c
y
n
,
c
∗
l
o
g
σ
(
x
n
,
c
)
+
(
1
−
y
n
,
c
)
∗
l
o
g
(
1
−
σ
(
x
n
,
c
)
)
]
\mathcal{l}_c(x,y)=L_c=\{l_{1,c},...,l_{N,c}\}, l_{n,c}=-w_{n,c}[p_cy_{n,c}*log\sigma (x_{n,c})+(1-y_{n,c})*log(1-\sigma (x_{n,c}))]
lc(x,y)=Lc={l1,c,...,lN,c},ln,c=−wn,c[pcyn,c∗logσ(xn,c)+(1−yn,c)∗log(1−σ(xn,c))]
其中,
c
c
c代表类数,N代表batch大小。
binary_cross_entropy_with_logits在计算多标签分类任务的交叉熵损失时,分割成
c
c
c个单标签二分类任务,具体原理参考:https://www.cnblogs.com/Fish0403/p/17073047.html。
在上述公式中,
w
n
,
c
w_{n,c}
wn,c就是weight参数,具体意思是:样本
n
n
n在第
c
c
c类二分类任务上的权重,
p
c
p_c
pc就是pos_weight参数,代表在第
c
c
c类二分类任务上,正样本类别所对应的权重。以一个猫狗分类的多标签分类模型为例:
上图是一个多分类任务,有三个类别(猫、狗、猪),每张图片可能包含1-3个类别。输入一张包含猫和狗的图片(假设为样本1),经过神经网络及sigmoid后,输出三个类别的概率分布。要计算交叉熵损失,可以先分别计算三个二分类任务的交叉熵损失,即:是否有狗,是否有猫,是否有猪。然后将三个损失相加,得到该样本的交叉熵损失。那么在计算过程中,weight参数与pos_weight参数分别作用于哪里呢?
以是否有狗这个二分类任务为例,若weight与pos_weight均为None,那么
l
o
s
s
1
=
−
y
1
l
o
g
y
1
^
−
(
1
−
y
i
)
l
o
g
(
1
−
y
i
^
)
loss_1=-y_1log\hat{y_1}-(1-y_i)log(1-\hat{y_i})
loss1=−y1logy1^−(1−yi)log(1−yi^),若weight与pos_weight均不为None,那么
l
o
s
s
1
=
−
w
1
,
1
[
p
1
y
1
l
o
g
y
1
^
+
(
1
−
y
i
)
l
o
g
(
1
−
y
i
^
)
]
loss_1=-w_{1,1}[p_1y_1log\hat{y_1}+(1-y_i)log(1-\hat{y_i})]
loss1=−w1,1[p1y1logy1^+(1−yi)log(1−yi^)]。这里,
w
1
,
1
w_{1,1}
w1,1代表样本1在是否有狗这个二分类任务上的权重,
p
1
p_1
p1代表在是否有狗这个二分类任务上,“有狗”这一类别的权重。
总结来说,weight参数代表样本权重,pos_weight参数代表类别权重,根据具体情况,两者可以结合使用。
这是我对这两个参数的理解,说实话我也不确定自己理解的对不对,只是在看过官方文档后,感觉这么理解能讲得通,可能有很多不对的地方,欢迎大家批评交流!
https://www.cnblogs.com/Fish0403/p/17073047.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。