赞
踩
Swin Transformer另一个重要的改进就是window-based的self-attention layer,之前提到过,ViT的一个缺点是计算复杂度是和patch数量成平方关系的,为了减少计算量,Swin的做法是将输入图片划分成不重合的windows,然后在不同的window内进行self-attention计算。假设一个图片有 的patches,每个window包含MxM个patches,作者将其成为window based self-attention(W-MSA)layer,W-MSA和multi-head self-attention(MSA)的计算复杂度分别为:
由于window内部的patch数量远小于图片patch数量,并且window数量是保持不变的,W-MSA的计算复杂度和图像尺寸呈线性关系,从而大大降低了模型的计算复杂度。
虽然W-MSA能够降低计算复杂度,但是不重合的window之间缺乏信息交流,这样其实就失去了transformer利用self-attention从全局构建关系的能力,于是文章进一步引入shifted window partition来跨window进行信息交流,作者将其成为shifted window based self-attention(SW-MSA)。
如上图所示,Layer 1中8x8尺寸feature map划分成2x2个patch,每个patch尺寸为4x4, 通过将patch位置整体平移1/2个patch大小,在下一层得到新的window,包括3x3个不重合的patch。移动window的划分方式使上一层相邻的不重合window之间引入连接,大大的增加了感受野。
但这样做带来的另一个问题就是window内部patch的数量从原本的4个增加到了9个,为了让patch数量保持不变,如下图所示,作者的解决思路是把平移之后左上角A,B,C部分的patch与右下角不满足4x4尺度的patch拼接,这样patch的数量还是4个,但是又满足了window外的信息交互,作者将其成为cyclic shift。
- # window_partition是划分,window_reverse是合并
- def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.reshape([B, H // window_size, window_size, W // window_size, window_size, C])
- windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
- return windows
-
-
- def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1])
- x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
- return x
这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。
我们先简单看下公式
主要区别是:在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。
总体代码:
- window_size =(2,2)
- coords_flatten = np.array([[0, 0, 1, 1],
- [0, 1, 0, 1]])
- new = torch.tensor(coords_flatten)
-
- new_first=new[:, :, None] # (2,4,1)
- new_second=new[:, None, :] # (2,1,4)
-
- relative_coords = (new_first-new_second).permute(1, 2, 0).contiguous() # (4,4,2):4个4行2列的矩阵
- relative_coords[:, :, 0] += window_size[0] - 1 # 在每个矩阵的第0列加1
- relative_coords[:, :, 1] += window_size[1] - 1 # 在每个矩阵的第1列加1
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1 # 在每个矩阵的第0列*3
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 将4个矩阵的每一个矩阵按行求和,作为新的行(4,2)==》(1,4)
下面我把涉及到相关位置编码的逻辑给单独拿出来,这部分比较绕
首先QK计算出来的Attention张量形状为(numWindows*B, num_heads, window_size*window_size, window_size*window_size)
。
而对于Attention张量来说,以不同元素为原点,其他元素的坐标也是不同的,以window_size=2
为例,其相对位置编码如下图所示
首先我们利用torch.arange
和torch.meshgrid
函数生成对应的坐标,这里我们以windowsize=2
为例子
关于torch.meshgrid
函数请看: 【pytorch】torch.meshgrid()==>常用于生成二维网格,比如图像的坐标点_小马牛的博客-CSDN博客
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
- """
- 此时是两个张量,每个 张量是一个2*2的二维矩阵
- (tensor([[0, 0],
- [1, 1]]),
- tensor([[0, 1],
- [0, 1]]))
- """
然后堆叠起来,展开为一个二维向量
- coords = torch.stack(coords) # 2, Wh, Ww
- # 将两个tensor堆叠起来,就变成了一个2*2*2的三维矩阵
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- # 再将三维矩阵压扁为2*4的一个二维矩阵
- """
- tensor([[0, 0, 1, 1],
- [0, 1, 0, 1]])
- """
利用广播机制,分别在第二维,第一维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww
的张量
解释:
relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords_first的None是在第二维,那么就是在第二维插入了一个维度1,就是由 2, wh*ww ==>2, wh*ww, 1
relative_coords_second的None是在第一维,那么就是在第一维插入了一个维度1,就是由 2, wh*ww ==>2, 1,wh*ww
图二a :
- relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
- relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
- relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量
- # 此处加上下面这一句就是图二的第一个
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
因为采取的是相减,所以得到的索引是从负数开始的
图二a对应的索引矩阵:
图二b :
因为采取的是相减,所以得到的索引是从负数开始的,我们加上偏移量,让其从0开始。
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1
- relative_coords[:, :, 1] += self.window_size[1] - 1
此处相当于在矩阵中每个加(1,1),相当于在索引矩阵上加2
图二c :
后续我们需要将其展开成一维偏移量。而对于(1,2)和(2,1)这两个坐标。在二维上是不同的,但是通过将x,y坐标相加转换为一维偏移的时候,他的偏移量是相等的。
所以最后我们对其中做了个乘法操作,以进行区分
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
图二:
图二d :
然后再最后一维上进行求和,展开成一个一维坐标,并注册为一个不参与网络学习的变量
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
现在得到了位置索引,那么下一步就是输入进计算公式(4)进行forward
代码里对特征图移位是通过torch.roll
来实现的,下面是示意图
如果需要reverse cyclic shift
的话只需把参数shifts
设置为对应的正数值。
我认为这是Swin Transformer的精华,通过设置合理的mask,让Shifted Window Attention
在与Window Attention
相同的窗口个数下,达到等价的计算结果。
首先我们对Shift Window后的每个窗口都给上index,并且做一个roll
操作(window_size=2, shift_size=1)
我们希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。例如下图的5353与5353进行计算,那么只会5与5计算,3与3计算,而不会5与3计算
最后正确的结果如下图所示
而要想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask
相关代码如下:
- if self.shift_size > 0:
- # calculate attention mask for SW-MSA
- H, W = self.input_resolution
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
以上图的设置,我们用这段代码会得到这样的一个mask
- tensor([[[[[ 0., 0., 0., 0.],
- [ 0., 0., 0., 0.],
- [ 0., 0., 0., 0.],
- [ 0., 0., 0., 0.]]],
-
-
- [[[ 0., -100., 0., -100.],
- [-100., 0., -100., 0.],
- [ 0., -100., 0., -100.],
- [-100., 0., -100., 0.]]],
-
-
- [[[ 0., 0., -100., -100.],
- [ 0., 0., -100., -100.],
- [-100., -100., 0., 0.],
- [-100., -100., 0., 0.]]],
-
-
- [[[ 0., -100., -100., -100.],
- [-100., 0., -100., -100.],
- [-100., -100., 0., -100.],
- [-100., -100., -100., 0.]]]]])
在之前的window attention模块的前向代码里,包含这么一段
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
将mask加到attention的计算结果,并进行softmax。mask的值设置为-100,softmax后就会忽略掉对应的值
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。