当前位置:   article > 正文

Transformer架构中基于窗口的的自注意力机制W-MSA和滑动窗口自注意力机制SW-MSA的实现_滑动窗口注意力机制

滑动窗口注意力机制

在Transformer架构中,自注意力机制(Self-Attention Mechanism)是实现序列建模和上下文信息捕捉的核心机制。自注意力机制允许模型在处理序列数据时根据序列中各个位置的信息来动态地分配注意力。

在基于窗口的自注意力机制W-MSA中,通过设置一个固定大小的窗口来约束注意机制的计算范围,只计算窗口内的位置之间的相似度,并且将注意分数归一化得到权重。这样可以减少计算量,并保持局部依赖性。W-MSA可以有效地处理长序列,但窗口的大小会直接影响模型性能。

而滑动窗口自注意力机制SW-MSA是对W-MSA的改进。它引入了一个滑动窗口,通过多次计算W-MSA来覆盖整个序列。这样可以提高模型的感知范围,使得模型能够捕捉到整个序列中的全局依赖性。SW-MSA在处理长序列时可以在一定程度上解决窗口大小的限制问题。

W-MSA和SW-MSA是自注意力机制在Transformer架构中的两种变体。它们可以根据具体任务和序列长度的不同选择合适的方式来实现自注意力机制,以提高模型的性能和表达能力。

1.窗口的划分:

  1. def window_partition(x, window_size): # 窗口划分
  2. """
  3. Args:
  4. x: (B, H, W, C)
  5. window_size (int): window size
  6. Returns:
  7. windows: (num_windows*B, window_size, window_size, C)
  8. """
  9. B, H, W, C = x.shape # 特征图的形状分别代表,一次处理的样本数量,宽,高,通道
  10. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # 窗口划分HW/MM,window_size=M
  11. # x.view()用于重新塑造张量 x 的形状而不改变其底层数据
  12. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  13. # contiguous():这是一个用于确保张量内存中元素连续排列的操作
  14. return windows

2.经过移动的窗口的复位

  1. def window_reverse(windows, window_size, H, W): # 窗口反转,便会原来的4x4的窗口
  2. """
  3. Args:
  4. windows: (num_windows*B, window_size, window_size, C)
  5. window_size (int): Window size
  6. H (int): Height of image
  7. W (int): Width of image
  8. Returns:
  9. x: (B, H, W, C)
  10. """
  11. B = int(windows.shape[0] / (H * W / window_size / window_size)) # 一次性处理的patch的个数
  12. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # 将向右下角移动的窗口归位
  13. # 将B分割成(H // window_size) x (W // window_size)个小窗口,每个小窗口大小为window_size x window_size,并按照一定的顺序重新排列它们
  14. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # 维度变换之后再调整大小
  15. return x

窗口自注意力机制W-MSA

  1. class WindowAttention(nn.Module): # (SW-MSA)shift window multi self attention 滑动窗口注意力机制
  2. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  3. It supports both of shifted and non-shifted window.
  4. Args:
  5. dim (int): Number of input channels.
  6. window_size (tuple[int]): The height and width of the window.
  7. num_heads (int): Number of attention heads.
  8. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  9. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  10. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  11. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  12. """
  13. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  14. super().__init__()
  15. self.dim = dim # 通道个数
  16. self.window_size = window_size # 窗口的宽和高
  17. self.num_heads = num_heads # 注意力的头目数量
  18. head_dim = dim // num_heads
  19. self.scale = qk_scale or head_dim ** -0.5
  20. # define a parameter table of relative position bias,相对位置偏差
  21. self.relative_position_bias_table = nn.Parameter(
  22. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  23. # get pair-wise relative position index for each token inside the window
  24. coords_h = torch.arange(self.window_size[0])
  25. coords_w = torch.arange(self.window_size[1])
  26. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww,个二维张量堆叠在一起,得到一个三维张量
  27. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  28. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww计算点之间的相对位置嵌入
  29. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  30. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0,右移window_size[0] - 1
  31. relative_coords[:, :, 1] += self.window_size[1] - 1 # 所有元素向下移动window_size[1] - 1
  32. # 所有元素都向右下方移动 self.window_size[0]-1 个单位
  33. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 将相对中心点的宽度偏移值乘以 (2 * self.window_size[1] - 1)
  34. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww,-1 表示对张量的最后一个维度进行求和
  35. self.register_buffer("relative_position_index", relative_position_index)
  36. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 三个 Q、K、V 分别对应着输出张量的前 dim,中间 dim 和后 dim 部分
  37. self.attn_drop = nn.Dropout(attn_drop) # 随机令一些输入为0
  38. self.proj = nn.Linear(dim, dim)
  39. # 一个线性层,用于将输入张量中每个位置的特征向量映射到一个更高维度的输出张量。在这个例子中,输入张量的维度是 dim,通过线性变换后,输出张量的维度仍然是 dim
  40. self.proj_drop = nn.Dropout(proj_drop) # 随机丢弃操作
  41. trunc_normal_(self.relative_position_bias_table, std=.02) # 随机生成一些数值,并将其中超过 0.02 倍标准差的数值进行截断
  42. self.softmax = nn.Softmax(dim=-1)
  43. def forward(self, x, mask=None):
  44. """
  45. Args:
  46. x: input features with shape of (num_windows*B, N, C)
  47. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  48. """
  49. B_, N, C = x.shape
  50. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 获取q,k,v的值
  51. # reshape()输出张量重新整形为一个五维张量
  52. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  53. q = q * self.scale
  54. attn = (q @ k.transpose(-2, -1)) # Q*K
  55. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  56. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  57. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  58. attn = attn + relative_position_bias.unsqueeze(0) # Q*K+B
  59. if mask is not None:
  60. nW = mask.shape[0]
  61. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  62. attn = attn.view(-1, self.num_heads, N, N)
  63. attn = self.softmax(attn) # 归一化
  64. else:
  65. attn = self.softmax(attn)
  66. attn = self.attn_drop(attn)
  67. x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # softmax[(Q*K+B)]*V
  68. x = self.proj(x) # 线性映射
  69. x = self.proj_drop(x) # 随机丢弃一些输入
  70. return x
  71. def extra_repr(self) -> str: # 返回一个字符串,描述了一些模块的关键参数信息
  72. return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
  73. # 阐述窗口的通道个数,窗口大小以及多头自注意力机制中的头的数量
  74. def flops(self, N): # 计算模型中所需的浮点运算量(flops)的函数。它根据给定的窗口大小 N 来计算相应的 flops
  75. # calculate flops for 1 window with token length of N
  76. flops = 0
  77. # qkv = self.qkv(x)
  78. flops += N * self.dim * 3 * self.dim
  79. # attn = (q @ k.transpose(-2, -1))
  80. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  81. # x = (attn @ v)
  82. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  83. # x = self.proj(x)
  84. flops += N * self.dim * self.dim
  85. return flops

4.基于滑动窗口SW-MSA自注意力机制的实现:

 

  1. class SwinTransformerBlock(nn.Module): # (SW-MSA)滑动窗口注意力机制的试下,RSTB
  2. r""" Swin Transformer Block.
  3. Args:
  4. dim (int): Number of input channels.
  5. input_resolution (tuple[int]): Input resulotion.
  6. num_heads (int): Number of attention heads.
  7. window_size (int): Window size.
  8. shift_size (int): Shift size for SW-MSA.
  9. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  10. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  11. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  12. drop (float, optional): Dropout rate. Default: 0.0
  13. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  14. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  15. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  16. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  17. """
  18. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  19. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  20. act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  21. super().__init__()
  22. self.dim = dim # 输入通道的数量
  23. self.input_resolution = input_resolution # 输入特征小图的分辨率
  24. self.num_heads = num_heads # 多头的个数
  25. self.window_size = window_size # 窗口大小,多少个patch
  26. self.shift_size = shift_size # SW-MSA滑动窗口滑动大小
  27. self.mlp_ratio = mlp_ratio
  28. if min(self.input_resolution) <= self.window_size: # 如果窗口大小大于输入图片的大小,就不进行窗口划分
  29. # if window size is larger than input resolution, we don't partition windows
  30. self.shift_size = 0
  31. self.window_size = min(self.input_resolution) # 窗口大小等于图像大小
  32. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  33. self.norm1 = norm_layer(dim) # LayerNorm1
  34. self.attn = WindowAttention(
  35. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  36. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # MSA多头自注意力的计算
  37. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  38. self.norm2 = norm_layer(dim) # LayerNorm2
  39. mlp_hidden_dim = int(dim * mlp_ratio) # 隐藏的输入通道的数量
  40. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # MLP
  41. if self.shift_size > 0:
  42. attn_mask = self.calculate_mask(self.input_resolution) # 计算注意力机制中的掩码
  43. else:
  44. attn_mask = None
  45. self.register_buffer("attn_mask", attn_mask)
  46. def calculate_mask(self, x_size):
  47. # calculate attention mask for SW-MSA
  48. H, W = x_size
  49. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1,初始化推向掩码
  50. h_slices = (slice(0, -self.window_size), # slice 函数来生成用于切片操作的索引,从位置0开始到位置-self.window_size结束的切片操作
  51. slice(-self.window_size, -self.shift_size),
  52. slice(-self.shift_size, None)) # 高度的大小
  53. w_slices = (slice(0, -self.window_size),
  54. slice(-self.window_size, -self.shift_size),
  55. slice(-self.shift_size, None)) # 从位置-self.shift_size开始到最后位置结束的切片操作
  56. cnt = 0
  57. for h in h_slices:
  58. for w in w_slices:
  59. img_mask[:, h, w, :] = cnt
  60. cnt += 1 # 掩码的设置,需要的地方设置为1,不需要的地方设置为0
  61. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1,掩码的窗口划分
  62. mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 维度调整
  63. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 创建注意力机制中的掩码张量,以便在自注意力或其他形式的掩码注意力中屏蔽掉不需要的位置
  64. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  65. # masked_fill 函数的作用是将输入张量中满足掩码条件的元素替换为指定的值
  66. return attn_mask
  67. def forward(self, x, x_size):
  68. H, W = x_size
  69. B, L, C = x.shape
  70. # assert L == H * W, "input feature has wrong size"
  71. shortcut = x
  72. x = self.norm1(x) # 定义第一个LayerNorm层
  73. x = x.view(B, H, W, C)
  74. # cyclic shift实现窗口的滑动
  75. if self.shift_size > 0:
  76. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  77. # 窗口滚动,于输入张量 x,shifts 参数用于指定每个维度需要滚动的步数,dims 参数用于指定需要滚动的维度
  78. else:
  79. shifted_x = x
  80. # partition windows窗口划分
  81. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C滚动后的窗口划分
  82. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  83. # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
  84. if self.input_resolution == x_size:
  85. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C,窗口注意力机制W-MSA
  86. else:
  87. attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) # 计算窗口注意力W-MSA
  88. # merge windows
  89. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  90. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C,滑动窗口复位
  91. # reverse cyclic shift
  92. if self.shift_size > 0:
  93. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) # 滚动的方式复位
  94. else:
  95. x = shifted_x
  96. x = x.view(B, H * W, C)
  97. # FFN
  98. x = shortcut + self.drop_path(x) # 第一个残差连接
  99. x = x + self.drop_path(self.mlp(self.norm2(x))) # MLP加上残差连接
  100. return x
  101. def extra_repr(self) -> str: # 打印模型细节信息
  102. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
  103. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  104. def flops(self): # 计算量和模型复杂度
  105. flops = 0
  106. H, W = self.input_resolution
  107. # norm1
  108. flops += self.dim * H * W
  109. # W-MSA/SW-MSA
  110. nW = H * W / self.window_size / self.window_size
  111. flops += nW * self.attn.flops(self.window_size * self.window_size)
  112. # mlp
  113. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  114. # norm2
  115. flops += self.dim * H * W
  116. return flops

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/337449
推荐阅读
相关标签
  

闽ICP备14008679号