赞
踩
本文再次讲述一篇新的 Sparse-MLP 工作,其的 Sparse 主要描述在感受野层面,与 MLP-Mixer 的全局感受野相比,本网络的感受野是轴向的,所以是稀疏的。本文可以看作是 ConvMLP 和 ViP 的结合,但是其发布时间早 ConvMLP 一周。
自 AlexNet 提出以来,卷积神经网络(CNN)一直是计算机视觉的主导范式。随着 Vision Transformer 的提出,这种情况发生了改变。ViT 将一个图像被划分为不重叠的 patch,并用线性层将这些 patch 转换为 token,然后输入到 Transformer Block 中进行处理。无卷积的 Vision Transformer 主要存在两个核心的思想:全局依赖性建模很重要 ;自注意力机制很重要 。但是近期的工作也发现局部依赖性似乎对图像更好(例如 Swin,AS-MLP 等等),或者说局部依赖性是全局依赖性的特殊情况,全局依赖性有要在超大规模训练上才得行。那么这种情况下,似乎还是要注入局部依赖性。此外,自注意力机制计算量是 token 数量的平方量级,因此,网络结构不有利于高分辨率输入,否则计算量 hold 不住(所以 Swin 用来金字塔结构和多阶段处理)。MLP-Mixer继承了ViT的所有缺点(除了自注意力机制中的平方计算量级),且由于参数数量过多,容易发生过拟合 。那么,一个无自注意力机制的网络能否达到 Sota 的性能?
本文就在于提出这样一个网络。本工作的原始论文为 Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?。2021.9.12 挂上 arXiv。其延续 MLP-Mixer 的交替使用 Token-Mixing MLP 和 Channel-Mixing MLP 的结构,不同之处在于修改了 Token-Mixing MLP。在 Token-MLP 中引入了 DWConv(类似于 ConvMLP) 和 轴向映射(类似于 ViP),此外使用了金字塔结构。最终 sMLPNet 在只有 24M 参数下达到 81.9% 的 Top-1 精度,比相同模型大小约束下的大多数 CNN 和视觉 Transformer 要好得多。当扩展到 66M 参数时,sMLPNet 达到了 83.4% 的 Top-1 精度,这与 SOTA 的 Swin Transformer 相当。
sMLPNet 采用多阶段金字塔模型,总共分为 4 个阶段,每个阶段交替使用 Token-mixing MLP 和 Channel-mixing MLP。Channel-mixing MLP 其实和 ViT 的 FFN 以及 MLP-Mixer 的 Channel-mixing MLP 其实是一样的,两个全连接中间有个 GELU,再加个残差结构。sMLPNet 网络结构图如下所示:
作者一共提出来了三种配置:
其中 C C C 为通道数目,后面的为每个阶段 Block 重复的次数。
sMLPNet 的一大核心在于修改了 Token-mixing MLP。修改后的 Token-mixing MLP 包含两部分:
DW 卷积能有效减少参数量,并且实现局部的信息交流,而 Sparse-MLP 则在轴向上具有长距离依赖。
Sparse-MLP 的结构图如下所示,其包含三个部分的并行结构:W 通道映射,H 通道映射,不动。这其实与 ViP 的三条并行支路很像,唯一不同的是 ViP 的第三条支路为通道 1 × 1 1 \times 1 1×1 卷积,而 Sparse-MLP 中为恒等映射。三条并行支路被通道拼接连在一起,然后经过一个 1 × 1 1 \times 1 1×1 卷积将通道数变为 1/3,即保持输入输出一致。
不难发现,Sparse-MLP 的感受野如图所示,有意思的点是:虽然在一层上看起来是十字形感受野,但是如果经过两个 Sparse-MLP 之后,其实能形成全局感受野!即如果该模块重复两次,每个 token 就可以聚合整个二维空间的信息。这个话之前 ViP,RaftMLP 等等这些没有提到哈,虽然他们也是一样的。
Sparse-MLP 实现代码如下所示:
import torch from torch import nn class sMLPBlock(nn.Module): def __init__(self,h=224,w=224,c=3): super().__init__() self.proj_h=nn.Linear(h,h) self.proj_w=nn.Linear(w,w) self.fuse=nn.Linear(3*c,c) def forward(self,x): x_h=self.proj_h(x.permute(0,1,3,2)).permute(0,1,3,2) x_w=self.proj_w(x) x_id=x x_fuse=torch.cat([x_h,x_w,x_id],dim=1) out=self.fuse(x_fuse.permute(0,2,3,1)).permute(0,3,1,2) return out if __name__ == '__main__': input=torch.randn(50,3,224,224) smlp=sMLPBlock(h=224,w=224) out=smlp(input) print(out.shape)
原文中是这样写的:
sMLP的复杂度为: Ω ( s M L P ) = H W C ( H + W ) + 3 H W C 2 \Omega(s M L P)=H W C(H+W)+3 H W C^{2} Ω(sMLP)=HWC(H+W)+3HWC2,这其实很好理解, H W C H HWCH HWCH 为 H 通道映射, H W C W HWCW HWCW 为 W 通道映射, 3 H W C 2 3HWC^2 3HWC2 为融合后的线性降维映射。
相比而言,MLP-Mixer 的 token 混合部分的复杂度为: Ω ( M L P ) = 2 α ( H W ) 2 C \Omega(M L P)=2 \alpha(H W)^{2} C Ω(MLP)=2α(HW)2C,其中 α \alpha α 为第一个 MLP 节点扩展系数。
但是这样其实并不正确!sMLP 其实还包含了 3 × 3 3 \times 3 3×3 DWConv,所以实际上的 sMLP 的复杂度为: Ω ( s M L P ) = H W C ( H + W ) + 3 H W C 2 + 9 H W C \Omega(s M L P)=H W C(H+W)+3 H W C^{2} + 9HWC Ω(sMLP)=HWC(H+W)+3HWC2+9HWC,
不过论文结论没有问题:可以看出,本文的方法将复杂度控制在了 O ( N N ) O(N\sqrt{N}) O(NN ) 内,而 MLP-Mixer为 O ( N 2 ) O(N^2) O(N2),其中 N = H W N = HW N=HW 为 Token 的数量。这使得本文的方法可以处理更大的 N N N,并最终在金字塔结构中实现多阶段处理。
论文的消融实验主要讨论了以下四点:
本文从网络结构上而言没有特殊的开创性的贡献,不过已经能看到,stride = 1 或 2,这已经被学术界所关注到。本文大胆地说了我们用了 DWConv,而且本文行文很多词汇描述都值得学习,这点很不错。如果结合了 Split Attention 或许会更好,如果放出特征图进行分析可能会更好,如果现在开源则会更更好。
本工作有个问题在于:使用了轴向映射,这就使得网络需要依赖图像的长宽尺寸进行构建,那么这就使得网络对于图像分辨率敏感,无法用于后续下游任务,这也是诸多 MLP-based model 的通病。
我自己实现的非官方 pytorch 代码见 此处,其中 PatchMerging 函数取自 Swin Transformer 官方。
import torch from torch import nn from functools import partial from einops.layers.torch import Rearrange, Reduce def pair(val): return (val, val) if not isinstance(val, tuple) else val class PreNormResidual(nn.Module): def __init__(self, dim, fn, norm = nn.LayerNorm): super().__init__() self.fn = fn self.norm = norm(dim) def forward(self, x): return self.fn(self.norm(x)) + x class sMLPBlock(nn.Module): def __init__(self, h = 224, w = 224, d_model = 3): super().__init__() self.proj_h = nn.Linear(h, h) self.proj_w = nn.Linear(w, w) self.fuse = nn.Conv2d(3 * d_model, d_model, kernel_size = 1) def forward(self,x): x_h = self.proj_h(x.permute(0,1,3,2)).permute(0,1,3,2) x_w = self.proj_w(x) x_id = x x_fuse = torch.cat([x_h, x_w, x_id], dim=1) out = self.fuse(x_fuse) return out class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, H, W, C = x.shape assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class sMLPStage(nn.Module): def __init__(self, height, width, d_model, depth, expansion_factor = 2, dropout = 0., pooling = False): super().__init__() self.pooling = pooling self.patch_merge = nn.Sequential( Rearrange('b c h w -> b h w c'), PatchMerging((height, width), d_model), Rearrange('b h w c -> b c h w'), ) self.model = nn.Sequential( *[nn.Sequential( PreNormResidual(d_model, nn.Sequential( nn.Conv2d(d_model, d_model, kernel_size = 3, padding = 1, groups = d_model), ), norm = nn.BatchNorm2d), PreNormResidual(d_model, nn.Sequential( sMLPBlock( height, width, d_model ) ), norm = nn.BatchNorm2d), Rearrange('b c h w -> b h w c'), PreNormResidual(d_model, nn.Sequential( nn.Linear(d_model, d_model * expansion_factor), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * expansion_factor, d_model), nn.Dropout(dropout), ), norm = nn.LayerNorm), Rearrange('b h w c -> b c h w'), ) for _ in range(depth)] ) def forward(self, x): x = self.model(x) if self.pooling: x = self.patch_merge(x) return x class SparseMLP(nn.Module): def __init__( self, image_size=224, patch_size=4, in_channels=3, num_classes=1000, d_model=96, depth=[2,10,24,2], expansion_factor = 2, patcher_norm = False, ): image_size = pair(image_size) patch_size = pair(patch_size) assert (image_size[0] % patch_size[0]) == 0, 'image must be divisible by patch size' assert (image_size[1] % patch_size[1]) == 0, 'image must be divisible by patch size' height = image_size[0] // patch_size[0] width = image_size[1] // patch_size[1] super().__init__() self.patcher = nn.Sequential( nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size), nn.Identity() if (not patcher_norm) else nn.Sequential( Rearrange('b c h w -> b h w c'), nn.LayerNorm(d_model), Rearrange('b h w c -> b c h w'), ) ) self.layers = nn.ModuleList() for i_layer in range(len(depth)): i_depth = depth[i_layer] i_stage = sMLPStage(height // (2**i_layer), width // (2**i_layer), d_model, i_depth, expansion_factor = expansion_factor, pooling = ((i_layer + 1) < len(depth))) self.layers.append(i_stage) if (i_layer + 1) < len(depth): d_model = d_model * 2 self.mlp_head = nn.Sequential( Rearrange('b c h w -> b h w c'), nn.LayerNorm(d_model), Reduce('b h w c -> b c', 'mean'), nn.Linear(d_model, num_classes) ) def forward(self, x): i = 0 embedding = self.patcher(x) for layer in self.layers: i += 1 embedding = layer(embedding) out = self.mlp_head(embedding) return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。