赞
踩
flyfish
本系列的主要内容是在2017年所写,GPT使用了交叉熵损失函数,所以就温故而知新,文中代码又用新版的PyTorch写了一遍,在看交叉熵损失函数遇到问题时,可先看链接提供的基础知识,可以有更深的理解。
深入理解交叉熵损失 CrossEntropyLoss - one-hot 编码
深入理解交叉熵损失 CrossEntropyLoss - 对数
深入理解交叉熵损失 CrossEntropyLoss - 概率基础
深入理解交叉熵损失 CrossEntropyLoss - 概率分布
深入理解交叉熵损失 CrossEntropyLoss - 损失函数
深入理解交叉熵损失 CrossEntropyLoss - 归一化
深入理解交叉熵损失 CrossEntropyLoss - 信息论(交叉熵)
深入理解交叉熵损失 CrossEntropyLoss - Softmax
深入理解交叉熵损失 CrossEntropyLoss - nn.LogSoftmax
深入理解交叉熵损失 CrossEntropyLoss - 似然
深入理解交叉熵损失CrossEntropyLoss - 乘积符号在似然函数中的应用
深入理解交叉熵损失 CrossEntropyLoss - nn.NLLLoss
深入理解交叉熵损失 CrossEntropyLoss - nn.CrossEntropyLoss
torch.nn.CrossEntropyLoss
是一个常用的
损失函数,主要用于多分类任务。它结合了
nn.LogSoftmax 和
nn.NLLLoss,并且内部进行了优化以避免
数值稳定性问题。
具体来说,torch.nn.CrossEntropyLoss 计算的是预测值与目标值之间的交叉熵损失。对于多分类问题,交叉熵损失是最常用的损失函数,因为它直接衡量了两个概率分布(预测概率分布和实际分布)之间的差异。
nn.CrossEntropyLoss 在内部已经包含了 LogSoftmax 和 NLLLoss 的操作。
编写代码验证,分别是 LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss。
import torch import torch.nn as nn # 输入张量 (batch_size=2, num_classes=3) input_tensor = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) # 目标张量 (batch_size=2) target_tensor = torch.tensor([2, 0]) # 使用 nn.LogSoftmax 和 nn.NLLLoss log_softmax = nn.LogSoftmax(dim=1) log_probs = log_softmax(input_tensor) nll_loss = nn.NLLLoss() loss = nll_loss(log_probs, target_tensor) print(f'Loss using LogSoftmax and NLLLoss: {loss.item()}') # 使用 nn.CrossEntropyLoss cross_entropy_loss = nn.CrossEntropyLoss() loss_ce = cross_entropy_loss(input_tensor, target_tensor) print(f'Loss using CrossEntropyLoss: {loss_ce.item()}')
输出结果
Loss using LogSoftmax and NLLLoss: 1.4076058864593506
Loss using CrossEntropyLoss: 1.4076058864593506
对于单个样本,交叉熵损失的定义如下:
CrossEntropyLoss = − ∑ i = 1 C y i log ( y ^ i ) \text{CrossEntropyLoss} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) CrossEntropyLoss=−i=1∑Cyilog(y^i)
其中:
交叉熵损失结合了两个概念:
y ^ i = exp ( z i ) ∑ j = 1 C exp ( z j ) \hat{y}_i = \frac{\exp(z_i)}{\sum_{j=1}^{C} \exp(z_j)} y^i=∑j=1Cexp(zj)exp(zi)
其中 z i z_i zi 是第 i i i 类的 logit。
以下是一个简单的示例,展示如何计算交叉熵损失:
import torch
import torch.nn as nn
# 假设我们有两个样本,每个样本属于3个类别中的一个
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.3]])
# 真实标签
labels = torch.tensor([0, 1])
# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')
Cross Entropy Loss: 0.37882310152053833
在这个示例中:
二分类交叉熵损失的公式为:
CrossEntropyLoss = − ( y log ( y ^ ) + ( 1 − y ) log ( 1 − y ^ ) ) \text{CrossEntropyLoss} = - (y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})) CrossEntropyLoss=−(ylog(y^)+(1−y)log(1−y^))
假设:
我们使用更高精度来计算:
exp
(
−
1.5
)
≈
0.22313016014842982
\exp(-1.5) \approx 0.22313016014842982
exp(−1.5)≈0.22313016014842982
σ
(
z
)
=
1
1
+
0.22313016014842982
≈
1
1.22313016014842982
≈
0.8175744761936437
\sigma(z) = \frac{1}{1 + 0.22313016014842982} \approx \frac{1}{1.22313016014842982} \approx 0.8175744761936437
σ(z)=1+0.223130160148429821≈1.223130160148429821≈0.8175744761936437
CrossEntropyLoss
=
−
(
y
log
(
σ
(
z
)
)
+
(
1
−
y
)
log
(
1
−
σ
(
z
)
)
)
\text{CrossEntropyLoss} = - (y \log(\sigma(z)) + (1 - y) \log(1 - \sigma(z)))
CrossEntropyLoss=−(ylog(σ(z))+(1−y)log(1−σ(z)))
CrossEntropyLoss
=
−
log
(
0.8175744761936437
)
\text{CrossEntropyLoss} = - \log(0.8175744761936437)
CrossEntropyLoss=−log(0.8175744761936437)
log
(
0.8175744761936437
)
≈
−
0.2014132779827524
\log(0.8175744761936437) \approx -0.2014132779827524
log(0.8175744761936437)≈−0.2014132779827524
CrossEntropyLoss
≈
0.2014132779827524
\text{CrossEntropyLoss} \approx 0.2014132779827524
CrossEntropyLoss≈0.2014132779827524
import torch import torch.nn as nn import math # 真实标签和 logits labels = torch.tensor([1.0]) logits = torch.tensor([1.5]) # 使用 BCEWithLogitsLoss criterion = nn.BCEWithLogitsLoss() loss = criterion(logits, labels) print(f'Binary Classification Cross Entropy Loss: {loss.item()}') # 手动计算 sigmoid 和交叉熵损失 sigmoid = 1 / (1 + math.exp(-1.5)) manual_loss = - (1 * math.log(sigmoid) + (1 - 1) * math.log(1 - sigmoid)) print(f'Manually Computed Cross Entropy Loss: {manual_loss}')
输出结果
Binary Classification Cross Entropy Loss: 0.20141397416591644
Manually Computed Cross Entropy Loss: 0.2014132779827524
假设有3个类别:
具体计算:
y
^
1
=
exp
(
0.1
)
exp
(
0.1
)
+
exp
(
0.2
)
+
exp
(
0.7
)
\hat{y}_1 = \frac{\exp(0.1)}{\exp(0.1) + \exp(0.2) + \exp(0.7)}
y^1=exp(0.1)+exp(0.2)+exp(0.7)exp(0.1)
y
^
2
=
exp
(
0.2
)
exp
(
0.1
)
+
exp
(
0.2
)
+
exp
(
0.7
)
\hat{y}_2 = \frac{\exp(0.2)}{\exp(0.1) + \exp(0.2) + \exp(0.7)}
y^2=exp(0.1)+exp(0.2)+exp(0.7)exp(0.2)
y
^
3
=
exp
(
0.7
)
exp
(
0.1
)
+
exp
(
0.2
)
+
exp
(
0.7
)
\hat{y}_3 = \frac{\exp(0.7)}{\exp(0.1) + \exp(0.2) + \exp(0.7)}
y^3=exp(0.1)+exp(0.2)+exp(0.7)exp(0.7)
计算得到:
exp
(
0.1
)
≈
1.1052
\exp(0.1) \approx 1.1052
exp(0.1)≈1.1052
exp
(
0.2
)
≈
1.2214
\exp(0.2) \approx 1.2214
exp(0.2)≈1.2214
exp
(
0.7
)
≈
2.0138
\exp(0.7) \approx 2.0138
exp(0.7)≈2.0138
总和:
exp ( 0.1 ) + exp ( 0.2 ) + exp ( 0.7 ) ≈ 1.1052 + 1.2214 + 2.0138 = 4.3404 \exp(0.1) + \exp(0.2) + \exp(0.7) \approx 1.1052 + 1.2214 + 2.0138 = 4.3404 exp(0.1)+exp(0.2)+exp(0.7)≈1.1052+1.2214+2.0138=4.3404
各个概率:
y
^
1
=
1.1052
4.3404
≈
0.2546
\hat{y}_1 = \frac{1.1052}{4.3404} \approx 0.2546
y^1=4.34041.1052≈0.2546
y
^
2
=
1.2214
4.3404
≈
0.2814
\hat{y}_2 = \frac{1.2214}{4.3404} \approx 0.2814
y^2=4.34041.2214≈0.2814
y
^
3
=
2.0138
4.3404
≈
0.4639
\hat{y}_3 = \frac{2.0138}{4.3404} \approx 0.4639
y^3=4.34042.0138≈0.4639
import torch import torch.nn as nn import torch.nn.functional as F # 模拟输入的 logits 和真实标签 logits = torch.tensor([[0.1, 0.2, 0.7]], requires_grad=True) labels = torch.tensor([2]) # 使用 CrossEntropyLoss criterion = nn.CrossEntropyLoss() loss = criterion(logits, labels) print(f'Computed Cross Entropy Loss (using nn.CrossEntropyLoss): {loss.item()}') # 手动计算 softmax 和交叉熵损失 softmax_probs = F.softmax(logits, dim=1) manual_loss = -torch.log(softmax_probs[0, labels]) print(f'Manually Computed Cross Entropy Loss: {manual_loss.item()}')
输出结果
Computed Cross Entropy Loss (using nn.CrossEntropyLoss): 0.7679495811462402
Manually Computed Cross Entropy Loss: 0.7679495811462402
注意在多分类问题的代码中,我们提供了logits而不是softmax后的概率,因为nn.CrossEntropyLoss会在内部应用softmax。
在二分类问题中,我们可以使用 nn.BCEWithLogitsLoss,它会在内部应用 Sigmoid 激活函数,并计算二分类的交叉熵损失。
在多分类问题中,我们可以使用 nn.CrossEntropyLoss,它会在内部应用 Softmax 激活函数,并计算多分类的交叉熵损失
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。