赞
踩
参考链接:
GitHub地址
End-to-End Object Detection with Transformers
DETR,即DEtection TRansformer,将目标检测任务转化成序列预测任务,使用transformer编码器-解码器结构和双边匹配的方法,由输入图像直接预测得到预测结果序列,整个过程就是使用CNN提取特征后编码解码得到预测输出。和SOTA(state-of-the-art)的检测方法不同,模型代替了复杂的目标检测传统套路,没有proposal(Faster R-CNN),没有anchor(YOLO),没有center(CenterNet),也没有NMS,直接预测检测框和类别,利用二分图匹配的算法,将CNN和transformer结合,实现目标检测的任务。
简言之,DETR将目标检测任务看做集合预测问题,对于一张图片,预测固定数量的物体(文中是100个,源码中可以指定)模型根据这些物体对象与图片中全局上下文的关系直接并行输出预测集,也就是Transformer一次性解码出图片中所有物体的预测结果。
两个关键部分:
backbone + transformer + prediction
CNN + encoder + decoder + FFN
传统的CNN网络,将输入的图像 3 × W 0 × W 0 3\times W_0\times W_0 3×W0×W0变成尺度为 2048 × W 0 / 32 × H 0 / 32 2048\times W_0/32\times H_0/32 2048×W0/32×H0/32的特征图。
将输入的image features降维并flatten,然后送入图中左半部分的编码器中,和空间位置编码一起并行经过多个自注意力分支、正则化和FFN,得到一组长度为N的预测目标序列。
总结一下,主要有这几个步骤:
其中位置编码的生成方式采用了原始transformer论文中的固定position encoding,即对于每个HW向量,用不同频率的sin函数对高(H)这个维度生成d/2维的position encoding,用不同频率cos函数对宽(W)这个维度生成d/2维的position encoding,然后将两个d/2维度的position encoding concat成d维的position encoding。
将编码器encoder得到的预测目标序列经过图中右半部分所示的transformer decoder中,并行地解码得到输出序列。每个层可以解码N个目标,由于解码器的位置不变性,除了每个像素本身的信息,位置信息也很重要,所以这N个输入嵌入必须不同以产生不用的结果,所以要在每层都加上position encoding。(使用transformer解决图片类输入的时候,一定要注意position信息的处理。)
总结一下,主要有两个输入:
使用共享参数的 FFNs 独立解码为包含两种:类别得分和预测框坐标的最终检测结果(N个),FFNs由一个具有ReLU激活函数和d维隐藏层的3层感知器和一个线性投影层构成。FFN预测框的标准化中心坐标、高度和宽度 w.r.t 输入图像,然后线性层使用softmax函数预测类标签。
基于序列预测的思想,将网络预测结果看做长度为N的固定顺序序列: y ^ \hat{y} y^, y ^ = y i ^ , i ∈ ( 1 , N ) \hat{y}=\hat{y_i},i\in(1,N) y^=yi^,i∈(1,N),其中N的值固定,且远大于图像中GT目标数。 y i ^ = ( c i ^ , b i ^ ) \hat{y_i}=(\hat{c_i},\hat{b_i}) yi^=(ci^,bi^)。
将ground truth也看成一个序列: y i = ( c i , b i ) y_i=(c_i,b_i) yi=(ci,bi),长度小于N,用∅对该序列进行填充使其长度等于N。 c i c_i ci表示该目标所属的真实类别, b i b_i bi表示为一个四元组,包含了目标框的中心点坐标和宽高,且均为相对图像的比例坐标。
预测任务可以看作
y
y
y和
y
^
\hat{y}
y^之间的二分图匹配问题,采用匈牙利算法作为二分匹配算法的求解方法,定义最小匹配的策略如下:
σ
^
=
argmin
σ
∈
S
N
∑
i
N
L
match
(
y
i
,
y
^
σ
(
i
)
)
\hat{\sigma}=\underset{\sigma \in \mathfrak{S}_{N}}{\operatorname{argmin}} \sum_{i}^{N} \mathcal{L}_{\operatorname{match}}\left(y_{i}, \hat{y}_{\sigma(i)}\right)
σ^=σ∈SNargmini∑NLmatch(yi,y^σ(i))
求出最小损失时的匹配策略
σ
^
\hat\sigma
σ^,对于
L
m
a
t
c
h
L_{match}
Lmatch同时考虑了类别预测损失和真实框之间的相似度预测。
对于 σ ( i ) \sigma(i) σ(i), c i c_i ci的预测类别置信度为 p ^ σ ( i ) ( c i ) \hat{p}_{\sigma(i)}\left(c_{i}\right) p^σ(i)(ci),边界框预测为 b ^ σ ( i ) \hat{b}_{\sigma(i)} b^σ(i),对于非空的匹配, L m a t c h L_{match} Lmatch定为:
−
1
{
c
i
≠
∅
}
p
^
σ
(
i
)
(
c
i
)
+
1
{
c
i
≠
∅
}
L
b
o
x
(
b
i
,
b
^
σ
(
i
)
)
-\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \hat{p}_{\sigma(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\mathrm{box}}\left(b_{i}, \hat{b}_{\sigma(i)}\right)
−1{ci=∅}p^σ(i)(ci)+1{ci=∅}Lbox(bi,b^σ(i))
总结一下,就是:
当配对的object是∅时,人为规定配对的cost
L
m
a
t
c
h
=
0
L_{match}=0
Lmatch=0;
当配对的object是真实物体时,如果预测的prediction box类别和image object类别相同的概率越大,或者二者的box差距越小,配对的cost
L
m
a
t
c
h
L_{match}
Lmatch就越小。
考虑到尺度问题,将L1
损
失
和
I
o
U
损
失
线
性
组
合
,
得
到
损失和IoU损失线性组合,得到
损失和IoU损失线性组合,得到L_{box}$:
λ
iou
L
iou
(
b
i
,
b
^
σ
(
i
)
)
+
λ
L
1
∥
b
i
−
b
^
σ
(
i
)
∥
1
\lambda_{\text {iou }} \mathcal{L}_{\text {iou }}\left(b_{i}, \hat{b}_{\sigma(i)}\right)+\lambda_{\mathrm{L} 1}\left\|b_{i}-\hat{b}_{\sigma(i)}\right\|_{1}
λiou Liou (bi,b^σ(i))+λL1∥∥∥bi−b^σ(i)∥∥∥1
第二项是两个box中心坐标的L1距离,如果只用L1距离,当box的大小不同时,即使L1距离相同,差距也会不同。所以又加入了与box大小无关的第一项。
得到最优二分图匹配后,来计算整体的损失,来评价生成的prediction boxes的好坏:
L
Hungarian
(
y
,
y
^
)
=
∑
i
=
1
N
[
−
log
p
^
σ
^
(
i
)
(
c
i
)
+
1
{
c
i
≠
∅
}
L
box
(
b
i
,
b
^
σ
^
(
i
)
)
]
\mathcal{L}_{\text {Hungarian }}(y, \hat{y})=\sum_{i=1}^{N}\left[-\log \hat{p}_{\hat{\sigma}(i)}\left(c_{i}\right)+\mathbb{1}_{\left\{c_{i} \neq \varnothing\right\}} \mathcal{L}_{\text {box }}\left(b_{i}, \hat{b}_{\hat{\sigma}}(i)\right)\right]
LHungarian (y,y^)=i=1∑N[−logp^σ^(i)(ci)+1{ci=∅}Lbox (bi,b^σ^(i))]
其中
σ
^
\hat{\sigma}
σ^为最佳匹配,将第i个image object匹配到第
σ
^
(
i
)
\hat{\sigma}(i)
σ^(i)个prediction box。
用 L Hungarian \mathcal{L}_{\text {Hungarian }} LHungarian 做反向传播即可优化transformer。
主要和Faster RCNN进行了对比。DC5后缀表示在主干网络的最后一个阶段加入一个dilation,并从这个阶段的第一个卷积中取出一个stride来增加特征分辨率。
DETR对于大目标的检测效果比faster RCNN有所提升,对小目标的检测表现较差。
不足在于训练阶段需要的时间和硬件资源需求都较大,训练难度较大。
下图为论文中给出的基于pytorch的关键代码:
A u t h o r : c h i e r Author: chier Author:chier
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。