  • 一、网路模型整体架构
  • 二、Patch Partition模块详解
  • 三、Patch Merging模块
  • 四、W-MSA详解
  • 五、SW-MSA详解
    • masked MSA详解
  • 六、 Relative Position Bias详解
  • 七、模型详细配置参数
  • 八、重要模块代码实现:
    • 1、Patch Partition代码模块:
    • 2、Patch Merging代码模块:
    • 3、mask掩码生成代码模块:
    • 4、stage堆叠部分代码:
    • 5、SW-MSA或者W-MSA模块代码:
  • 九:模型整体流程代码实现:

论文名称:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原论文地址: https://arxiv.org/abs/2103.14030




二、Patch Partition模块详解


三、Patch Merging模块







masked MSA详解


六、 Relative Position Bias详解





1、Patch Partition代码模块:

class PatchEmbed(nn.Module):"""2D Image to Patch Embeddingsplit image into non-overlapping patches   即将图片划分成一个个没有重叠的patch"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果输入图片的H,W不是patch_size的整数倍,需要进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left, W_right, H_top,H_bottom, C_front, C_back)x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],   # 表示宽度方向右侧填充数0, self.patch_size[0] - H % self.patch_size[0],   # 表示高度方向底部填充数0, 0))# 下采样patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, W

2、Patch Merging代码模块:

class PatchMerging(nn.Module):r""" Patch Merging Layer.步长为2,间隔采样Args:dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""x: B, H*W, C    即输入x的通道排列顺序"""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)# padding# 如果输入feature map的H,W不是2的整数倍,需要进行paddingpad_input = (H % 2 == 1) or (W % 2 == 1)if pad_input:# to pad the last 3 dimensions, starting from the last dimension and moving forward.# (C_front, C_back, W_left, W_right, H_top, H_bottom)# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))# 以2为间隔进行采样x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]x = torch.cat([x0, x1, x2, x3], -1)  #  ————————>  [B, H/2, W/2, 4*C]   在channael维度上进行拼接x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]x = self.norm(x)x = self.reduction(x)  # [B, H/2*W/2, 2*C]return x


    def create_mask(self, x, H, W):# calculate attention mask for SW-MSA# 保证Hp和Wp是window_size的整数倍Hp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partitionimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1# 将img_mask划分成一个一个窗口mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]           # 输出的是按照指定的window_size划分成一个一个窗口的数据mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]  使用了广播机制# [nW, Mh*Mw, Mh*Mw]# 因为需要求得的是自身注意力机制,所以,所以相同的区域使用0表示,;不同的区域不等于0,填入-100,这样,在求得attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))   # 即对于不等于0的位置,赋值为-100;否则为0return attn_mask


class BasicLayer(nn.Module):"""A basic Swin Transformer layer for one stage.Args:dim (int): Number of input channels.depth (int): Number of blocks.num_heads (int): Number of attention heads.window_size (int): Local window size.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2  # 表示向右和向下偏移的窗口大小   即窗口大小除以2,然后向下取整# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,   # 通过判断shift_size是否等于0,来决定是使用W-MSA与SW-MSAmlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])# patch merging layer    即:PatchMerging类if downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef create_mask(self, x, H, W):# calculate attention mask for SW-MSA# 保证Hp和Wp是window_size的整数倍Hp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partitionimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1# 将img_mask划分成一个一个窗口mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]           # 输出的是按照指定的window_size划分成一个一个窗口的数据mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]  使用了广播机制# [nW, Mh*Mw, Mh*Mw]# 因为需要求得的是自身注意力机制,所以,所以相同的区域使用0表示,;不同的区域不等于0,填入-100,这样,在求得attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))   # 即对于不等于0的位置,赋值为-100;否则为0return attn_maskdef forward(self, x, H, W):attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]   # 制作mask蒙版for blk in self.blocks:blk.H, blk.W = H, Wif not torch.jit.is_scripting() and self.use_checkpoint:x = checkpoint.checkpoint(blk, x, attn_mask)else:x = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, W


class SwinTransformerBlock(nn.Module):r""" Swin Transformer Block.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.window_size (int): Window size.shift_size (int): Shift size for SW-MSA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioassert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)    # 先经过层归一化处理# WindowAttention即为:SW-MSA或者W-MSA模块self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# cyclic shift# 判断是进行SW-MSA或者是W-MSA模块if self.shift_size > 0:# https://blog.csdn.net/ooooocj/article/details/126046858?ops_request_misc=&request_id=&biz_id=102&utm_term=torch.roll()%E7%94%A8%E6%B3%95&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-126046858.142^v73^control,201^v4^add_ask,239^v1^control&spm=1018.2226.3001.4187shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))    #进行数据移动操作else:shifted_x = xattn_mask = None# partition windows# 将窗口按照window_size的大小进行划分,得到一个个窗口x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]# 将数据进行展平操作x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]# W-MSA/SW-MSA"""# 进行多头自注意力机制操作"""attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]# 将多窗口拼接回大的featureMapshifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]# reverse cyclic shift# 将移位的数据进行还原if self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_x# 如果进行了padding操作,需要移出掉相应的padif pad_r > 0 or pad_b > 0:# 把前面pad的数据移除掉x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x


