赞
踩
在深度学习训练过程中,损失计算是非常重要的一环。有一种论调叫做“损失函数的设计在一定程度上决定了模型能干什么”。
PyTorch提供了许多损失计算函数,包括torch.nn.L1Loss,torch.MSELoss,torch.nn.CrossEntropyLoss,torch.nn.BCELoss,torch.nn.BCEWithLogitsLoss等。
具体可以参考如下连接:torch.nn — PyTorch 2.2 documentation
当然,开发者也可以按照各种Loss的原理自行构建损失函数,PyTorch也提供了许多基础计算接口。
分类损失一般采用交叉熵损失,根据不同的情况又会采取不同的损失形式。如在DeepLab图像分割系列网络中,我们要对最后一层特征图中的每一个像素(或称为cell)都进行分类,最终要确定某一个像素属于哪一个类(有多种可能,但最终仅属于一个类),也就是说在通道轴上进行分类,那么我们采用的普遍是多分类交叉熵损失形式。
而在YOLO系列中,以YOLOv5为例,最终虽然需要对一个cell进行多分类(该cell可能归为多个类别,需要分别确认属于和不属于该类的概率),我们在最后处理的时候,是判断属于和不属于某一类的概率大小,确切地说是二分类问题,那么我们通常采用二院交叉熵损失。
PyTorch提供了许多用于分类的损失函数,如下是常用的几个分类损失函数。
原型如下所示:
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)
比较关键的参数有两个,分别是weight和reduction。
weight:类别权重,分属于每一个类别的权重信息,其维度应该等同于类别数量。
reduction:归纳操作,默认取值为计算均值。如果设置为“none”则表示不归纳,此时输出的Tensor维度与类别数相同。设置为“sum”为求和。
使用时,参考如下。
- import torch
-
- from torch.nn import CrossEntropyLoss, BCELoss, BCEWithLogitsLoss
-
- CE_loss = CrossEntropyLoss(reduction='mean')
-
- CE_loss_result = CE_loss.forward(input, target) # input:(N, C, H, W) target:(H, H, W)
需要注意input和target的shape信息,一般情况下,input通常是带有C通道,而target则不带有C通道。主要原因是,target我们直接指定了类别的索引。
但是自PyTorch1.10版本以后,开始支持input和target具有相同的shape,此时,target不再是指定了类别索引,而是经过了one-hot或label-smooth操作。
另外一个需要注意的是,通过如下PyTorch中CrossEntropyLoss解释可以看出,CrossEntropyLoss实际上是对input做了softmax操作,因此开发者无需再网络输出后添加类似操作,可将网络输出直接导入CrossEntropyLoss执行计算。
二元交叉熵主要用来二分类,即标签只有0和1的场合。
PyTorch提供的torch.nn.BCELoss原型如下。
torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')
使用时,参考如下。
- import torch
-
- from torch.nn import CrossEntropyLoss, BCELoss, BCEWithLogitsLoss
-
- BCE_LOG_LOSS = BCEWithLogitsLoss(reduction='mean')
-
- BCE_LOG_LOSS_result_normal = BCE_LOG_LOSS.forward(input, target_onehot.float())
其中,input和target的shape是相同的。
需要特别注意一点,nn.BCEloss计算损失时,需要对预测值执行sigmoid计算。sigmoid函数会将预测值映射到0-1之间。如果觉得手动加sigmoid函数麻烦,可以直接调用nn.BCEwithlogitsloss。
torch.nn.BCEWithLogitsLoss在使用上和torch.nn.BCELoss基本一致,不同点是无需对预测执行sigmoid操作。
原型如下。
torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
如何选取损失函数?
单纯的以多分类和二分类来确定使用torch.nn.CrossEntropyLoss还是torch.nn.BCELoss并不准确,还是要结合实际的网络输出来定。
torch.nn.CrossEntropyLoss典型的是应用在某一个cell有多个分类输出,每个分类输出都有自己的概率,总的概率为1。这种场景下,一般网络输出先做softmax,然后计算交叉熵损失。
torch.nn.BCELoss主要还是针对“是与不是”的问题。以YOLOv5为例,每一个cell都有分类输出,且分类输出的是属于每一个类的概率。这样,我们可以单纯的认为,网络输出的是“是与不是”某个类的概率。如共有3各类,分别是“cat”,“dog”,“ant”,那么网络输出的是属于cat的概率,属于dog的概率和属于ant的概率。
L1损失主要用来计算平均绝对误差(MAE)。其公示如下。
L1损失函数的主要特点是它对异常值不敏感,因此,当数据中存在一些噪声或离群点时,L1损失函数可能会比L2损失(均方误差)更为稳健。
然而,L1损失在0点处不是光滑的,这可能会导致求解效率不如L2损失。在某些情况下,可能需要综合使用L1和L2损失,以获得在稳定性和求解效率之间的平衡。
PyTorch提供的L1Loss原型如下。
torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')
使用参考如下。
- loss = nn.L1Loss()
-
- input = torch.randn(3, 5, requires_grad=True)
-
- target = torch.randn(3, 5)
-
- output = loss(input, target)
-
- output.backward()
MSE,即均方误差损失,也称作L2损失,是一种常用的回归损失函数。它计算的是预测值与真实值之间差值的平方的均值。其公式如下。
MSE损失函数对预测误差进行了平方,这意味着它对大的误差赋予更大的权重。由于平方的性质,大的误差会产生更大的损失值,因此模型在训练过程中会倾向于首先减少那些大的预测误差。这使得MSE损失对于异常值或噪声数据比较敏感,可能会导致模型过于关注这些值而牺牲了整体的预测性能。
PyTorch提供的MSELoss原型如下。
torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
其使用参考如下。
- loss = nn.MSELoss()
-
- input = torch.randn(3, 5, requires_grad=True)
-
- target = torch.randn(3, 5)
-
- output = loss(input, target)
-
- output.backward()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。