当前位置:   article > 正文

Swin-Transformer听课笔记_swintransformer的缺点

swintransformer的缺点

前言

Swin-Transformer是微软亚洲研究院发表于ICCV 2021的一篇论文,并获得了当年的最佳论文。对比 ViT,Swin-Transformer大大降低了计算量,提供了更加通用的基于Transformer的计算机视觉任务的主干网络,并且能应用到分类、检测、分割等多种计算机视觉任务中。

1. 模型的特点

Swin Transformer与Vision Transformer对比有两个突出特征:

  1. 首先Swin Transformer所构建的feature map是具有层次性的,与卷积网络类似,随着特征层的不断加深,feature map的高和宽越来越小。由于像CNN一样能够构建出具有层次性的特征图,Swin Transformer对于目标检测和分割任务都有更大的优势。ViT则一直保持16倍的下采样,没办法构建出具有层次性的特征图。
  2. Swin Transformer是用一个个窗口的形式将feature map分割开了,窗口与窗口之间没有重叠。而在ViT中feature map是一个整体,没有对其进行分割。窗口的划分使得Swin Transformer能够在每个窗口内部进行多头自注意力的计算,窗口之间不去进行信息传递。好处在于能够大大降低运算量,尤其是在下采样率较低的浅层网络,相比于直接对整个特征图进行多头自注意力的计算,大大降低了运算量。

在这里插入图片描述

2. 模型结构

假设输入一张高H宽W的3通道彩色图像,首先通过Patch Partition模块输出 H 4 × W 4 × 48 \frac{H}{4} \times \frac{W}{4} \times 48 4H×4W×48的特征图,之后与ResNet网络类似,通过不同的Stage对特征图进行下采样,并且每次下采样后Channel数就会翻倍。需要注意Stage1模块与Stage2、3、4不同之处在于其第一个模块是一个Linear Embeding层,而2、3、4第一个模块是Patch Merging层。实际上,Patch Partition + Linear Embeding的功能与Patch Merging差不多。
在这里插入图片描述

2.1 Patch Partition + Linear Embedding

例如输入一张图像,Patch Partition会用 4 × 4 4 \times 4 4×4大小的窗口对其进行分割。分割之后对每一个窗口在channel方向进行展平,也就是对每个像素沿深度方向进行拼接。由于每个像素都是RGB三个通道的,则 16 × 3 = 48 16 \times 3 = 48 16×3=48,所以通过Patch Partition之后,图像的高和宽就缩减为原来的 1 4 \frac{1}{4} 41,通道数变为48。

接下来再通过Linear Embedding层对输入特征矩阵进行调整,输出通道数变为 C C C C C C的大小与模型的选择有关。注意,在Linear Embedding中还包括了一个Layer Norm。在实际实现中,Patch Partition和Linear Embedding的操作都是通过卷积来完成的。
在这里插入图片描述

2.2 Patch Merging

Patch Merging的作用是进行下采样,特征图的高和宽会缩减为原来的一半,并且通道数会翻倍。假设输入特征矩阵的高和宽都是 4 × 4 4 \times 4 4×4, 输入通道数为1。以 2 × 2 2 \times 2 2×2大小作为一个窗口,将每个窗口中相同位置的像素取出来,就能得到4个特征矩阵。将这4个特征矩阵在深度方向进行拼接,然后在深度方向进行LayerNorm的处理,最后再通过一个全连接层,对每一个像素的深度方向进行线性映射,输出通道数为2。
在这里插入图片描述

2.3 Swin-Transformer Block

对于每个Stage还会重复堆叠每个Swin Transformer Block,注意重复次数均为偶数。**为什么都要重复偶数次呢?**结构如下:
在第一个Block中可以看到,是将ViT中的Multi-head Self-Attention模块替换成了W-MSA (Windows Multi-head Self-Attention) 。第二个Block中则是使用了SW-MSA (Shifted Windows Multi-head Self-Attention) 。 这两个模块都是成对去使用的。
在这里插入图片描述
对于Swin Transformer模型的架构当中,其实后面还有一些层结构,比如对于分类网络而言,在Stage 4后面还会接上Layer Norm、全局池化以及一个全连接层进行一个最终输出。

2.4 W-MSA

回顾Multi-head Self-Attention,会对每一个像素求Q、K、V,对于每一个像素所求得的Q会和特征图中的每一个K进行匹配,然后再进行一系列的操作。
在这里插入图片描述
对于Swin-Transformer中所提出的Windows Multi-head Self-Attention模块,会对特征图分成一个一个Window,对每一个Window的内部进行Multi-head Self-Attention,但是,Window和Window之间没有任何通信。

目的:减少计算量。
缺点:窗口之间无法进行信息交互。
在这里插入图片描述
理论计算量:
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(MSA) = 4hwC^2 + 2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \Omega(W-MSA) = 4hwC^2 + 2M^2hwC Ω(WMSA)=4hwC2+2M2hwC

  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

MSA计算量公式是怎么来的?
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q,K,V) = SoftMax(\frac{QK^T}{\sqrt{d}})V Attention(Q,K,V)=SoftMax(d QKT)V

对于feature map中的每个像素(或称为token,patch),都要通过矩阵 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv生成对应的query(q),key(k)以及value(v)。这里假设q,k,v的向量长度与feature map的深度C保持一致。那么对应所有像素生成Q的过程如下式:
A h w × C ⋅ W q C × C = Q h w × C A_{hw \times C} \cdot {W^q}_{C \times C} = Q_{hw \times C} Ahw×CWqC×C=Qhw×C

  • A h w × C A_{hw \times C} Ahw×C是所有像素(token)拼接再一起得到的矩阵(一共有 h w hw hw个像素,每个像素深度为 C C C
  • W q C × C {W^q}_{C \times C} WqC×C为生成query的变换矩阵
  • Q h w × C Q_{hw \times C} Qhw×C是所有像素与变换矩阵 W q C × C W_q^{C \times C} WqC×C相乘而得到的所有query拼接后的矩阵

根据矩阵乘法运算规则,可以得到生成Q共进行了 h w C 2 hwC^2 hwC2次乘法运算(一共 h w C hwC hwC个像素,每个像素进行C次乘法)。同理,生成K和V都是 h w C 2 hwC^2 hwC2,那么总共 3 h w C 2 3hwC^2 3hwC2次乘法。

接下来 Q Q Q K T K^T KT相乘,对应计算量为 ( h w ) 2 C (hw)^2C (hw)2C
Q h w × C ⋅ K T C × h w = X h w × h w Q_{hw \times C} \cdot {K^T}_{C \times hw} = X_{hw \times hw} Qhw×CKTC×hw=Xhw×hw

忽略除以 d \sqrt{d} d 以及softmax的计算量,假设通过 S o f t M a x ( Q K T d ) SoftMax(\frac{QK^T}{\sqrt{d}}) SoftMax(d QKT)得到 Λ h w × h w \Lambda^{hw \times hw} Λhw×hw,最后还要乘上 V V V,对应的计算量是 ( h w ) 2 C (hw)^2C (hw)2C
Λ h w × h w ⋅ V h w × C = B h w × C \Lambda^{hw \times hw} \cdot V^{hw \times C} = B^{hw \times C} Λhw×hwVhw×C=Bhw×C

综上,单头Self-Attention总共需要 3 h w C 2 + ( h w ) 2 C + ( h w ) 2 C = 3 h w C 2 + 2 ( h w ) 2 C 3hwC^2+(hw)^2C+(hw)^2C=3hwC^2+2(hw)^2C 3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C。而对于多头注意力机制,有
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O MultiHead(Q,K,V) = Concat(head_1,\ldots,head_h)W^O MultiHead(Q,K,V)=Concat(head1,,headh)WO
也就是说,多头注意力机制仅是将每个单头自注意力拼接起来乘上了矩阵 W O W^O WO,计算量为:
B h w × C ⋅ W O C × C = O h w × C B^{hw \times C} \cdot W_O^{C \times C}=O^{hw \times C} Bhw×CWOC×C=Ohw×C
所以总共加起来是: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C

W-MSA的计算公式是怎么来的?
W-MSA是将整个feature map划分为一个个窗口(Windows),假设每个窗口的高和宽都是M,那么总共会得到 h M × w M \frac{h}{M} \times \frac{w}{M} Mh×Mw个窗口,然后再、在每个窗口内使用多头注意力模块。在MSA的计算中,高 h h h w w w深度为 C C C的feature map的计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C,带入 M × M M \times M M×M有:
4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2+2(M)^4C 4(MC)2+2(M)4C
又因为有 h M × w M \frac{h}{M} \times \frac{w}{M} Mh×Mw个窗口,则:
h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w C 2 + 2 M 2 h w C \frac{h}{M} \times \frac{w}{M} \times (4(MC)^2+2(M)^4C)=4hwC^2+2M^2hwC Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2M2hwC
故W-MSA模块的计算量为: 4 h w C 2 + 2 M 2 h w C 4hwC^2+2M^2hwC 4hwC2+2M2hwC

假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:
2 ( h w ) 2 C − 2 M 2 h w C = 2 × 11 2 4 × 128 − 2 × 7 2 × 11 2 2 × 128 = 40124743680 2(hw)^2C - 2M^2hwC = 2 \times 112^4 \times 128 - 2 \times 7^2 \times 112^2 \times 128 = 40124743680 2(hw)2C2M2hwC=2×1124×1282×72×1122×128=40124743680

2.5 SW-MSA

使用W-MSA的问题在于窗口之间没有信息交互,为了解决这个问题作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块。在第l层上采用的是W-MSA模块的话,在第l+1层则要使用SW-MSA模块。在SW-MSA模块中Window的划分有所不同,可以看作将每个Window向右和向下移动了两个像素,融合了不同Window之间的信息。
在这里插入图片描述
但是通过Shifted Window之后,如果要并行计算,就需要对不足 4 × 4 4 \times 4 4×4大小的Window进行填充,相当于计算9个Window的注意力,带来了计算量的增加。论文中提出了一种简化运算的方法:
将Window 0标记为区域A,Window 1 和 2标记区域C,Window 3 和 6标记为区域B。
在这里插入图片描述
先将A和C移动到下面:
在这里插入图片描述
再将A和B移动到右边:
在这里插入图片描述

移动之后可以重新划分Window。Window 4不变,将Window 3和5划分到一起,同样将7和1合并在一起,将8、6、2、0合并在一起。这样就在新的4个Window内进行自注意力计算,保持了计算量不变。

但是如果直接简单粗暴地在每个Window内进行计算,又会引入新的问题。例如,对于Window 5 和 Window 3来说,它们本来是两个不相邻的区域,但是现在强行划分在了同一个Window内,直接对它们进行MSA计算是有问题的。所以就希望能够在区域内单独计算Window 5和Window 3的MSA。
原论文中采用了Masked MSA来解决这一问题。假设求得了0位置处的 q 0 q_0 q0 q 0 q_0 q0要和所有位置的像素进行匹配,就会依次生成 α 0 , 0 , α 0 , 1 , … , α 0 , 15 \alpha_{0,0},\alpha_{0,1},\ldots,\alpha_{0,15} α0,0,α0,1,,α0,15,对应Attention计算公式中 Q K T QK^T QKT的过程,但是,在计算Window 5内部的MSA时,不希望引入Window 3的信息。源码实现中,将Window 3所对应的 α 0 , 2 , α 0 , 3 \alpha_{0,2},\alpha_{0,3} α0,2,α0,3等全部减去100。由于 α \alpha α本身数值是很小的,所以在减去100之后就变成了非常大的负数,在经过 S o f t m a x Softmax Softmax处理之后,Window 3所对应的 α \alpha α就全部近似为0了。
在这里插入图片描述
注意:全部计算完成之后,需要将数据还原回初始位置。

2.6 相对位置偏置(Relative position bias)

A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d + B ) V Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt {d}}+B)V Attention(Q,K,V)=SoftMax(d QKT+B)V
在计算Attention时,加上了一个相对位置偏置 B B B。通过下表可以看出,对比不加入位置编码和加入ViT中的绝对位置编码,加入相对位置偏置后的结果,在分类、目标检测、分割任务中都会表现得更好。
在这里插入图片描述
**什么是相对位置偏置?**假设feature map是 2 × 2 2 \times 2 2×2大小的,那对于蓝色的像素,其绝对位置索引就是 ( 0 , 0 ) (0,0) (0,0),第0行,第0列。那么匹配蓝色像素时的相对位置索引,就是用蓝色的绝对位置索引减去相应位置的绝对位置索引。将每一个相对位置索引的矩阵在行方向上进行展平,拼接在一起可以得到一个大的矩阵。根据每个位置的相对位置索引都可以在Relative position bias当中取到一个对应的参数。
在这里插入图片描述
在原作者的代码当中使用的并不是一个二维的位置坐标,而是使用了一维坐标。如何将二维转化成一维呢?
首先将偏移从0开始,行、列标上加上 M − 1 M-1 M1 M M M对应窗口大小。然后在行标上乘上 2 M − 1 2M-1 2M1,最后再将行标和列表相加,得到最终的一维位置矩阵(relative position index)。然后在元素个数为 ( 2 M − 1 ) × ( 2 M − 1 ) (2M-1) \times (2M-1) (2M1)×(2M1)的relative position bias table中取出对应的值,最终得到relative position bias,即最终使用到的B。

在这里插入图片描述

参考

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/909693
推荐阅读
相关标签
  

闽ICP备14008679号