赞
踩
神经网络学习笔记1——ResNet残差网络、Batch Normalization理解与代码
神经网络学习笔记2——VGGNet神经网络结构与感受野理解与代码
参考博客1
参考博客2
swin-transformer是什么?
解决了什么问题?
优点是什么?
结论是什么?
效果怎么样?
ViT用的是16×16的patch size,也就是16倍的下采样率,从低到高,这些token每个patch的尺寸并不会发生改变,通过全局自注意力操作来实现全局建模,可是面对多尺寸的目标的学习会较差,一单一尺寸处理为主。且面对大图片时序列长度还是过大,计算复杂度平方式递增。
在密集预测型任务如检测和分割或者说在落地项目中使用的图片,多尺度问题是很重要的问题,成熟的模型都会有专门的多尺度特征处理方法。
Swin transformer是在小窗口中进行自注意力(窗口概念在第二块),这些patch组成的小窗口和ViT的patch不同是相较独立的。比如4倍下采样中,将特征图划分成了多个不相交的小窗口区域,Multi-Head Self-Attention只在每个窗口patch内进行。
面对不相交的窗口如何传递信息,如何学习多尺度信息,它提出了patch merging,简单来说就是由小窗口patch合成大窗口patch增大感受野,再通过序号选取的方式去提取出深度特征图,模拟出一种类似池化的操作。
详细来说就是通过一个Patch Merging层进行下采样,如下图所示,比如想下采样两倍,先将四个小patch合成大patch,再通过小patch身上的序号1、2、3、4进行提取,提取的时候是每隔一个点选一个也就是选择同序号,同样序号位置上的 patch 就会被 merge。经过提取之后,原来的这个张量就变成了四个张量,在深度方向进行concat拼接,维度从h × w × c变为h/2 × w/2 × 4c,然后在通过一个LayerNorm层。因为要类比CNN模式,每次经过pooling后通道数只会翻倍,所以这里也只想让他翻2倍,而不是变成4倍,所以紧接着又再做了一次操作,就是在 c 的维度上用一个1x1的卷积(或者全连接层),把通道数降下来变成2c,最后就得到了h/2 × w/2 × 2c的输出。即通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
像ViT的全局自注意力的计算会导致平方倍的复杂度,同样当去做视觉里的下游任务,尤其是密集预测型的任务,或者说遇到非常大尺寸的图片时候,这种全局算自注意力的计算复杂度对比卷积就会有很大算力差别。
文章提出用窗口的方式去做自注意力,也就是Windows Multi-head Self-Attention(W-MSA),W-MSA模块是为了减少计算量。
如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token或patch)与Class序列在Self-Attention计算过程中需要和所有的像素去计算全局。
但在图右侧,将特征图拆分成一个个不重叠的window,使用W-MSA模块时,首先将feature map按照M×M(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。
假设在Swin transformer中输入224×224×3的图片,那么一个patch的大小划分为4×4,那么就有56×56个patch,而每7个patch就组成一个窗口,也就是一个窗口有7×7个patch,一个224×224×3的图片会有8×8=64个窗口。
原论文中有给出下面两个公式,这里忽略了Softmax的计算复杂度:
对比公式(1)和公式(2),虽然这两个公式前面这两项是一样的,只有后面从 (hw) ^ 2变成了 M^2 * h * w,看起来好像差别不大,但其实如果仔细带入数字进去计算就会发现,计算复杂的差距是相当巨大的,因为这里的 hw 如果是56*56的话, M^2 其实只有49,所以是相差了几十甚至上百倍。
transformer初衷是理解上下文,是一种信息的传递交互,采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。根据左右两幅图对比能够发现窗口发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了 M /2 个patch)。
在L1层使用的是W-MSA,L1+1层使用的是SW_MSA,在L1时每个窗口里的patch只能和同一个窗口里的patch相互学习,而到了L1+1层时,由于窗口的移动,导致一些patch进入新的窗口,这些带有上一层窗口信息的patch可以和别的带有上一层前窗口信息的patch相互学习。这就是跨窗连接cross window connection操作,使得窗口与窗口之间有着交互。再结合合并Patch Merging操作,在最后几层的时候,每个patch已经与特征图绝大部分的patch有过交流,也就是感受野已经很大了可以看见图片的绝大部分了。这些局部注意力信息最终会扩散到全局,变相达到全局注意力的效果。简单来说就L1+1层中心的4×4窗口学习融合的信息是L1层四个窗口的信息,因为中心的4×4窗口来源组成是L1层四个窗口的patch,L1层四个窗口的patch经过W-MSA时就已经学习到所在窗口的信息,带有自己窗口的信息。所以L1+1层中心的4×4窗口的学习就是L1层四个窗口的绝大部分相邻信息融合。
其实SW-MSA窗口的移动也存在问题,虽然实现让窗口里的patch可以和其他窗口的patch相互通信,交流到别的窗口的信息。可是移动的前后却带有一个问题,就比如移动前L1是四个窗口,每个窗口都是16个patch,移动后的L1+1是九个窗口,每个窗口大小不一,分别是4\8\16个patch。
有一种简单的做法,就是补零,比如把四个patch的窗口补多12个0,补成16个patch的窗口格式,这样4补12,8补8,补完后就得到9个窗口,再将9个窗口打成batch进行学习。虽然做法直白简单,但是一个batch里的窗口从4个提升到9个,实际上计算量提高了,复杂度也提高了。
Swin transformer提出利用掩码做一次循环移位cyclic shift,具体的做法就是:
这种操作即实现了不同窗口的patch交流,,又不会像补零操作那样窗口增加,计算复杂度提高。但是又产生新问题,就是原中心16patch窗口是不变的,里面的patch是本来就是像素意义上的邻居,是有关系的,可以两两相互做自注意力。可是对于另外3个拼接16patch窗口来说,它们是来自不同区域的特征图,如果它们之间做自注意力那么学习出来的特征可能是混乱的,也就是说它们之间不能当做一个纯粹的窗口去做自注意力。
如何处理拼接窗口,Swin transformer提出利用掩码masked操作
比如这里有一个已经进过移动拼接的14×14×3的特征图,0号窗口占7×7个patch,1号与3号是4×7个patch,2号与6号是3×7个patch,4号4×4,5号和7号是3×4,8号是3×3,一共就是14×14个patch(窗口从左上角分别向右侧和下方各偏移了 M /2 个patch)。
0号窗口是一个完整的窗口,可以直接使用自注意力,3号和6号是属于拼接窗口,它自身不可以直接做自注意力。
所以先执行前面的操作将3号方块和6号方块的patch提取出来,拉长为一个向量A,这个向量A中3号patch的值有4×7=28个,6号patch的值有3×7=21个。再通过向量A进行转置操作得到向量B。通过向量A、B的矩阵乘法进行自注意力计算,得到自注意力矩阵C,矩阵C中可以具体区分成四种类型,分别是:
其中3×3和6×6是符合自注意力理念的,3×6和6×3是拼接的混乱值,所以我们只需要3×3和6×6的数据,而3×6和6×3是需要masked掉的。
那么如何去处理3号方块+6号方块的窗口呢,Swin transformer提出一个巧妙的思路就是使用一个掩码模板矩阵D,让矩阵C与矩阵D相加,本来矩阵C里的那些数值是很小的值(大概是0点以下的值),3×3和6×6的数据加上0是不会变化的,而3×6和6×3的数据加上-100则会变成一个很大的负数,这是将这些值都进行softmax操作,那么那些负数就会归为0,剩下的也就是我们所需要的3×3和6×6数据。
讲完了3+6窗口那么继续看看1+2窗口,结合上面的思路,可以发现1+2窗口和3+6窗口是很不一样的,这个不一样产生于拉直向量A上。
可以看见向量A里1号patch的值和2号patch的值是交错排序的,这也导致转置向量B以及自注意矩阵C的变化。
主要说说矩阵C的变化,它依然是分成4种类型(1×1,1×2,2×1,2×2),但不再是集中化和区域化了,而是一个横竖条纹围棋格式的矩阵,这种变化也导致了掩码模板矩阵D的设计,由于这种格式比较麻烦,我就没有专门一个个画出来,可以参考一下,Swin transformer提供的掩码模板。
至于4+5+7+8窗口,其实就是3+6和1+2窗口的合体,我就画了一个拉直向量A的图,具体可以自己去理解,需要结合3+6和1+2的规律。具体的掩码模板在上图有。
做完了多头自注意力后,需要把拼接的特征图还原回去,以保证它的相对位置不变,语义信息不变。如果不还原的话,那么循环轮到下一次Blocks模块时,学习的W-MSA是混乱的,学习SW-MSA时又将移动过的特征图继续拆分拼接,向右下角拼接,多轮下来学到的特征会越来越混乱,特征图也会处于不停打乱的状态。
Swin Transformer Blocks有两种结构,区别在于窗口多头自注意力的计算一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以堆叠Swin Transformer Block的次数都是偶数(在整体模型里Swin Transformer Blocks下的×2、×6就是因为成对使用的意思)。
结合图片和公式进行前向模拟:
前向过程:
参考别的类似图:
Swin transformer在做实验的时候,表示做SW-MSA会比只做W-MSA好,做相对位置rel. pos.会比做绝对位置abs. pos.和没有位置no pos.好。
回到论文公式,会发现它的自注意力公式之前我们讲的多了一个+B操作,这个B就是相对位置偏置。
1、假设window的大小M×M是2×2patch,计算window内的自注意力时,先计算相对位置索引,这里的索引并不是偏置B,而是一个构成B的要素。
2、这里要区分出一个概念,就是绝对位置与相对位置,相对位置是可以结合参考系的方式理解,就是参考主体减去自身与参考客体计算而来。
3、不同的参考主体patch都可以计算出一种对应的相对位置索引,将每种计算出来的索引展平并拼接到一个矩阵A,A矩阵大小为(M×M)2。
4、我们可以观察A矩阵里的相对位置索引分布规律,就比如右边位置的位置概念,红色patch右边是蓝色,以红色为参考主体,蓝色的相对位置是[0,-1]。又比如黄色patch的右边是绿色,相对位置也会是[0,-1]。仔细观察后会发现上下左右,左上左下右上右下等位置概念是相同的索引的。
5、通过对行列做哈希变换,将2D的相对位置索引变为1D以精简计算,作者这里采用的哈希公式是(x+M-1)×(2M-1)+(y+M-1),其中x为行,y为列号,M为窗口大小。
6、将2D降为1D可以用相加或相乘等简单方法实现,但是会出现重复值,比如说红色patch右方是[0,-1],下方[-1,0],在2D时还是较明显的,但是用相加(-1)或相乘(0)时就会出现不同输入却输出相同的干扰,这时可以使用哈希算法来解决这个问题。
7、实现相对位置的特有值,相对位置索引总共有(2M-1)×(2M-1)种,那么就可以随机生成(2M-1)*(2M-1)个随机相对位置偏置(nn.Parameter可学参数),根据相对位置索引,去获取对应的相对位置偏置,也就是公式里面的B,进行多头自注意力的计算。
未完待续。。。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。