赞
踩
Decoupling Representation and classifier for long-tailed recognition
PDF:https://arxiv.org/pdf/1910.09217.pdf
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks
Official :https://github.com/facebookresearch/classifier-balancing
在学习分类任务的过程中,将通常默认为联合起来学习的类别特征表征与分类器解耦(decoupling),寻求合适的表征来最小化长尾样本分类的负面影响。
作者将分类网络分解为representation learning 和 classification 两部分,系统的研究了这两部分对于Long-tailed问题的影响。通过实验得到的两点发现是:
数据不均衡问题不会影响高质量Representations的学习。即,random sampling策略往往会学到泛化性更好的representations;
使用最简单的random sampling 来学习representations,然后只调整classifier的学习也可以作为一个strong baseline。
(Sampling strategies)
该方法最为常见,即每一个训练样本都有均等的机会概率被选中,即上述公式中 q=1 的情况。
每个类别都有同等的概率被选中,即公平地选取每个类别,然后再从类别中进行样本选取,即上述公式中 q=0 的情况。
本质上是之前两种采样方式的变种,通常是将概率公式中的 q 定值为 0.5。
根据训练中的迭代次数 t(epoch)同时引入样本均衡(IB)与类别均衡(CB)采样并进行适当权重调整的一种新型采样模式,公式为
固定住representations部分,随机初始化classifier的weight和bias参数,并使用class-balanced sampling在训练少量epoch
首先将training set里的每个类别计算feature representaitions的均值,然后在test set上执行最近邻查找。或者将mean features进行L2-Normalization之后,使用余弦距离或者欧氏距离计算相似度。作者指出,余弦相似度可以通过其本身的normalization特性来缓解weight imbalance的问题。
对classifier权重normalized
当 τ = 1 \tau = 1 τ=1 时,就是标准的L2-Normalization;当 τ = 0 \tau = 0 τ=0 时,表示没有进行scaling操作。 τ ∈ ( 0 , 1 ) \tau \in (0,1) τ∈(0,1),其值是通过cross-validation来选择的。
将
f
i
f_{i}
fi 看作是一个可学习的参数,我们通过固定住representations和classifier两部分的weighs来只学习这个scaling factors
通过各类对比实验,该研究得到了如下观察:
学习过程中保持网络结构(比如 global pooling 之后不需要增加额外的全连接层)、超参数选择、学习率和 batch size 的关系和正常分类问题一致(比如 ImageNet),以确保表征学习的质量。
类别均衡采样:采用多 GPU 实现的时候,需要考虑使得每块设备上都有较为均衡的类别样本,避免出现样本种类在卡上过于单一,从而使得 BN 的参数估计不准。
渐进式均衡采样:为提升采样速度,该采样方式可以分两步进行。第一步先从类别中选择所需类别,第二步从对应类别中随机选择样本。
重新学习分类器(cRT):重新随机初始化分类器或者继承特征表示学习阶段的分类器,重点在于保证学习率重置到起始大小并选择 cosine 学习率。
τ-归一化(tau-normalization):τ 的选取在验证集上进行,如果没有验证集可以从训练集模仿平衡验证集,可参考原论文附录 B.5。
可学习参数放缩(LWS):学习率的选择与 cRT 一致,学习过程中要保证分类器参数固定不变,只学习放缩因子。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。