当前位置:   article > 正文

Vision Transformer (ViT)初识:原理详解及代码_vit代码详解

vit代码详解

参考资源
前言
1. 整体架构
1.1 Embedding层
class token
Position Embedding
1.2 Transformer Encoder层
MLP block ,MLP head 和 FFN
Feed Forward
1.3 MLP head 分类层
2. 代码解析
ViT 调用
Class PreNorm:线性化
Class FeedForward:即 MLP Block
Class Attention:注意力模块
Class Transformer :Encoder
整体框架
参考资源
论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
代码:https://github.com/google-research/vision_transformer(原论文对应源码)
代码:https://github.com/lucidrains/vit-pytorch
知乎:Vision Transformer
知乎:ViT(Vision Transformer)解析
CSDN:Vision Transformer详解

前言
Transformer最初提出是针对NLP领域的。

Vision Transformer将CV和NLP领域知识结合起来,对原始图片进行分块,展平成序列,输入进原始Transformer模型的编码器Encoder部分,最后接入一个全连接层对图片进行分类。

1. 整体架构
原论文中 Vision Transformer(ViT) 的模型框架。由三个模块组成:

Linear Projection of Flattened Patches(Embedding层)
Transformer Encoder(图右侧更加详细结构)
MLP Head(最终用于分类的层)

1.1 Embedding层
Transformer模块的输入:token(向量)序列,即二维矩阵[num_token, token_dim]

对于图像数据:将 [H, W, C] 三维矩阵通过一个Embedding层进行变换。

先对图片作分块,再进行展平

首先将一张图片按给定大小分成一堆Patches,分块的数目为:
N = H ∗ W / ( P ∗ P ) N=H*W/(P*P)
N=H∗W/(P∗P)

对每个图片块展平成一维向量,每个向量大小为
P ∗ P ∗ C P*P*C
P∗P∗C

以ViT-B/16为例:

将输入图片(224x224)按照16x16大小的Patch进行划分,会得到196个Patches.
N = ( 224 / 16 ) 2 = 196 N=(224/16)^2=196
N=(224/16) 
2
 =196

接着通过线性映射将每个Patch映射到一维向量中,每个Patche数据shape为 [16, 16, 3] 通过映射得到一个长度为 768 的向量(后面都直接称为 token)。
16 ∗ 16 ∗ 3 = 768 16*16*3=768
16∗16∗3=768

在代码实现中,通过一个卷积层来实现。 以ViT-B/16为例,使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

在输入Transformer Encoder之前注意需要加上 [class]token 以及Position Embedding。 其中:

class token
假设我们按照论文切成了9块,但是在输入的时候变成了10个向量。这是人为增加的一个向量。

因为传统的Transformer采取的是类似seq2seq编解码的结构 而ViT只用到了Encoder编码器结构,缺少了解码的过程,假设你9个向量经过编码器之后,你该选择哪一个向量进入到最后的分类头呢?因此这里作者给了额外的一个用于分类的向量,与输入进行拼接。

[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量

以ViT-B/16为例,就是: 一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,
Cat([1, 768], [196, 768]) -> [197, 768]

Position Embedding
Position Embedding采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。

以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。

1.2 Transformer Encoder层
重复堆叠 Encoder Block L次,Encoder Block组成:

Layer Norm
主要是针对NLP领域提出的,是对每个token进行Norm处理
Multi-Head Attention
多头注意力,可以参考上篇
Dropout/DropPath
在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
MLP Block
全连接+GELU激活函数+Dropout
需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]


MLP block ,MLP head 和 FFN
Feed Forward
Feed Forward 层是一个两层的全连接层,第一层的激活函数为 Gelu,第二层不使用激活函数,对应的公式如下。

Feed Forward 全称是:Position-wise Feed-Forward network

Transformer 结构里的 Feed Forward 实际上就是 ViT (VisionTransformer) 结构里的 MLP Block,下图用粉色圈标出,具体结构为右侧粉色框,是一个全连接网络,包含两个线性变换和一个非线性函数

左边Transformer 结构里的 Encoder 用红色框标出,对应右边红色框的 ViT (VisionTransformer),因为ViT不包含Decorder

FeedForward 即 ViT 的 MLP Block 代码:

两个全连接中间夹个激活函数,可以是RELU或者GELU,加入了dropout

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

1
2
3
4
5
6
7
8
9
10
11
12
13
1.3 MLP head 分类层
MLP Head 是ViT 里用于分类的预测头,结构和代码如下,也是由LN 和线性层组成

MLP head 代码:

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
1
2
3
代码中,取token的第一个,也就是用于分类的token,输入到分类头里,得到最后的分类结果
即[197, 768]中抽取出[class]token对应的[1, 768]。

2. 代码解析
VIT Vision Transformer | 先从PyTorch代码了解
VIT(vision transformer)模型介绍+pytorch代码炸裂解析
ViT源码阅读-PyTorch

ViT 调用
import torch
from vit_pytorch import SimpleViT

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,     # width and height
    patch_size = 32,      # Number of patches. 
    num_classes = 1000,   # Number of classes to classify.
    dim = 1024,           # Last dimension of output tensor after linear transformation
    depth = 6,            # Number of Transformer blocks.
    heads = 16,           # Number of heads in Multi-head Attention layer.
    mlp_dim = 2048        # Dimension of the MLP (FeedForward) layer.
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)   # [batch, channel, height, width]
preds2 = v(img)                              # (1, 1000)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Class PreNorm:线性化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        # 先LN
        return self.fn(self.norm(x), **kwargs)
1
2
3
4
5
6
7
8
Class FeedForward:即 MLP Block
# 两个全连接中间夹个激活函数,可以是RELU或者GELU,加入了dropout
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
Class Attention:注意力模块

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5   # 论文里的\sqrt{d_k}

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 输入 x -> (batch, 197, 768)即(batch, num_patch + 1, hid_dims)
        # 按照最后一维(特征维度)分成3块,分别对应QKV
        # chunk后是一个tuple,即(q, k, v)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # q, k, v都做维度变换,(batch, 197, 768) -> (batch, 12, 197, 768 / 12 = 64)
        # 12是head的数量,目的是做**多头**注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # Q @ K^T  /  \sqrt{d_k}
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        # 把多头拼回去 -> (batch, 197, 768)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
Class Transformer :Encoder
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                # 先对输入做lN,然后放到attention,然后和做lN之前的输入相加做一个残差链接;
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                # x->LayerNormalization->FeedForward线性层(即MLP block)->y, 然后这个y和输入的x相加,做残差连接。
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:  # attn为Multi-head Attention,ff就是FeedForward
            x = attn(x) + x           # 残差连接,图片中的边线
            x = ff(x) + x
        return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
整体框架

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)   # 默认为16   行和列上一共有224 / 16 = 14个patch

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # # num patches -> (224 / 16) = 14, 14 * 14 = 196
        num_patches = (image_height // patch_height) * (image_width // patch_width)   # 分块数目: N = H *W/(P*P)
        # # path dim -> 3 * 16 * 16 = 768,和Bert-base一致
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 步骤一:图像分块与映射。首先将图片分块,然后接一个线性层做映射
        #  [224, 224, 3] -> [14, 14, 768]  ——> # [14, 14, 768] -> [196, 768]
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # pos_embedding:位置编码;cls_token:在序列最前面插入一个cls token作为分类输出
        # Cat([1, 768], [196, 768]) -> [197, 768]
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 步骤二:Transformer Encoder结构来提特征 即  Transformer Encoder
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # 线性层输出
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 1 x 1 x 768的CLS token重复至 batch x 1 x 768
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]   # 因为多了个CLS token所以要n+1
        x = self.dropout(x)
        # x.shape -> (batch, 196 + 1, 768)

        x = self.transformer(x)   # Transformer Encoder

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
本篇是Transformer 在视觉的应用 ViT 的原理和代码

ViT 的应用和变体的解析和代码参考下篇:
ViT 的应用和变体的解析和代码

Transformer 基本原理和知识参考上篇:
Transformer 初识:模型结构+attention原理详解
 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/764348
推荐阅读
相关标签
  

闽ICP备14008679号