当前位置:   article > 正文

Swin Transformer(W-MSA详解)代码+图解

w-msa

2. Window & Shifted Window based Self-Attention

Swin Transformer另一个重要的改进就是window-based的self-attention layer,之前提到过,ViT的一个缺点是计算复杂度是和patch数量成平方关系的,为了减少计算量,Swin的做法是将输入图片划分成不重合的windows,然后在不同的window内进行self-attention计算。假设一个图片有  的patches,每个window包含MxM个patches,作者将其成为window based self-attention(W-MSA)layer,W-MSA和multi-head self-attention(MSA)的计算复杂度分别为:

 

由于window内部的patch数量远小于图片patch数量,并且window数量是保持不变的,W-MSA的计算复杂度和图像尺寸呈线性关系,从而大大降低了模型的计算复杂度。

虽然W-MSA能够降低计算复杂度,但是不重合的window之间缺乏信息交流,这样其实就失去了transformer利用self-attention从全局构建关系的能力,于是文章进一步引入shifted window partition来跨window进行信息交流,作者将其成为shifted window based self-attention(SW-MSA)。

 如上图所示,Layer 1中8x8尺寸feature map划分成2x2个patch,每个patch尺寸为4x4, 通过将patch位置整体平移1/2个patch大小,在下一层得到新的window,包括3x3个不重合的patch。移动window的划分方式使上一层相邻的不重合window之间引入连接,大大的增加了感受野。

但这样做带来的另一个问题就是window内部patch的数量从原本的4个增加到了9个,为了让patch数量保持不变,如下图所示,作者的解决思路是把平移之后左上角A,B,C部分的patch与右下角不满足4x4尺度的patch拼接,这样patch的数量还是4个,但是又满足了window外的信息交互,作者将其成为cyclic shift。

 

 image.png

window的划分与合并

https://gitee.com/xiaomoon/image/raw/master/Img/image-20210530155145008.png

  1. # window_partition是划分,window_reverse是合并
  2. def window_partition(x, window_size):
  3. """
  4. Args:
  5. x: (B, H, W, C)
  6. window_size (int): window size
  7. Returns:
  8. windows: (num_windows*B, window_size, window_size, C)
  9. """
  10. B, H, W, C = x.shape
  11. x = x.reshape([B, H // window_size, window_size, W // window_size, window_size, C])
  12. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
  13. return windows
  14. def window_reverse(windows, window_size, H, W):
  15. """
  16. Args:
  17. windows: (num_windows*B, window_size, window_size, C)
  18. window_size (int): Window size
  19. H (int): Height of image
  20. W (int): Width of image
  21. Returns:
  22. x: (B, H, W, C)
  23. """
  24. B = int(windows.shape[0] / (H * W / window_size / window_size))
  25. x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1])
  26. x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
  27. return x

Window Attention

这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

我们先简单看下公式

 主要区别是:在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。

在原始计算Attention的公式中的Q,K时加入了相对位置编码

总体代码:

  1. window_size =(2,2)
  2. coords_flatten = np.array([[0, 0, 1, 1],
  3. [0, 1, 0, 1]])
  4. new = torch.tensor(coords_flatten)
  5. new_first=new[:, :, None] # (2,4,1)
  6. new_second=new[:, None, :] # (2,1,4)
  7. relative_coords = (new_first-new_second).permute(1, 2, 0).contiguous() # (4,4,2):4个4行2列的矩阵
  8. relative_coords[:, :, 0] += window_size[0] - 1 # 在每个矩阵的第0列加1
  9. relative_coords[:, :, 1] += window_size[1] - 1 # 在每个矩阵的第1列加1
  10. relative_coords[:, :, 0] *= 2 * window_size[1] - 1 # 在每个矩阵的第0列*3
  11. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 将4个矩阵的每一个矩阵按行求和,作为新的行(4,2)==》(1,4)

下面我把涉及到相关位置编码的逻辑给单独拿出来,这部分比较绕

首先QK计算出来的Attention张量形状为(numWindows*B, num_heads, window_size*window_size, window_size*window_size)

而对于Attention张量来说,以不同元素为原点,其他元素的坐标也是不同的,以window_size=2为例,其相对位置编码如下图所示

首先我们利用torch.arangetorch.meshgrid函数生成对应的坐标,这里我们以windowsize=2为例子

关于torch.meshgrid函数请看: 【pytorch】torch.meshgrid()==>常用于生成二维网格,比如图像的坐标点_小马牛的博客-CSDN博客

  1. coords_h = torch.arange(self.window_size[0])
  2. coords_w = torch.arange(self.window_size[1])
  3. coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
  4. """
  5. 此时是两个张量,每个 张量是一个2*2的二维矩阵
  6. (tensor([[0, 0],
  7. [1, 1]]),
  8. tensor([[0, 1],
  9. [0, 1]]))
  10. """

然后堆叠起来,展开为一个二维向量

  1. coords = torch.stack(coords) # 2, Wh, Ww
  2. # 将两个tensor堆叠起来,就变成了一个2*2*2的三维矩阵
  3. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  4. # 再将三维矩阵压扁为2*4的一个二维矩阵
  5. """
  6. tensor([[0, 0, 1, 1],
  7. [0, 1, 0, 1]])
  8. """

利用广播机制,分别在第二维,第一维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww的张量

解释:

relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1

relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww

relative_coords_first的None是在第二维,那么就是在第二维插入了一个维度1,就是由 2, wh*ww ==>2, wh*ww, 1

relative_coords_second的None是在第一维,那么就是在第一维插入了一个维度1,就是由 2, wh*ww ==>2, 1,wh*ww

图二a :

  1. relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
  2. relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
  3. relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量
  4. # 此处加上下面这一句就是图二的第一个
  5. relative_coords = relative_coords.permute(1, 2, 0).contiguous()

因为采取的是相减,所以得到的索引是从负数开始的

图二a对应的索引矩阵:

 

 

  

图二b :

因为采取的是相减,所以得到的索引是从负数开始的,我们加上偏移量,让其从0开始

  1. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  2. relative_coords[:, :, 0] += self.window_size[0] - 1
  3. relative_coords[:, :, 1] += self.window_size[1] - 1

此处相当于在矩阵中每个加(1,1),相当于在索引矩阵上加2 

图二c :

 后续我们需要将其展开成一维偏移量。而对于(1,2)和(2,1)这两个坐标。在二维上是不同的,但是通过将x,y坐标相加转换为一维偏移的时候,他的偏移量是相等的

所以最后我们对其中做了个乘法操作,以进行区分

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

图二:

图二d :

然后再最后一维上进行求和,展开成一个一维坐标,并注册为一个不参与网络学习的变量

  1. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  2. self.register_buffer("relative_position_index", relative_position_index)

现在得到了位置索引,那么下一步就是输入进计算公式(4)进行forward

特征图移位操作

代码里对特征图移位是通过torch.roll来实现的,下面是示意图

如果需要 reverse cyclic shift的话只需把参数 shifts设置为对应的正数值。

Attention Mask

我认为这是Swin Transformer的精华,通过设置合理的mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果。

首先我们对Shift Window后的每个窗口都给上index,并且做一个roll操作(window_size=2, shift_size=1)

我们希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。例如下图的5353与5353进行计算,那么只会5与5计算,3与3计算,而不会5与3计算

最后正确的结果如下图所示

而要想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask

相关代码如下:

  1. if self.shift_size > 0:
  2. # calculate attention mask for SW-MSA
  3. H, W = self.input_resolution
  4. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  5. h_slices = (slice(0, -self.window_size),
  6. slice(-self.window_size, -self.shift_size),
  7. slice(-self.shift_size, None))
  8. w_slices = (slice(0, -self.window_size),
  9. slice(-self.window_size, -self.shift_size),
  10. slice(-self.shift_size, None))
  11. cnt = 0
  12. for h in h_slices:
  13. for w in w_slices:
  14. img_mask[:, h, w, :] = cnt
  15. cnt += 1
  16. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  17. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  18. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  19. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

以上图的设置,我们用这段代码会得到这样的一个mask

  1. tensor([[[[[ 0., 0., 0., 0.],
  2. [ 0., 0., 0., 0.],
  3. [ 0., 0., 0., 0.],
  4. [ 0., 0., 0., 0.]]],
  5. [[[ 0., -100., 0., -100.],
  6. [-100., 0., -100., 0.],
  7. [ 0., -100., 0., -100.],
  8. [-100., 0., -100., 0.]]],
  9. [[[ 0., 0., -100., -100.],
  10. [ 0., 0., -100., -100.],
  11. [-100., -100., 0., 0.],
  12. [-100., -100., 0., 0.]]],
  13. [[[ 0., -100., -100., -100.],
  14. [-100., 0., -100., -100.],
  15. [-100., -100., 0., -100.],
  16. [-100., -100., -100., 0.]]]]])

在之前的window attention模块的前向代码里,包含这么一段

  1. if mask is not None:
  2. nW = mask.shape[0]
  3. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  4. attn = attn.view(-1, self.num_heads, N, N)
  5. attn = self.softmax(attn)

将mask加到attention的计算结果,并进行softmax。mask的值设置为-100,softmax后就会忽略掉对应的值

图解Swin Transformer - 知乎

SOTA 模型 Swin Transformer 是如何炼成的! - 极市社区

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/361005
推荐阅读
相关标签
  

闽ICP备14008679号