当前位置:   article > 正文

pytorch基础使用—自定义损失函数_使用pytorch自己写交叉熵损失函数

使用pytorch自己写交叉熵损失函数

1 模板

与定义一个模型类似,定义一个继承nn.Module的类:

  1. __init__:初始化超参数
  2. forward:定义损失的计算方式,并进行前向传播
  3. backward:反向传播(暂未遇到需要修改的情况)
import torch.nn as nn
import torch

class MyLoss(nn.Module):

	def __init__(self):
		# 超参数初始化,如
		slef.param1 = 0
	
	def forward(self, predict, label):  # 一般是预测值和label
		# 进行损失计算,即前向传播,如
		return torch.mean(torch.pow((predict - label), 2))  # 可以自己定义一些计算,但是所有的数学操作必须使用tensor提供的math。也可以用functional提供的一些损失计算,如交叉熵损失。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

下面介绍一些损失函数:

2 损失函数

损失函数分为两类:

  1. 分类损失,如0-1 loss、熵与交叉熵loss、softmax loss及其变种、KL散度、Hinge loss、Exponential loss、Logistic loss、Focal Loss。
  2. 回归损失,如L1 loss、L2 loss、perceptual loss、生成对抗网络损失、GAN的基本损失、-log D trick、Wasserstein GAN、LS-GAN、Loss-sensitive-GAN。

2.1 交叉熵损失

交叉熵损失函数一般用于分类任务,在计算的时候,期望输出和实际输出一般是one-hot形式(只有一个是真实值1,其余都是0)

2.1.1 原理

交叉熵主要是用来判定实际的输出与期望的输出的接近程度
在这里插入图片描述
其中p为期望输出,q为实际输出。

假设期望输出为p=[1, 1, 0],实际输出q1=[0.4, 0.3, 0.3],q2=[0.6, 0.3, 0.1]:
在这里插入图片描述
在这里插入图片描述
可以看到q2和p的交叉熵更小,代表q2和p更加接近。

2.1.2 公式推导

假设有N条数据,out为网络输出,p为期望输出。
对于二分类问题:
首先我们先使用sigmod函数处理网络输出,限制其范围为0-1,结果为q,代表着实际输出:
在这里插入图片描述
对于一个样本i来说,在期望输出为pi的情况下,其正负样本的概率为:
在这里插入图片描述
假设所有样本相互独立,对应的似然函数为:
在这里插入图片描述
对似然函数取对数和相反数即为损失函数
在这里插入图片描述

2.1.3 扩展

交叉熵损失也可以应用到多分类问题,只是此时我们的网络输出out是一个one-hot变量,此时我们需要将out通过softmax函数,而不是sigmod。
假设网络输出N个样本,每个样本C个类别。一个样本的输出out(维度是1xC),其第i个数在这里插入图片描述经过softmax计算如下:
在这里插入图片描述
该样本中其余数也经过这样计算。该样本的编码这样处理后所有值相加为0,然后取其中最大的一个作为在这里插入图片描述。后面就与二分类问题一致了。

2.1.4 nn.CrossEntropyLoss

from torch.nn impiort CrossEntropyLoss  # 导入

loss = CrossEntropyLoss()  # 定义,后面去使用即可
  • 1
  • 2
  • 3

还有一种办法是使用functional中的cross_entropy函数

2.2 Focal Loss

Focal Loss以交叉熵损失为基础,引入主要是为了解决目标检测中正负样本数量极不平衡问题
交叉熵函数如下:
在这里插入图片描述
两个式子合并到一起为:
在这里插入图片描述
由该函数得到的交叉熵损失函数无法解决正负样本的平衡问题。因此经过三个阶段形成了Focal Loss:

  1. 平衡交叉熵
  2. 聚焦损失

2.2.1 平衡交叉熵

一个普遍的解决正负样本的问题的办法是增加权重参数在这里插入图片描述,公式为:在这里插入图片描述
样本t中,当为正样本y=1,负样本y=0。
结合了参数在这里插入图片描述的交叉熵函数为:
在这里插入图片描述

2.2.2 聚焦损失

参数在这里插入图片描述平衡了正负样本不均衡的问题。但是后面又发现难分样本的问题,为此,对于简单的样本增加一个小的权重,让损失函数聚焦在困难样本的训练
设置这样一个调节因子在这里插入图片描述,其中在这里插入图片描述
结合该调节因子后,交叉熵函数如下:
在这里插入图片描述
当p为1,即为易区分样本时,在这里插入图片描述接近0,即降低对易区分样本的损失比例。

2.2.3 Focal Loss

假设N个样本,最终的Focal Loss由上面CE(p. y)得到:
在这里插入图片描述
论文中提示在这里插入图片描述时效果最好。
公式推导与2.1.2小结中一致

2.2.4 Code

def focal_loss(y, p, alpha=0.25, gamma=2):
    p = K.clip(y_pred, 1e-8, 1 - 1e-8)
    return - alpha * y * K.log(p) * (1 - p)**gamma - (1 - alpha) * (1 - y) * K.log(1 - p) * p**gamma
  • 1
  • 2
  • 3

这里只是一个实现思路,配合着公式看,网络上也有通过pyotrch实现的Focal Loss。

3 代价函数、损失函数、目标函数

代价函数(Cost Function):指在整个数据集上衡量模型预测结果与真实结果之间差异的函数。代价函数通常用于监督学习问题中,用于评估模型的性能。代价函数的值越小,表示模型的预测结果与真实结果越接近。

损失函数(Loss Function):指衡量单个样本预测结果与真实结果之间差异的函数。损失函数通常用于监督学习问题中,用于衡量模型在单个样本上的预测误差。损失函数的值越小,表示模型在该样本上的预测结果越接近真实结果。

目标函数(Objective Function):是指在优化问题中需要最小化或最大化的函数。

  1. 目标函数可以是代价函数或损失函数的总和。例如,目标函数=经验风险(代价函数)+结构风险(Cost Function+正则化项)
  2. 也可以是在优化问题中需要优化的其他指标,其选择取决于具体的问题和优化目标。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/321676
推荐阅读
相关标签
  

闽ICP备14008679号