当前位置:   article > 正文

Vision Transformer (ViT)初识:原理详解及代码_vit encoder和block

vit encoder和block


参考资源

论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
代码:https://github.com/google-research/vision_transformer(原论文对应源码)
代码:https://github.com/lucidrains/vit-pytorch
知乎:Vision Transformer
知乎:ViT(Vision Transformer)解析
CSDN:Vision Transformer详解


前言

Transformer最初提出是针对NLP领域的。

Vision Transformer将CV和NLP领域知识结合起来,对原始图片进行分块,展平成序列,输入进原始Transformer模型的编码器Encoder部分,最后接入一个全连接层对图片进行分类。

1. 整体架构

原论文中 Vision Transformer(ViT) 的模型框架。由三个模块组成:

  1. Linear Projection of Flattened Patches(Embedding层)
  2. Transformer Encoder(图右侧更加详细结构)
  3. MLP Head(最终用于分类的层)
    在这里插入图片描述

1.1 Embedding层

Transformer模块的输入:token(向量)序列,即二维矩阵[num_token, token_dim]

对于图像数据:将 [H, W, C] 三维矩阵通过一个Embedding层进行变换。

先对图片作分块,再进行展平

首先将一张图片按给定大小分成一堆Patches,分块的数目为:
N = H ∗ W / ( P ∗ P ) N=H*W/(P*P) N=HW/(PP)
对每个图片块展平成一维向量,每个向量大小为
P ∗ P ∗ C P*P*C PPC

以ViT-B/16为例:

将输入图片(224x224)按照16x16大小的Patch进行划分,会得到196个Patches.
N = ( 224 / 16 ) 2 = 196 N=(224/16)^2=196 N=(224/16)2=196

接着通过线性映射将每个Patch映射到一维向量中,每个Patche数据shape为 [16, 16, 3] 通过映射得到一个长度为 768 的向量(后面都直接称为 token)。
16 ∗ 16 ∗ 3 = 768 16*16*3=768 16163=768

代码实现中,通过一个卷积层来实现。 以ViT-B/16为例,使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

在这里插入图片描述

在输入Transformer Encoder之前注意需要加上 [class]token 以及Position Embedding。 其中:

class token

假设我们按照论文切成了9块,但是在输入的时候变成了10个向量。这是人为增加的一个向量。

因为传统的Transformer采取的是类似seq2seq编解码的结构 而ViT只用到了Encoder编码器结构,缺少了解码的过程,假设你9个向量经过编码器之后,你该选择哪一个向量进入到最后的分类头呢?因此这里作者给了额外的一个用于分类的向量,与输入进行拼接。

[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量

以ViT-B/16为例,就是: 一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,
Cat([1, 768], [196, 768]) -> [197, 768]

Position Embedding

Position Embedding采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。

以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]

1.2 Transformer Encoder层

重复堆叠 Encoder Block L次,Encoder Block组成:

  • Layer Norm
    主要是针对NLP领域提出的,是对每个token进行Norm处理
  • Multi-Head Attention
    多头注意力,可以参考上篇
  • Dropout/DropPath
    在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
  • MLP Block
    全连接+GELU激活函数+Dropout
    需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]

在这里插入图片描述

MLP block ,MLP head 和 FFN

Feed Forward

Feed Forward 层是一个两层的全连接层,第一层的激活函数为 Gelu,第二层不使用激活函数,对应的公式如下。
在这里插入图片描述
Feed Forward 全称是:Position-wise Feed-Forward network

Transformer 结构里的 Feed Forward 实际上就是 ViT (VisionTransformer) 结构里的 MLP Block,下图用粉色圈标出,具体结构为右侧粉色框,是一个全连接网络,包含两个线性变换和一个非线性函数

左边Transformer 结构里的 Encoder 用红色框标出,对应右边红色框的 ViT (VisionTransformer),因为ViT不包含Decorder

在这里插入图片描述

FeedForwardViT 的 MLP Block 代码:

两个全连接中间夹个激活函数,可以是RELU或者GELU,加入了dropout

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

1.3 MLP head 分类层

MLP Head 是ViT 里用于分类的预测头,结构和代码如下,也是由LN 和线性层组成
在这里插入图片描述
MLP head 代码:

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
  • 1
  • 2
  • 3

代码中,取token的第一个,也就是用于分类的token,输入到分类头里,得到最后的分类结果
[197, 768]中抽取出[class]token对应的[1, 768]


2. 代码解析

VIT Vision Transformer | 先从PyTorch代码了解
VIT(vision transformer)模型介绍+pytorch代码炸裂解析
ViT源码阅读-PyTorch

ViT 调用

import torch
from vit_pytorch import SimpleViT

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,     # width and height
    patch_size = 32,      # Number of patches. 
    num_classes = 1000,   # Number of classes to classify.
    dim = 1024,           # Last dimension of output tensor after linear transformation
    depth = 6,            # Number of Transformer blocks.
    heads = 16,           # Number of heads in Multi-head Attention layer.
    mlp_dim = 2048        # Dimension of the MLP (FeedForward) layer.
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)   # [batch, channel, height, width]
preds2 = v(img)                              # (1, 1000)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

Class PreNorm:线性化

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        # 先LN
        return self.fn(self.norm(x), **kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

Class FeedForward:即 MLP Block

# 两个全连接中间夹个激活函数,可以是RELU或者GELU,加入了dropout
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

Class Attention:注意力模块


class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5   # 论文里的\sqrt{d_k}

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 输入 x -> (batch, 197, 768)即(batch, num_patch + 1, hid_dims)
        # 按照最后一维(特征维度)分成3块,分别对应QKV
        # chunk后是一个tuple,即(q, k, v)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # q, k, v都做维度变换,(batch, 197, 768) -> (batch, 12, 197, 768 / 12 = 64)
        # 12是head的数量,目的是做**多头**注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # Q @ K^T  /  \sqrt{d_k}
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        # 把多头拼回去 -> (batch, 197, 768)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

Class Transformer :Encoder

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                # 先对输入做lN,然后放到attention,然后和做lN之前的输入相加做一个残差链接;
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                # x->LayerNormalization->FeedForward线性层(即MLP block)->y, 然后这个y和输入的x相加,做残差连接。
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:  # attn为Multi-head Attention,ff就是FeedForward
            x = attn(x) + x           # 残差连接,图片中的边线
            x = ff(x) + x
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

整体框架


class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)   # 默认为16   行和列上一共有224 / 16 = 14个patch

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # # num patches -> (224 / 16) = 14, 14 * 14 = 196
        num_patches = (image_height // patch_height) * (image_width // patch_width)   # 分块数目: N = H *W/(P*P)
        # # path dim -> 3 * 16 * 16 = 768,和Bert-base一致
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 步骤一:图像分块与映射。首先将图片分块,然后接一个线性层做映射
        #  [224, 224, 3] -> [14, 14, 768]  ——> # [14, 14, 768] -> [196, 768]
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # pos_embedding:位置编码;cls_token:在序列最前面插入一个cls token作为分类输出
        # Cat([1, 768], [196, 768]) -> [197, 768]
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 步骤二:Transformer Encoder结构来提特征 即  Transformer Encoder
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # 线性层输出
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 1 x 1 x 768的CLS token重复至 batch x 1 x 768
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]   # 因为多了个CLS token所以要n+1
        x = self.dropout(x)
        # x.shape -> (batch, 196 + 1, 768)

        x = self.transformer(x)   # Transformer Encoder

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

本篇是Transformer 在视觉的应用 ViT 的原理和代码

ViT 的应用和变体的解析和代码参考下篇:
ViT 的应用和变体的解析和代码

Transformer 基本原理和知识参考上篇:
Transformer 初识:模型结构+attention原理详解

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/371505
推荐阅读
相关标签
  

闽ICP备14008679号