当前位置:   article > 正文

深度学习之图像分类(二十一)-- MLP-Mixer网络详解

mlp-mixer

深度学习之图像分类(二十一)MLP-Mixer网络详解

Transformer 之后,我们开启了一个新篇章,即无关卷积和注意力机制的最原始形态,全连接网络。在本章中我们学习全连接构成的 MLP-Mixer。(仔细发现,这个团队其实就是 ViT 团队…),作为一种“开创性”的工作,挖了很多很多的新坑,也引发了后续一系列工作。也许之后是 CNN、Transformer、MLP 三分天下了。

请添加图片描述

1. 前言

MLP-Mixer 是谷歌 AI 团队 2021 初的文章,论文为 MLP-Mixer: An all-MLP Architecture for Vision。卷积神经网络(CNN)是计算机视觉的首选模型。 最近,基于注意力的网络(例如Vision Transformer)也变得很流行。 在工作中,研究人员表明,尽管卷积和注意力都足以获得良好的性能,但它们都不是必需的,纯 MLP + 非线性激活函数 + Layer Normalization 也能取得 comparable 的性能,其预训练和推理成本可与最新模型相媲美。看了实验结果,只能说:有钱有钱有钱,常人即使有想法也跑不出来。

请添加图片描述

为什么要用全连接层有什么好处呢?它的归纳偏置(inductive bias)更低。归纳偏置可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或者“价值观”。下图展示了机器学习中常用的归纳偏置。CNN 的归纳偏置在于卷积操作,只有感受野内部有相互作用,即图像的局部性特征。时序网络 RNN 的归纳偏置在于时间维度上的连续性和局部性。事实上,ViT 已经开始延续了一直以来想要在神经网络中移除手工视觉特征和归纳偏置的趋势,让模型只依赖于原始数据进行学习。MLP 则更进了一步。原文提到说:One could speculate and explain it again with the difference in inductive biases: self-attention layers in ViT lead to certain properties of the learned functions that are less compatible with the true underlying distribution than those discovered with Mixer architecture

请添加图片描述

2. MLP-Mixer 网络结构

MLP-Mixer 的其整体思路为:先将输入图片拆分成多个 patches(每个 patch 之间不重叠),通过 Per-patch Fully-connected 层的操作将每个 patch 转换成 feature embedding,然后送入N个Mixer Layer。最后,Mixer 将标准分类头与全局平均池化层配合使用,随后使用 Fully-connected 进行分类。细细看来,其第一步与 ViT 其实是一致的,Mixer Layer 替换了 Transformer Block,最后直接的输出接到全连接层而无需 class token。此外,Mixer 的输出基于输入的信息,因为因为全连接层,所以交换任意两个 token 会得到不同的结果(对应的权重不一样了),所以无需 position embedding。

请添加图片描述

假设我们有输入图像 224 × 224 × 3 224 \times 224 \times 3 224×224×3,首先我们切 patch,例如长宽都取 32,则我们可以切成 7 × 7 = 49 7 \times 7 = 49 7×7=49 个 patch,每个 patch 是 32 × 32 × 3 32 \times 32 \times 3 32×32×3。我们将每个 patch 展平就能成为 49 个 3072 3072 3072 维的向量。通过一个全连接层(Per-patch Fully-connected)进行降维,例如 512 维,就得到了 49 个 token,每个 token 的维度为 512。然后将他们馈入 Mixer Layer。

细看 Mixer Layer,Mixer 架构采用两种不同类型的 MLP 层:token-mixing MLP 和 channel-mixing MLP。token-mixing MLP 指的是 cross-location operation,即对于 49 个 512 512 512 维的 token,将每一个 token 内部进行自融合,将 49 维映射到 49 维,即“混合”空间信息;channel-mixing MLP 指的是 pre-location operation,即对于 49 个 512 512 512 维的 token,将每一维进行融合,将 512 维映射到 512 维,即“混合”每个位置特征。为了简单实现,其实将矩阵转置一下就可以了。这两种类型的层交替执行以促进两个维度间的信息交互。单个 MLP 是由两个全连接层和一个 GELU 激活函数组成的。此外,Mixer 还是用了跳跃连接(Skip-connection)和层归一化(Layer Norm),这里的跳跃连接其实不是 UNet 中的通道拼接,而是 ResNet 中的残差结构,将输入输出相加;而这里的层归一化与 DenseNet 等网络一致,作用于全连接层前面,进行前归一化(DenseNet 的 BN 层在卷积层前面,也有工作证明 LN 在 Multi-head Self-Attention 前会更好)。
U ∗ , i = X ∗ , i + W 2 σ ( W 1 LayerNorm ⁡ ( X ) ∗ , i ) ,  for  i = 1 … C Y j , ∗ = U j , ∗ + W 4 σ ( W 3 LayerNorm ⁡ ( U ) j , ∗ ) ,  for  j = 1 … S

U,i=X,i+W2σ(W1LayerNorm(X),i), for i=1CYj,=Uj,+W4σ(W3LayerNorm(U)j,), for j=1S
U,iYj,=X,i+W2σ(W1LayerNorm(X),i),=Uj,+W4σ(W3LayerNorm(U)j,), for i=1C for j=1S

请添加图片描述

为什么作者要采用 pre-location operation 和 cross-location operation 呢?因为信息融合无非就是分为固定位置的信息融合和位置间的信息融合。CNN 的卷积核大小为 K × K × C K \times K \times C K×K×C 其实是同时考虑了这两个,当 K = 1 K = 1 K=1 时则是 pre-location operation,当 C = 1 C = 1 C=1 时对应的 depth-wise convolution 其实就是考虑了 cross-location operation。有 CNN (MobileNet,EfficientNet…) 为了计算复杂度降低采用通道可分离卷积,先做 1 × 1 1 \times 1 1×1,再做 DW 卷积,效果也可,说明他们可以不一起做。所以 MLP-Mixer 也干脆将这两个正交开来处理。

注意到:这里的 channel-mixing MLP 其实就是卷积神经网络中的 1 × 1 1 \times 1 1×1 卷积,卷积核的个数和输入通道数一致。此外 token-mixing MLP 其实就是卷积神经网络中的 padding = 0 的超大型卷积,将 512 维映射到 512 维的全连接层对应的就是 512 个 32 × 32 32 \times 32 32×32 的卷积层。只不过在 Mixer Layer 中 49 个 token 是共用这个全连接层,也就是如果换成卷积层的话是 single-channel depth-wise convolutions(这个和组卷积还不一样,49个组卷积,每个组内有 512 个 32 × 32 32 \times 32 32×32 的卷积层,这样组间的卷积核可以不一样,但是 MLP-Mixer 中是同一个全连接层,所以组卷积每个组的卷积核是一样的)。所以作者说:MLP-Mixer 是 CNN 的特例,但 CNN 不是 MLP-Mixer 的特例

MLP-Mixer 有多个 Mixer Layer 拼接,Mixer 中的每一层 (除了初始块投影层外) 采用相同尺寸的输入(token 维度不变),这种“各向同性”设计类似于Transformer 和 RNN 中定宽;这与 CNN 中金字塔结构 (越深的层具有更低的分辨率、更多的通道数) 不同。

至于最后的全局池化和全连接,则是相对普通的操作了,按照 Channel 对每个 token 进行池化,最后得到 Channel 维度的向量,然后经过一个全连接层输出目标分类分数即可。

最终 MLP-Mixer 的网络配置如下表所示:

请添加图片描述

3. 总结

对于 MLP-Mixer 的一些反思:CV 和 Transformer 已经在目标检测和分割上展现了出色的性能,但是 MLP-Mixer 是否可以为分割、识别等方向提供太大的帮助呢?暂时还不得而知,因为 MLP-Mixer 的特征对分类有用,但是还有待下游其他任务验证。此外 MLP-Mixer 计算量很大,对于增大图像输入分辨率不太友好。还有一个重要的是,MLP 思想虽然看起来很哇塞,但是残差结构和 LN 的 block 设计感觉效用也很大,如果没有他们 MLP-Mixer 还能 work 吗?

对于 MLP-Mixer 的一些展望:MLP-Mixer 能不能也和 TNT 一样增加 patch 内的信息交流呢?虽然 token-mixing MLP 可以被视为 patch 内的信息交流,但是这也可以看起来和 Transformer Block 在 Multi-Head Self-Attention 之后的映射一样的,所以增加 TNT 一样增加 patch 内的信息交流应该会有帮助。CNN + MLP-Mixer 会比 CNN + Transformer 更好吗?我个人暂时看起来不一定,因为无论从可解释性还是训练速度等角度看起来还是持保有态度。

MLP-Mixer 有很多微弱修改和改进,例如 RepMLP,ResMLP,gMLP,CycleMLP 等有待后续学习。

关于实验部分,推荐大家去浏览原文,这里不做过多的描述。最后依然是一贯认识:只要算力和数据够,一切就都有可能

请添加图片描述

4. 代码

代码出处见 此处

import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    return nn.Sequential(
        dense(dim, dim * expansion_factor),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(dim * expansion_factor, dim),
        nn.Dropout(dropout)
    )

def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
    assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_size // patch_size) ** 2
    chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear

    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        nn.Linear((patch_size ** 2) * channels, dim),
        *[nn.Sequential(
            PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
            PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))
        ) for _ in range(depth)],
        nn.LayerNorm(dim),
        Reduce('b n c -> b c', 'mean'),
        nn.Linear(dim, num_classes)
    )

model = MLPMixer(
    image_size = 256,
    channels = 3,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/248406
推荐阅读
相关标签
  

闽ICP备14008679号