当前位置:   article > 正文

样本不均衡问题 (OHEM, Focal loss)_ohemceloss

ohemceloss

不均衡问题分析

正负样本不均衡

  • 对于物体检测算法,有核心价值的是对应着真实物体的正样本,在训练时会根据其 loss 来调整网络参数。相比之下,负样本对应着图像的背景,如果有大量的负样本参与训练,则会淹没正样本的损失,从而降低网络收敛的效率与检测精度
  • 以 Faster RCNN 为例,在 RPN 部分如果不对 RoI 进行筛选 (包括计算 RPN 损失时进行筛选、生成 Proposal 和进一步筛选 Proposal),就会生成接近 20000 个 Anchor,由于一张图中通常有 10 个左右的物体,导致可能只有 100 个左右的 Anchor 会是正样本,正负样本比例严重不均衡

难易样本不均衡

  • 根据是否容易学习及与标签的重叠程度,可以将所有样本分为 4 类:简单正样本 (Easy Positive)、难正样本 (Hard Positive)、简单负样本 (Easy Negative) 及难负样本 (Hard Negative)
    在这里插入图片描述
  • 难样本指的是分类不太明确的边框,处在前景与背景的过渡区域上,在网络训练中难样本损失会较大,也是我们希望模型去学习优化的样本,利用这部分训练可以提升检测的准确率。然而,大量的样本属于简单样本,虽然简单样本单个损失小,但由于数量众多,因此如果全都计算损失的话,其损失也会比难样本大很多,这种难易样本的不均衡也会影响模型的收敛与精度

值得注意的是,由于负样本中大量的是简单样本,导致难易样本与正负样本这两个不均衡问题有一定的重叠,解决方法往往能同时对这两个问题起作用

类别间样本不均衡

  • 在有些物体检测的数据集中,还会存在类别间的不均衡问题。举个例子,数据集中有 100 万个车辆、1000 个行人的实例标签,样本比例为 1000∶1,属于典型的类别不均衡。这种情况下,如果不做任何处理,使用该数据集进行训练,由于行人这一类别可参考标签太少,会使得模型主要关注车这一类别的检测,网络中的参数主要根据车辆的损失进行优化,导致行人的检测精度大大下降

常用的解决方法

  • (1) Faster RCNN、SSD 等算法在正负样本的筛选时,根据样本与真实物体的 IoU 大小,设置 3:1 的正负样本比例,这一点缓解了正负样本的不均衡,同时也对难易样本不均衡起到了作用
  • (2) Faster RCNN 在 RPN 模块中,通过前景得分排序筛选出了 2000 个左右的候选框,这也会将大量的负样本与简单样本过滤掉,缓解了前两个不均衡问题
  • (3) 权重惩罚:对于难易样本与类别间的不均衡,可以增大难样本与少类别的损失权重,从而增大模型对这些样本的惩罚,缓解不均衡问题
  • (4) 数据增强:从数据侧入手,可以在当前数据集上使用随机生成和添加扰动的方法,从而缓解难易样本和类别间样本等不均衡问题

在线难样本挖掘: OHEM

难负样本挖掘 (Hard Negative Mining, HNM)

  • 难负样本挖掘最初在机器学习中被广泛使用,用于解决类别的不均衡问题。以 SVM 为例,HNM 方法先让模型收敛于当前的工作数据集,然后固定该模型,在数据集中去除简单的样本,添加一些当前无法判断的样本,进行新的训练。这样的交替训练可以使得模型性能达到最优
  • 物体检测方法很难直接使用 HNM 算法进行挖掘。原因在于物体检测算法通常采用 SGD 等优化方法来进行优化,往往需要上万次的参数更新;而如果采用 HNM 交替训练的方法,每迭代几次就固定模型,训练的速度会大大下降

在线难样本挖掘 (Online Hard Example Mining, OHEM)

  • OHEM 可以看做是 HNM 在物体检测算法上的应用,以 Fast(er) RCNN 作为基础检测算法。在标准的 Fast(er) RCNN 中,RPN 网络通过生成 Proposal 并筛选得到 256 个锚框送入 RCNN 网络进行训练,在 Proposal 筛选阶段,需要控制正、负样本的比例为 1:3,总数量不超过 256。上述方法虽然有效缓解了正、负样本的不均衡,但是容易忽略一些较为重要的难负样本,并且固定了正、负样本的比例与最大数量,显然不是最优的选择以此为出发点,OHEM 将交替训练与 SGD 优化方法进行了结合,在每张图片的 RoI 中选择了较难的样本,实现了在线的难样本挖掘

OHEM 网络结构

在这里插入图片描述
如下图所示,OHEM 的改进主要集中在 RCNN 网络部分。图中包含了两个相同的 RCNN 网络,上半部的 a a a 部分是只可读的网络,只进行前向运算;下半部的 b b b 网络即可读也可写,需要完成前向计算与反向传播。在一个 batch 的训练中,基于 Fast(er) RCNN 的 OHEM 算法可以分为以下 4 步:

  • (1) 按照原始 Fast(er) RCNN 算法,经过卷积提取网络与 RoI Pooling 得到每一张图像的 RoI
  • (2) 上半部的 a a a 网络对所有的 RoI 进行前向计算,得到每一个 RoI 的损失
  • (3) 对 RoI 的损失进行排序,进行一步 NMS 操作,然后选择出固定数量损失较大的 RoI 作为难样本
  • (4) 将筛选出的难样本输入到可读写的 b b b 网络中,进行前向计算得到损失,然后反向传播更新网络,并将更新后的参数与上半部的 a a a 网络同步,完成一次迭代

当然,为了实现方便,OHEM 也可以仅采用一个 RCNN 网络,在选择完难样本后将剩下的简单样本损失置 0,可以起到相同的作用


总结

  • 总体上,OHEM 是一个很经典的难样本挖掘 Trick,实现方式简单,可以显著提升网络训练的效率和检测性能,被广泛地应用于难样本的挖掘场景中,并且数据集越大、难度越高,OHEM 对于检测的提升越明显
  • 但是,由于其特殊的损失计算方式,把简单的样本都舍弃了,导致模型无法提升对于简单样本的检测精度,这也是 OHEM 方法的一个弊端

专注难样本: Focal Loss

  • 当前一阶的物体检测算法,如 SSD 和 YOLO 等虽然实现了实时的速度,但精度始终无法与两阶的 Faster RCNN 相比。何凯明等人将其归咎于正、负样本的不均衡:Faster RCNN 在第一个阶段利用得分筛选出了 2000 个左右的 RoI,可以过滤掉大部分的负样本,在第二个阶段通过固定正、负样本比例或者 OHEM 等方法,可以有效解决正、负样本的不均衡问题。而对于 SSD 等一阶网络,由于其需要直接从所有的预选框中进行筛选,即使使用了固定正、负样本比例的方法,仍然效率低下,简单的负样本仍然占据主要地位,导致其精度不如两阶网络
  • 为了解决上述问题,何凯明等人提出了新的损失函数 Focal Loss 及网络结构 RetinaNet,在与同期一阶网络速度相同的前提下,其检测精度比同期最优的二阶网络还要高

Focal Loss

  • 标准交叉熵损失: 标准交叉熵中所有样本的权重都是相同的,因此如果正、负样本不均衡,大量简单的负样本会占据主导地位,少量的难样本与正样本会起不到作用,导致精度变差
    C E ( p , y ) = { − log ⁡ ( p )  if  ( y = 1 ) − log ⁡ ( 1 − p )  otherwise  C E(p, y)=\left\{
    log(p) if (y=1)log(1p) otherwise 
    \right.
    CE(p,y)={log(p)log(1p) if (y=1) otherwise 
    为了方便表示,将 p p p 标记为 p t p_t pt
    p t = { p  if  ( y = 1 ) 1 − p  otherwise  p_t=\left\{
    p if (y=1)1p otherwise 
    \right.
    pt={p1p if (y=1) otherwise 

    C E ( p , y ) = − log ⁡ ( p t ) C E(p, y)=-\log (p_t) CE(p,y)=log(pt)
  • 平衡交叉熵损失: 为了改善样本的不平衡问题,平衡交叉熵在标准的基础上增加了一个系数 α t α_t αt 来平衡正、负样本的权重
    C E ( p , y ) = − α t log ⁡ ( p t ) C E(p, y)=-\alpha_t\log (p_t) CE(p,y)=αtlog(pt)其中
    α t = { α  if  ( y = 1 ) 1 − α  otherwise  α_t=\left\{
    α if (y=1)1α otherwise 
    \right.
    αt={α1α if (y=1) otherwise 
    α α α 为超参. 尽管平衡交叉熵损失改善了正、负样本间的不平衡,但由于其缺乏对难易样本的区分,因此没有办法控制难易样本之间的不均衡
  • Focal Loss: Focal Loss 在平衡交叉熵损失的基础上引入了权重系数 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ 调节难易样本的权重,使得越简单的样本 (i.e. 置信度大的样本) 权重越小。因此,Focal loss 可以同时调节正、负样本与难易样本
    F L ( p , y ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p, y)=-\alpha_t(1-p_t)^\gamma\log (p_t) FL(p,y)=αt(1pt)γlog(pt)其中, γ γ γ 是一个调制因子, γ γ γ 越大,简单样本损失的贡献会越低。对于两个超参数,通常来讲, γ γ γ 增大时, α α α 应当适当减小实验中 γ γ γ 取 2、 α α α 取 0.25 时效果最好 ( α = 0.25 \alpha=0.25 α=0.25 时模型反而会更加注重对负样本的学习…)

Focal loss 对噪声特别敏感,如果数据集中有错误标注的数据,模型反而会着重对这些错误数据进行学习,这也是 Focal loss 的弊端之一了

RetinaNet

为了验证 Focal Loss 的效果,何凯明等人还提出了一个一阶物体检测结构 RetinaNet

在这里插入图片描述

  • 在 Backbone 部分,RetinaNet 利用 ResNetFPN 构建了一个多尺度特征的特征金字塔
  • RetinaNet 使用了类似于 Anchor 的预选框,在每一个金字塔层,使用了 9 个大小不同的预选框
  • 分类子网络:分类子网络为每一个预选框预测其类别,因此其输出特征大小为 K A × W × H KA×W×H KA×W×H A A A 默认为 9, K K K 代表类别数。中间使用全卷积网络与 ReLU 激活函数 (连续 4 个 3 × 3 3\times3 3×3 卷积层 + ReLU,最后再使用 3 × 3 3\times 3 3×3 卷积增加通道数),最后利用 Sigmoid 函数输出预测值
  • 回归子网络:回归子网络与分类子网络平行,预测每一个预选框的偏移量,最终输出特征大小为 4 A × W × W 4A×W×W 4A×W×W
  • Focal Loss:Focal Loss 在训练时作用到所有的预选框上

由于 batch 较小,RetinaNet 冻结了 backbone 的 BN 层,不参与训练,这一点需要注意

参考文献

  • 《深度学习之 PyTorch 物体检测实战》
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/371607
推荐阅读
相关标签
  

闽ICP备14008679号