当前位置:   article > 正文

Vision Transformer (ViT) 代码实现PyTorch版本_torchvision.models.vit

torchvision.models.vit

简介

本文的目的是通过实际代码编写来实现ViT模型,进一步加深对ViT模型的理解,如果还不知道ViT模型的话,可以看这个博客了解一下ViT的整体结构。

本文整体上是对Implementing Vision Transformer (ViT) in PyTorch的翻译,但是也加上了一些自己的注解。如果读者更习惯看英文版,建议直接去看原文。

ViT模型整体结构

按照惯例,先放上模型的整体架构图,如下:
在这里插入图片描述
输入图片被划分为一个个16x16的小块,也叫做patch。接着这些patch被送入一个全连接层得到embeddings,然后在embeddings前再加上一个特殊的cls token。然后给所有的embedding加上位置信息编码positional encoding。接着这个整体被送入Transformer Encoder,然后取cls token的输出特征送入MLP Head去做分类,总体流程就是这样。

代码的整体结构跟ViT模型的结构类似,大体可以分为以下几个部分:

  • Data

  • Patches Embeddings

    • CLS Token
    • Positional Encoding
  • Transformer Encoder Block

    • Attention
    • Residuals
    • MLP
  • Transformer Encoder

  • Head

  • ViT

我们将以自底向上的方式来逐步实现ViT模型。

Data

首先需要导入相关的依赖库,如下:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

首先我们需要打开一张图片,如下:

img = Image.open('./test.jpeg')
fig = plt.figure()
plt.imshow(img)
  • 1
  • 2
  • 3

在这里插入图片描述
接着我们需要对图片进行预处理,主要是包含resize向量化等操作,代码如下:

# resize to ImageNet size 
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0)  # 主要是为了添加batch这个维度
x.shape
  • 1
  • 2
  • 3
  • 4
  • 5
torch.Size([1, 3, 224, 224])
  • 1

Patches Embeddings

根据ViT的模型结构,第一步需要将图片划分为多个Patches,并且将其铺平。如下图:
在这里插入图片描述
原文的描述如下:
在这里插入图片描述
看起来很复杂,但是我们可以使用einops库来简化代码编写,如下:

patch_size = 16  # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
pathes.shape
  • 1
  • 2
  • 3
torch.Size([1, 196, 768])
  • 1

这里解释一下这个结果[1,196,768]是怎么来的。我们知道原始图片向量x的大小为[1,3,224,224],当我们使用16x16大小的patch对其进行分割的时候,一共可以划分为224x224/16/16 = 196个patches,其次每个patch大小为16x16x3=768,故大小为[1,196,768]。

接着我们需要将这些patches通过一个线性映射层。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3XmnsQQh-1648173776546)(attachment:image.png)]
这里可以定义一个名为PatchEmbedding的类来使代码更加整洁:

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # 将原始图像切分为16*16的patch并把它们拉平
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            # 注意这里的隐层大小设置的也是768,可以配置
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
torch.Size([1, 196, 768])
  • 1

实际查看原作者的代码,他并没有使用线性映射层来做这件事,出于效率考虑,作者使用了Conv2d层来实现相同的功能。这是通过设置卷积核大小和步长均为patch_size来实现的。直观上来看,卷积操作是分别应用在每个patch上的。所以,我们可以先应用一个卷积层,然后再对结果进行铺平,改进如下:

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # 使用一个卷积层而不是一个线性层 -> 性能增加
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            # 将卷积操作后的patch铺平
            Rearrange('b e h w -> b (h w) e'),
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
torch.Size([1, 196, 768])
  • 1

CLS Token

下一步是对映射后的patches添加上cls token以及位置编码信息。cls token是一个随机初始化的torch Parameter对象,在forward方法中它需要被拷贝b次(b是batch的数量),然后使用torch.cat函数添加到patch前面。

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.proj = nn.Sequential(
            # 使用一个卷积层而不是一个线性层 -> 性能增加
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        # 生成一个维度为emb_size的向量当做cls_token
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape  # 单独先将batch缓存起来
        x = self.proj(x)  # 进行卷积操作
        # 将cls_token 扩展b次
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        print(cls_tokens.shape)
        print(x.shape)
        # 将cls token在维度1扩展到输入上
        x = torch.cat([cls_tokens, x], dim=1)
        return x
    
PatchEmbedding()(x).shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
torch.Size([1, 1, 768])
torch.Size([1, 196, 768])
torch.Size([1, 197, 768])
  • 1
  • 2
  • 3

Positional Encoding

目前为止,模型还对patches在图像中的原始位置一无所知。我们需要传递给模型这些空间上的信息。可以有很多种方法来实现这个功能,在ViT中,我们让模型自己去学习这个。位置编码信息是一个形状为[N_PATCHES+1(token) * EMBED_SIZE]的张量,它直接加到映射后的patches上。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hlrx4jUO-1648173776546)(attachment:image.png)]
我们首先定义position embedding向量,然后在forward函数中将其加到线性映射后的patches向量(包含cls_token)上去。

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # 使用一个卷积层而不是一个线性层 -> 性能增加
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        # 位置编码信息,一共有(img_size // patch_size)**2 + 1(cls token)个位置向量
        self.positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # 将cls token在维度1扩展到输入上
        x = torch.cat([cls_tokens, x], dim=1)
        # 添加位置编码
        print(x.shape, self.positions.shape)
        x += self.positions
        return x
    
PatchEmbedding()(x).shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
torch.Size([1, 197, 768]) torch.Size([197, 768])
torch.Size([1, 197, 768])
  • 1
  • 2

Transformer Encoder Block

现在我们来实现Transformer Encoder Block模块。ViT模型中只使用了Transformer的Encoder部分,其整体架构如下:

接下来依次实现。

Attention

attention部分有三个输入,分别是queries,keys,values矩阵,首先使用queries,keys矩阵去计算注意力矩阵,经softmax后与values矩阵相乘,得到对应的输出。在下图中,multi-head注意力机制表示将输入划分成n份,然后将计算分到n个head上去。

我们可以使用pytorch的nn.MultiAttention模块或者自己实现一个,为了完整起见,我将完整的MultiAttention代码贴出来:

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

接下来,一步一步分析上述代码。我们定义了4个全连接层,分别用于queries,keys,values,以及最后的线性映射层。关于这块更加详细的内容可以阅读The Illustrated Transformer。主要的思想是使用querieskeys之间的乘积来计算输入序列中的每一个patch与剩余patch之间的匹配程度。然后使用这个匹配程度(数值)去对对应的values做缩放,再累加起来作为Attention的输出。

forward方法将上一层的输出作为输入,使用三个线性映射层分别得到queries,keys,values。因为我们要实现multi-head注意力机制,我们需要将输出重排成多个head的形式。这一步是使用einops库的rearrange函数来完成的。

queries,keys,values的形状是一样的,为了简便起见,它们都是基于同一个输入x

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)
  • 1
  • 2
  • 3

经过rearrange操作之后,queries,keys,values的形状大小为[BATCH, HEADS, SEQUENCE_LEN, EMBEDDING_SIZE].然后我们将多个head的输出拼接在一起就得到了Multi-Head Attention最终的输出。

注意:为了加快计算,我们可以使用单个矩阵一次性计算出queries,keys,values。

改进后的代码如下:

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # 使用单个矩阵一次性计算出queries,keys,values
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # 将queries,keys和values划分为num_heads
        print("1qkv's shape: ", self.qkv(x).shape)  # 使用单个矩阵一次性计算出queries,keys,values
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)  # 划分到num_heads个头上
        print("2qkv's shape: ", qkv.shape)
        
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        print("queries's shape: ", queries.shape)
        print("keys's shape: ", keys.shape)
        print("values's shape: ", values.shape)
        
        # 在最后一个维度上相加
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        print("energy's shape: ", energy.shape)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        
        scaling = self.emb_size ** (1/2)
        print("scaling: ", scaling)
        att = F.softmax(energy, dim=-1) / scaling
        print("att1' shape: ", att.shape)
        att = self.att_drop(att)
        print("att2' shape: ", att.shape)
        
        # 在第三个维度上相加
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        print("out1's shape: ", out.shape)
        out = rearrange(out, "b h n d -> b n (h d)")
        print("out2's shape: ", out.shape)
        out = self.projection(out)
        print("out3's shape: ", out.shape)
        return out
    
patches_embedded = PatchEmbedding()(x)
print("patches_embedding's shape: ", patches_embedded.shape)
MultiHeadAttention()(patches_embedded).shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
torch.Size([1, 197, 768]) torch.Size([197, 768])
patches_embedding's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

Residuals

Transformer模块也包含了残差连接,如下图:

我们可以单独封装一个处理残差连接的类,如下:

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

MLP

接着attention层的输出,首先通过BN层,紧跟其后的是一个全连接层,全连接层中采用了一个expansion因子来对输入进行上采样。同样这里也采用了类似resnet的残差连接方式。如下图:

代码如下:

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 作者说,不知道为什么,很少看见有人直接继承nn.Sequential类,这样就可以避免重写forward方法了。
  • 译者著,确实~又学到了一招。

最终,我们可以创建一个完整的Transformer Encoder Block了。

利用我们之前定义好的ResidualAdd类,我们可以很优雅地定义出Transformer Encoder Block,如下:

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

我们来测试一下:

patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
  • 1
  • 2
torch.Size([1, 197, 768]) torch.Size([197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

Transformer Encoder

在ViT中只使用了原始Transformer中的Encoder部分(其实和原始Transformer中的Encoder是有区别的)。Encoder一共包含L个block,我们使用参数depth来指定,代码如下:

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
  • 1
  • 2
  • 3

我们再来测试一下:

patches_embedded = PatchEmbedding()(x)
TransformerEncoder()(patches_embedded).shape
  • 1
  • 2
torch.Size([1, 197, 768]) torch.Size([197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
1qkv's shape:  torch.Size([1, 197, 2304])
2qkv's shape:  torch.Size([3, 1, 8, 197, 96])
queries's shape:  torch.Size([1, 8, 197, 96])
keys's shape:  torch.Size([1, 8, 197, 96])
values's shape:  torch.Size([1, 8, 197, 96])
energy's shape:  torch.Size([1, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([1, 8, 197, 197])
att2' shape:  torch.Size([1, 8, 197, 197])
out1's shape:  torch.Size([1, 8, 197, 96])
out2's shape:  torch.Size([1, 197, 768])
out3's shape:  torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146

torchsummary的用法:

# summary可以打印网络结构和参数
from torchsummary import summary
from torchvision.models import resnet18

model = resnet18()
summary(model, input_size=[(3, 256, 256)], batch_size=2, device="cpu")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [2, 64, 128, 128]           9,408
       BatchNorm2d-2          [2, 64, 128, 128]             128
              ReLU-3          [2, 64, 128, 128]               0
         MaxPool2d-4            [2, 64, 64, 64]               0
            Conv2d-5            [2, 64, 64, 64]          36,864
       BatchNorm2d-6            [2, 64, 64, 64]             128
              ReLU-7            [2, 64, 64, 64]               0
            Conv2d-8            [2, 64, 64, 64]          36,864
       BatchNorm2d-9            [2, 64, 64, 64]             128
             ReLU-10            [2, 64, 64, 64]               0
       BasicBlock-11            [2, 64, 64, 64]               0
           Conv2d-12            [2, 64, 64, 64]          36,864
      BatchNorm2d-13            [2, 64, 64, 64]             128
             ReLU-14            [2, 64, 64, 64]               0
           Conv2d-15            [2, 64, 64, 64]          36,864
      BatchNorm2d-16            [2, 64, 64, 64]             128
             ReLU-17            [2, 64, 64, 64]               0
       BasicBlock-18            [2, 64, 64, 64]               0
           Conv2d-19           [2, 128, 32, 32]          73,728
      BatchNorm2d-20           [2, 128, 32, 32]             256
             ReLU-21           [2, 128, 32, 32]               0
           Conv2d-22           [2, 128, 32, 32]         147,456
      BatchNorm2d-23           [2, 128, 32, 32]             256
           Conv2d-24           [2, 128, 32, 32]           8,192
      BatchNorm2d-25           [2, 128, 32, 32]             256
             ReLU-26           [2, 128, 32, 32]               0
       BasicBlock-27           [2, 128, 32, 32]               0
           Conv2d-28           [2, 128, 32, 32]         147,456
      BatchNorm2d-29           [2, 128, 32, 32]             256
             ReLU-30           [2, 128, 32, 32]               0
           Conv2d-31           [2, 128, 32, 32]         147,456
      BatchNorm2d-32           [2, 128, 32, 32]             256
             ReLU-33           [2, 128, 32, 32]               0
       BasicBlock-34           [2, 128, 32, 32]               0
           Conv2d-35           [2, 256, 16, 16]         294,912
      BatchNorm2d-36           [2, 256, 16, 16]             512
             ReLU-37           [2, 256, 16, 16]               0
           Conv2d-38           [2, 256, 16, 16]         589,824
      BatchNorm2d-39           [2, 256, 16, 16]             512
           Conv2d-40           [2, 256, 16, 16]          32,768
      BatchNorm2d-41           [2, 256, 16, 16]             512
             ReLU-42           [2, 256, 16, 16]               0
       BasicBlock-43           [2, 256, 16, 16]               0
           Conv2d-44           [2, 256, 16, 16]         589,824
      BatchNorm2d-45           [2, 256, 16, 16]             512
             ReLU-46           [2, 256, 16, 16]               0
           Conv2d-47           [2, 256, 16, 16]         589,824
      BatchNorm2d-48           [2, 256, 16, 16]             512
             ReLU-49           [2, 256, 16, 16]               0
       BasicBlock-50           [2, 256, 16, 16]               0
           Conv2d-51             [2, 512, 8, 8]       1,179,648
      BatchNorm2d-52             [2, 512, 8, 8]           1,024
             ReLU-53             [2, 512, 8, 8]               0
           Conv2d-54             [2, 512, 8, 8]       2,359,296
      BatchNorm2d-55             [2, 512, 8, 8]           1,024
           Conv2d-56             [2, 512, 8, 8]         131,072
      BatchNorm2d-57             [2, 512, 8, 8]           1,024
             ReLU-58             [2, 512, 8, 8]               0
       BasicBlock-59             [2, 512, 8, 8]               0
           Conv2d-60             [2, 512, 8, 8]       2,359,296
      BatchNorm2d-61             [2, 512, 8, 8]           1,024
             ReLU-62             [2, 512, 8, 8]               0
           Conv2d-63             [2, 512, 8, 8]       2,359,296
      BatchNorm2d-64             [2, 512, 8, 8]           1,024
             ReLU-65             [2, 512, 8, 8]               0
       BasicBlock-66             [2, 512, 8, 8]               0
AdaptiveAvgPool2d-67             [2, 512, 1, 1]               0
           Linear-68                  [2, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.50
Forward/backward pass size (MB): 164.02
Params size (MB): 44.59
Estimated Total Size (MB): 210.12
----------------------------------------------------------------
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

MLP Head往上是我需要用到的。

MLP Head

ViT的最后一层就是一个简单的全连接层,输出分类的概率值。它对整个序列执行一个mean操作。
在这里插入图片描述

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

我们将之前定义好的PatchEmbedding,TransformerEncoder,ClassificationHead整合起来,搭建出最终的ViT代码模型,如下:

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

我们可以使用torchsummary函数来计算参数量,输出如下:

model = ViT()
summary(model, input_size=[(3, 224, 224)], batch_size=1, device="cpu")
  • 1
  • 2
torch.Size([2, 197, 768]) torch.Size([197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
1qkv's shape:  torch.Size([2, 197, 2304])
2qkv's shape:  torch.Size([3, 2, 8, 197, 96])
queries's shape:  torch.Size([2, 8, 197, 96])
keys's shape:  torch.Size([2, 8, 197, 96])
values's shape:  torch.Size([2, 8, 197, 96])
energy's shape:  torch.Size([2, 8, 197, 197])
scaling:  27.712812921102035
att1' shape:  torch.Size([2, 8, 197, 197])
att2' shape:  torch.Size([2, 8, 197, 197])
out1's shape:  torch.Size([2, 8, 197, 96])
out2's shape:  torch.Size([2, 197, 768])
out3's shape:  torch.Size([2, 197, 768])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [1, 768, 14, 14]         590,592
         Rearrange-2              [1, 196, 768]               0
    PatchEmbedding-3              [1, 197, 768]               0
         LayerNorm-4              [1, 197, 768]           1,536
            Linear-5             [1, 197, 2304]       1,771,776
            Linear-6             [1, 197, 2304]       1,771,776
           Dropout-7           [1, 8, 197, 197]               0
            Linear-8              [1, 197, 768]         590,592
MultiHeadAttention-9              [1, 197, 768]               0
          Dropout-10              [1, 197, 768]               0
      ResidualAdd-11              [1, 197, 768]               0
        LayerNorm-12              [1, 197, 768]           1,536
           Linear-13             [1, 197, 3072]       2,362,368
             GELU-14             [1, 197, 3072]               0
          Dropout-15             [1, 197, 3072]               0
           Linear-16              [1, 197, 768]       2,360,064
          Dropout-17              [1, 197, 768]               0
      ResidualAdd-18              [1, 197, 768]               0
        LayerNorm-19              [1, 197, 768]           1,536
           Linear-20             [1, 197, 2304]       1,771,776
           Linear-21             [1, 197, 2304]       1,771,776
          Dropout-22           [1, 8, 197, 197]               0
           Linear-23              [1, 197, 768]         590,592
MultiHeadAttention-24              [1, 197, 768]               0
          Dropout-25              [1, 197, 768]               0
      ResidualAdd-26              [1, 197, 768]               0
        LayerNorm-27              [1, 197, 768]           1,536
           Linear-28             [1, 197, 3072]       2,362,368
             GELU-29             [1, 197, 3072]               0
          Dropout-30             [1, 197, 3072]               0
           Linear-31              [1, 197, 768]       2,360,064
          Dropout-32              [1, 197, 768]               0
      ResidualAdd-33              [1, 197, 768]               0
        LayerNorm-34              [1, 197, 768]           1,536
           Linear-35             [1, 197, 2304]       1,771,776
           Linear-36             [1, 197, 2304]       1,771,776
          Dropout-37           [1, 8, 197, 197]               0
           Linear-38              [1, 197, 768]         590,592
MultiHeadAttention-39              [1, 197, 768]               0
          Dropout-40              [1, 197, 768]               0
      ResidualAdd-41              [1, 197, 768]               0
        LayerNorm-42              [1, 197, 768]           1,536
           Linear-43             [1, 197, 3072]       2,362,368
             GELU-44             [1, 197, 3072]               0
          Dropout-45             [1, 197, 3072]               0
           Linear-46              [1, 197, 768]       2,360,064
          Dropout-47              [1, 197, 768]               0
      ResidualAdd-48              [1, 197, 768]               0
        LayerNorm-49              [1, 197, 768]           1,536
           Linear-50             [1, 197, 2304]       1,771,776
           Linear-51             [1, 197, 2304]       1,771,776
          Dropout-52           [1, 8, 197, 197]               0
           Linear-53              [1, 197, 768]         590,592
MultiHeadAttention-54              [1, 197, 768]               0
          Dropout-55              [1, 197, 768]               0
      ResidualAdd-56              [1, 197, 768]               0
        LayerNorm-57              [1, 197, 768]           1,536
           Linear-58             [1, 197, 3072]       2,362,368
             GELU-59             [1, 197, 3072]               0
          Dropout-60             [1, 197, 3072]               0
           Linear-61              [1, 197, 768]       2,360,064
          Dropout-62              [1, 197, 768]               0
      ResidualAdd-63              [1, 197, 768]               0
        LayerNorm-64              [1, 197, 768]           1,536
           Linear-65             [1, 197, 2304]       1,771,776
           Linear-66             [1, 197, 2304]       1,771,776
          Dropout-67           [1, 8, 197, 197]               0
           Linear-68              [1, 197, 768]         590,592
MultiHeadAttention-69              [1, 197, 768]               0
          Dropout-70              [1, 197, 768]               0
      ResidualAdd-71              [1, 197, 768]               0
        LayerNorm-72              [1, 197, 768]           1,536
           Linear-73             [1, 197, 3072]       2,362,368
             GELU-74             [1, 197, 3072]               0
          Dropout-75             [1, 197, 3072]               0
           Linear-76              [1, 197, 768]       2,360,064
          Dropout-77              [1, 197, 768]               0
      ResidualAdd-78              [1, 197, 768]               0
        LayerNorm-79              [1, 197, 768]           1,536
           Linear-80             [1, 197, 2304]       1,771,776
           Linear-81             [1, 197, 2304]       1,771,776
          Dropout-82           [1, 8, 197, 197]               0
           Linear-83              [1, 197, 768]         590,592
MultiHeadAttention-84              [1, 197, 768]               0
          Dropout-85              [1, 197, 768]               0
      ResidualAdd-86              [1, 197, 768]               0
        LayerNorm-87              [1, 197, 768]           1,536
           Linear-88             [1, 197, 3072]       2,362,368
             GELU-89             [1, 197, 3072]               0
          Dropout-90             [1, 197, 3072]               0
           Linear-91              [1, 197, 768]       2,360,064
          Dropout-92              [1, 197, 768]               0
      ResidualAdd-93              [1, 197, 768]               0
        LayerNorm-94              [1, 197, 768]           1,536
           Linear-95             [1, 197, 2304]       1,771,776
           Linear-96             [1, 197, 2304]       1,771,776
          Dropout-97           [1, 8, 197, 197]               0
           Linear-98              [1, 197, 768]         590,592
MultiHeadAttention-99              [1, 197, 768]               0
         Dropout-100              [1, 197, 768]               0
     ResidualAdd-101              [1, 197, 768]               0
       LayerNorm-102              [1, 197, 768]           1,536
          Linear-103             [1, 197, 3072]       2,362,368
            GELU-104             [1, 197, 3072]               0
         Dropout-105             [1, 197, 3072]               0
          Linear-106              [1, 197, 768]       2,360,064
         Dropout-107              [1, 197, 768]               0
     ResidualAdd-108              [1, 197, 768]               0
       LayerNorm-109              [1, 197, 768]           1,536
          Linear-110             [1, 197, 2304]       1,771,776
          Linear-111             [1, 197, 2304]       1,771,776
         Dropout-112           [1, 8, 197, 197]               0
          Linear-113              [1, 197, 768]         590,592
MultiHeadAttention-114              [1, 197, 768]               0
         Dropout-115              [1, 197, 768]               0
     ResidualAdd-116              [1, 197, 768]               0
       LayerNorm-117              [1, 197, 768]           1,536
          Linear-118             [1, 197, 3072]       2,362,368
            GELU-119             [1, 197, 3072]               0
         Dropout-120             [1, 197, 3072]               0
          Linear-121              [1, 197, 768]       2,360,064
         Dropout-122              [1, 197, 768]               0
     ResidualAdd-123              [1, 197, 768]               0
       LayerNorm-124              [1, 197, 768]           1,536
          Linear-125             [1, 197, 2304]       1,771,776
          Linear-126             [1, 197, 2304]       1,771,776
         Dropout-127           [1, 8, 197, 197]               0
          Linear-128              [1, 197, 768]         590,592
MultiHeadAttention-129              [1, 197, 768]               0
         Dropout-130              [1, 197, 768]               0
     ResidualAdd-131              [1, 197, 768]               0
       LayerNorm-132              [1, 197, 768]           1,536
          Linear-133             [1, 197, 3072]       2,362,368
            GELU-134             [1, 197, 3072]               0
         Dropout-135             [1, 197, 3072]               0
          Linear-136              [1, 197, 768]       2,360,064
         Dropout-137              [1, 197, 768]               0
     ResidualAdd-138              [1, 197, 768]               0
       LayerNorm-139              [1, 197, 768]           1,536
          Linear-140             [1, 197, 2304]       1,771,776
          Linear-141             [1, 197, 2304]       1,771,776
         Dropout-142           [1, 8, 197, 197]               0
          Linear-143              [1, 197, 768]         590,592
MultiHeadAttention-144              [1, 197, 768]               0
         Dropout-145              [1, 197, 768]               0
     ResidualAdd-146              [1, 197, 768]               0
       LayerNorm-147              [1, 197, 768]           1,536
          Linear-148             [1, 197, 3072]       2,362,368
            GELU-149             [1, 197, 3072]               0
         Dropout-150             [1, 197, 3072]               0
          Linear-151              [1, 197, 768]       2,360,064
         Dropout-152              [1, 197, 768]               0
     ResidualAdd-153              [1, 197, 768]               0
       LayerNorm-154              [1, 197, 768]           1,536
          Linear-155             [1, 197, 2304]       1,771,776
          Linear-156             [1, 197, 2304]       1,771,776
         Dropout-157           [1, 8, 197, 197]               0
          Linear-158              [1, 197, 768]         590,592
MultiHeadAttention-159              [1, 197, 768]               0
         Dropout-160              [1, 197, 768]               0
     ResidualAdd-161              [1, 197, 768]               0
       LayerNorm-162              [1, 197, 768]           1,536
          Linear-163             [1, 197, 3072]       2,362,368
            GELU-164             [1, 197, 3072]               0
         Dropout-165             [1, 197, 3072]               0
          Linear-166              [1, 197, 768]       2,360,064
         Dropout-167              [1, 197, 768]               0
     ResidualAdd-168              [1, 197, 768]               0
       LayerNorm-169              [1, 197, 768]           1,536
          Linear-170             [1, 197, 2304]       1,771,776
          Linear-171             [1, 197, 2304]       1,771,776
         Dropout-172           [1, 8, 197, 197]               0
          Linear-173              [1, 197, 768]         590,592
MultiHeadAttention-174              [1, 197, 768]               0
         Dropout-175              [1, 197, 768]               0
     ResidualAdd-176              [1, 197, 768]               0
       LayerNorm-177              [1, 197, 768]           1,536
          Linear-178             [1, 197, 3072]       2,362,368
            GELU-179             [1, 197, 3072]               0
         Dropout-180             [1, 197, 3072]               0
          Linear-181              [1, 197, 768]       2,360,064
         Dropout-182              [1, 197, 768]               0
     ResidualAdd-183              [1, 197, 768]               0
          Reduce-184                   [1, 768]               0
       LayerNorm-185                   [1, 768]           1,536
          Linear-186                  [1, 1000]         769,000
================================================================
Total params: 107,676,904
Trainable params: 107,676,904
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 405.89
Params size (MB): 410.75
Estimated Total Size (MB): 817.22
----------------------------------------------------------------
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 与其他ViT实现代码相比,这个参数量是差不多的。
  • 原文的代码仓库在https://github.com/FrancescoSaverioZuppichini/ViT。

参考文献

  • https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632
  • https://arxiv.org/abs/2010.11929
  • https://www.jianshu.com/p/06a40338dc7c
  • https://www.jianshu.com/p/d4bc4f540c62
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/345794
推荐阅读
相关标签
  

闽ICP备14008679号