赞
踩
在深度学习中,损失函数的选择对于模型的训练和性能至关重要。均方误差损失(Mean Squared Error Loss,MSE Loss)常用于回归问题,但对于分类任务来说,其效果并不理想。本文将深入探讨为什么MSE Loss不适合处理分类任务。
分类任务的目标是将输入样本分为不同的类别。对于分类问题,以下特性使得MSE Loss并不适用:
假设当前任务为猫狗二分类任务,猫的label为1, one-hot编码为[0, 1],狗的label是0,one-hot编码为[1, 0];
假设选取模型的最后输出维度为(N, 2), 其中N为Batch size,2为num_classes。
问题1: 为什么MSE Loss不适合处理分类任务?
如果我们选择MSE Loss作为猫狗二分类任务的损失函数,比如某个样本类别为猫,label为[0, 1], 模型的输出为[0.48, 0.52]。
那么MSE Loss所做的就是引导模型在处理这个样本时,模型输出的第一个值越接近0越好,模型输出的第二个值越接近1越好;但我们有必要让模型的输出精确到0/1吗?
分析
公式如下:
其中,N为样本个数, o u t p u t i output_{i} outputi表示第i个样本的输出(经过softmax函数,输出概率之和为1), l a b e l i label_{i} labeli表示第i个样本的标签(对于猫狗二分类,标签为0/1), 那么 o u t p u t i [ l a b e l i ] output_{i}[label_{i}] outputi[labeli]表示的便是目标类别的输出概率。
NLL Loss函数以目标类别的预测概率作为输入,其曲线如下所示:
从上图可以看出:
因此,相比于MSE Loss, NLL Loss更适用于处理分类问题,而分类任务常用的交叉熵损失正是基于NLL Loss!!!(基于NLL Loss的交叉熵函数的pytorch保姆级复现见下期博客)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。