当前位置:   article > 正文

第6周学习笔记:Vision Transformer & Swin Transformer学习_swin transformer与vision transformer结合

swin transformer与vision transformer结合

Vision Transformer模型详解

该模型将Transformer结构直接应用到图像上,即将一张图像分割成多个patches,这些patches看作是NLP的tokens (words),然后对每个patches做一系列linear embedding操作之后作为Transformer的input。

Vision Transformer 模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder(Transformer 层)
  • MLP Head(最终用于分类的层)
    在这里插入图片描述

Linear Projection of Flattened Patches(Embedding层)

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。
在这里插入图片描述

左边是一个个被切分好的图片块,假设原始输入的图片数据是 H x W x C,我们需要对图片进行块切割,假设图片块大小为S1 x S2,则最终的块数量N为:N = ( H / S1) * (W / S2)。然后将切分好的图片块展平为一维,那么每一个向量的长度为:Patch_dim = S1 * S2 * C,从而得到了一个N x Patch_dim的输入序列。

Transformer Encoder(Transformer 层)

Vision Transformer Encoder 有层归一化,多头注意力机制,残差连接和线性变换这四个操作

  • 给定输入编码矩阵 ,首先将其进行层归一化得到 ;
  • 利用矩阵 对 进行线性变换得到矩阵,再将矩阵输入到 Multi-Head Attention中得到矩阵 ,将最原始的输入矩阵 与 进行残差计算得到 ;
  • 将 进行第二次层归一化得到 ,然后再将 输入到全连接神经网络中进行线性变换得到 。最后将 与 进行残差操作得到该 Block 的输出;。一个 Encoder 可以将 个 Block 进行堆叠。
    在这里插入图片描述
    其中Multi-Head Attention就是让模型学习全方位、多层次、多角度的信息,学习更丰富的信息特征,对于同一张图片来说,每个人看到的、注意到的部分都会存在一定差异,而在图像中的多头恰恰是把这些差异综合起来进行学习。

MLP Head(最终用于分类的层)

结束了Transformer Encoder,就到了最终的分类处理部分,在之前进行Encoder的时候通过concat的方式多加了一个用于分类的可学习向量,这时把这个向量取出来输入到MLP Head中,即经过Layer Normal --> 全连接 --> GELU --> 全连接,得到了最终的输出。
在这里插入图片描述

Swin Transformer 网络详解

目前Transformer应用到图像领域主要有两大挑战:

  • 视觉实体变化大,在不同场景下视觉Transformer性能未必很好
  • 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较大

针对上述两个问题,提出了一种包含滑窗操作,具有层级设计的Swin Transformer。
其中滑窗操作包括不重叠的local window,和重叠的cross-window。将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量。
Swin Transformer的最大贡献是提出了一个可以广泛应用到所有计算机视觉领域的backbone,并且大多数在CNN网络中常见的超参数在Swin Transformer中也是可以人工调整的,例如可以调整的网络块数,每一块的层数,输入图像的大小等等。

网络整体架构

通过与CNN相似的分层结构来处理图片,使得模型能够灵活处理不同尺度的图片
在这里插入图片描述
接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图,可以看出整个框架的基本流程如下:
在这里插入图片描述
整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding。
  • 在每个Stage里,由Patch Merging和多个Block组成。
  • 其中Patch Merging模块主要在每个Stage一开始降低图片分辨率。
  • Block具体结构主要是LayerNorm,MLP,Window Attention 和 Shifted Window Attention组成
class SwinTransformer(nn.Module):
    def __init__(.
  • 1
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号