本文整体上是对Implementing Vision Transformer (ViT) in PyTorch的翻译,但是也加上了一些自己的注解。如果读者更习惯看英文版,建议直接去看原文。
前再加上一个特殊的cls token
加上位置信息编码positional encoding
。接着这个整体被送入Transformer Encoder
,然后取cls token
的输出特征送入MLP Head去做分类,总体流程就是这样。
Patches Embeddings
Transformer Encoder Block
Transformer Encoder
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
img = Image.open('./test.jpeg')
fig = plt.figure()
# resize to ImageNet size
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # 主要是为了添加batch这个维度
torch.Size([1, 3, 224, 224])
patch_size = 16 # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
torch.Size([1, 196, 768])
对其进行分割的时候,一共可以划分为224x224/16/16 = 196个patches,其次每个patch大小为16x16x3=768,故大小为[1,196,768]。
class PatchEmbedding(nn.Module): def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768): self.patch_size = patch_size super().__init__() self.projection = nn.Sequential( # 将原始图像切分为16*16的patch并把它们拉平 Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size), # 注意这里的隐层大小设置的也是768,可以配置 nn.Linear(patch_size * patch_size * in_channels, emb_size) ) def forward(self, x: Tensor) -> Tensor: x = self.projection(x) return x PatchEmbedding()(x).shape
torch.Size([1, 196, 768])
class PatchEmbedding(nn.Module): def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768): self.patch_size = patch_size super().__init__() self.projection = nn.Sequential( # 使用一个卷积层而不是一个线性层 -> 性能增加 nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), # 将卷积操作后的patch铺平 Rearrange('b e h w -> b (h w) e'), ) def forward(self, x: Tensor) -> Tensor: x = self.projection(x) return x PatchEmbedding()(x).shape
torch.Size([1, 196, 768])
下一步是对映射后的patches添加上cls token
以及位置编码信息。cls token
是一个随机初始化的torch Parameter对象,在forward
class PatchEmbedding(nn.Module): def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768): self.patch_size = patch_size super().__init__() self.proj = nn.Sequential( # 使用一个卷积层而不是一个线性层 -> 性能增加 nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), Rearrange('b e (h) (w) -> b (h w) e'), ) # 生成一个维度为emb_size的向量当做cls_token self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size)) def forward(self, x: Tensor) -> Tensor: b, _, _, _ = x.shape # 单独先将batch缓存起来 x = self.proj(x) # 进行卷积操作 # 将cls_token 扩展b次 cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) print(cls_tokens.shape) print(x.shape) # 将cls token在维度1扩展到输入上 x = torch.cat([cls_tokens, x], dim=1) return x PatchEmbedding()(x).shape
torch.Size([1, 1, 768])
torch.Size([1, 196, 768])
torch.Size([1, 197, 768])
在图像中的原始位置一无所知。我们需要传递给模型这些空间上的信息。可以有很多种方法来实现这个功能,在ViT中,我们让模型自己去学习这个。位置编码信息是一个形状为[N_PATCHES+1(token) * EMBED_SIZE]的张量,它直接加到映射后的patches
我们首先定义position embedding向量,然后在forward函数中将其加到线性映射后的patches向量(包含cls_token)上去。
class PatchEmbedding(nn.Module): def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224): self.patch_size = patch_size super().__init__() self.projection = nn.Sequential( # 使用一个卷积层而不是一个线性层 -> 性能增加 nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), Rearrange('b e (h) (w) -> b (h w) e'), ) self.cls_token = nn.Parameter(torch.randn(1,1, emb_size)) # 位置编码信息,一共有(img_size // patch_size)**2 + 1(cls token)个位置向量 self.positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size)) def forward(self, x: Tensor) -> Tensor: b, _, _, _ = x.shape x = self.projection(x) cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) # 将cls token在维度1扩展到输入上 x = torch.cat([cls_tokens, x], dim=1) # 添加位置编码 print(x.shape, self.positions.shape) x += self.positions return x PatchEmbedding()(x).shape
torch.Size([1, 197, 768]) torch.Size([197, 768])
torch.Size([1, 197, 768])
现在我们来实现Transformer Encoder Block模块。ViT模型中只使用了Transformer的Encoder部分,其整体架构如下:
class MultiHeadAttention(nn.Module): def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0): super().__init__() self.emb_size = emb_size self.num_heads = num_heads self.keys = nn.Linear(emb_size, emb_size) self.queries = nn.Linear(emb_size, emb_size) self.values = nn.Linear(emb_size, emb_size) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x : Tensor, mask: Tensor = None) -> Tensor: # split keys, queries and values in num_heads queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) # sum up over the last axis energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len if mask is not None: fill_value = torch.finfo(torch.float32).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1/2) att = F.softmax(energy, dim=-1) / scaling att = self.att_drop(att) # sum up over the third axis out = torch.einsum('bhal, bhlv -> bhav ', att, values) out = rearrange(out, "b h n d -> b n (h d)") out = self.projection(out) return out
,以及最后的线性映射层。关于这块更加详细的内容可以阅读The Illustrated Transformer。主要的思想是使用queries
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)
.然后我们将多个head的输出拼接在一起就得到了Multi-Head Attention最终的输出。
class MultiHeadAttention(nn.Module): def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0): super().__init__() self.emb_size = emb_size self.num_heads = num_heads # 使用单个矩阵一次性计算出queries,keys,values self.qkv = nn.Linear(emb_size, emb_size * 3) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x : Tensor, mask: Tensor = None) -> Tensor: # 将queries,keys和values划分为num_heads print("1qkv's shape: ", self.qkv(x).shape) # 使用单个矩阵一次性计算出queries,keys,values qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3) # 划分到num_heads个头上 print("2qkv's shape: ", qkv.shape) queries, keys, values = qkv[0], qkv[1], qkv[2] print("queries's shape: ", queries.shape) print("keys's shape: ", keys.shape) print("values's shape: ", values.shape) # 在最后一个维度上相加 energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len print("energy's shape: ", energy.shape) if mask is not None: fill_value = torch.finfo(torch.float32).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1/2) print("scaling: ", scaling) att = F.softmax(energy, dim=-1) / scaling print("att1' shape: ", att.shape) att = self.att_drop(att) print("att2' shape: ", att.shape) # 在第三个维度上相加 out = torch.einsum('bhal, bhlv -> bhav ', att, values) print("out1's shape: ", out.shape) out = rearrange(out, "b h n d -> b n (h d)") print("out2's shape: ", out.shape) out = self.projection(out) print("out3's shape: ", out.shape) return out patches_embedded = PatchEmbedding()(x) print("patches_embedding's shape: ", patches_embedded.shape) MultiHeadAttention()(patches_embedded).shape
torch.Size([1, 197, 768]) torch.Size([197, 768])
patches_embedding's shape: torch.Size([1, 197, 768])
1qkv's shape: torch.Size([1, 197, 2304])
2qkv's shape: torch.Size([3, 1, 8, 197, 96])
queries's shape: torch.Size([1, 8, 197, 96])
keys's shape: torch.Size([1, 8, 197, 96])
values's shape: torch.Size([1, 8, 197, 96])
energy's shape: torch.Size([1, 8, 197, 197])
scaling: 27.712812921102035
att1' shape: torch.Size([1, 8, 197, 197])
att2' shape: torch.Size([1, 8, 197, 197])
out1's shape: torch.Size([1, 8, 197, 96])
out2's shape: torch.Size([1, 197, 768])
out3's shape: torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
class ResidualAdd(nn.Module):
def __init__(self, fn):
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
nn.Linear(emb_size, expansion * emb_size),
nn.Linear(expansion * emb_size, emb_size),
- 作者说,不知道为什么,很少看见有人直接继承nn.Sequential类,这样就可以避免重写forward方法了。
- 译者著,确实~又学到了一招。
最终,我们可以创建一个完整的Transformer Encoder Block了。
利用我们之前定义好的ResidualAdd类,我们可以很优雅地定义出Transformer Encoder Block,如下:
class TransformerEncoderBlock(nn.Sequential): def __init__(self, emb_size: int = 768, drop_p: float = 0., forward_expansion: int = 4, forward_drop_p: float = 0., ** kwargs): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, **kwargs), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock( emb_size, expansion=forward_expansion, drop_p=forward_drop_p), nn.Dropout(drop_p) ) ))
patches_embedded = PatchEmbedding()(x)
torch.Size([1, 197, 768]) torch.Size([197, 768])
1qkv's shape: torch.Size([1, 197, 2304])
2qkv's shape: torch.Size([3, 1, 8, 197, 96])
queries's shape: torch.Size([1, 8, 197, 96])
keys's shape: torch.Size([1, 8, 197, 96])
values's shape: torch.Size([1, 8, 197, 96])
energy's shape: torch.Size([1, 8, 197, 197])
scaling: 27.712812921102035
att1' shape: torch.Size([1, 8, 197, 197])
att2' shape: torch.Size([1, 8, 197, 197])
out1's shape: torch.Size([1, 8, 197, 96])
out2's shape: torch.Size([1, 197, 768])
out3's shape: torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
patches_embedded = PatchEmbedding()(x)
torch.Size([1, 197, 768]) torch.Size([197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) 1qkv's shape: torch.Size([1, 197, 2304]) 2qkv's shape: torch.Size([3, 1, 8, 197, 96]) queries's shape: torch.Size([1, 8, 197, 96]) keys's shape: torch.Size([1, 8, 197, 96]) values's shape: torch.Size([1, 8, 197, 96]) energy's shape: torch.Size([1, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([1, 8, 197, 197]) att2' shape: torch.Size([1, 8, 197, 197]) out1's shape: torch.Size([1, 8, 197, 96]) out2's shape: torch.Size([1, 197, 768]) out3's shape: torch.Size([1, 197, 768]) torch.Size([1, 197, 768])
# summary可以打印网络结构和参数
from torchsummary import summary
from torchvision.models import resnet18
model = resnet18()
summary(model, input_size=[(3, 256, 256)], batch_size=2, device="cpu")
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [2, 64, 128, 128] 9,408 BatchNorm2d-2 [2, 64, 128, 128] 128 ReLU-3 [2, 64, 128, 128] 0 MaxPool2d-4 [2, 64, 64, 64] 0 Conv2d-5 [2, 64, 64, 64] 36,864 BatchNorm2d-6 [2, 64, 64, 64] 128 ReLU-7 [2, 64, 64, 64] 0 Conv2d-8 [2, 64, 64, 64] 36,864 BatchNorm2d-9 [2, 64, 64, 64] 128 ReLU-10 [2, 64, 64, 64] 0 BasicBlock-11 [2, 64, 64, 64] 0 Conv2d-12 [2, 64, 64, 64] 36,864 BatchNorm2d-13 [2, 64, 64, 64] 128 ReLU-14 [2, 64, 64, 64] 0 Conv2d-15 [2, 64, 64, 64] 36,864 BatchNorm2d-16 [2, 64, 64, 64] 128 ReLU-17 [2, 64, 64, 64] 0 BasicBlock-18 [2, 64, 64, 64] 0 Conv2d-19 [2, 128, 32, 32] 73,728 BatchNorm2d-20 [2, 128, 32, 32] 256 ReLU-21 [2, 128, 32, 32] 0 Conv2d-22 [2, 128, 32, 32] 147,456 BatchNorm2d-23 [2, 128, 32, 32] 256 Conv2d-24 [2, 128, 32, 32] 8,192 BatchNorm2d-25 [2, 128, 32, 32] 256 ReLU-26 [2, 128, 32, 32] 0 BasicBlock-27 [2, 128, 32, 32] 0 Conv2d-28 [2, 128, 32, 32] 147,456 BatchNorm2d-29 [2, 128, 32, 32] 256 ReLU-30 [2, 128, 32, 32] 0 Conv2d-31 [2, 128, 32, 32] 147,456 BatchNorm2d-32 [2, 128, 32, 32] 256 ReLU-33 [2, 128, 32, 32] 0 BasicBlock-34 [2, 128, 32, 32] 0 Conv2d-35 [2, 256, 16, 16] 294,912 BatchNorm2d-36 [2, 256, 16, 16] 512 ReLU-37 [2, 256, 16, 16] 0 Conv2d-38 [2, 256, 16, 16] 589,824 BatchNorm2d-39 [2, 256, 16, 16] 512 Conv2d-40 [2, 256, 16, 16] 32,768 BatchNorm2d-41 [2, 256, 16, 16] 512 ReLU-42 [2, 256, 16, 16] 0 BasicBlock-43 [2, 256, 16, 16] 0 Conv2d-44 [2, 256, 16, 16] 589,824 BatchNorm2d-45 [2, 256, 16, 16] 512 ReLU-46 [2, 256, 16, 16] 0 Conv2d-47 [2, 256, 16, 16] 589,824 BatchNorm2d-48 [2, 256, 16, 16] 512 ReLU-49 [2, 256, 16, 16] 0 BasicBlock-50 [2, 256, 16, 16] 0 Conv2d-51 [2, 512, 8, 8] 1,179,648 BatchNorm2d-52 [2, 512, 8, 8] 1,024 ReLU-53 [2, 512, 8, 8] 0 Conv2d-54 [2, 512, 8, 8] 2,359,296 BatchNorm2d-55 [2, 512, 8, 8] 1,024 Conv2d-56 [2, 512, 8, 8] 131,072 BatchNorm2d-57 [2, 512, 8, 8] 1,024 ReLU-58 [2, 512, 8, 8] 0 BasicBlock-59 [2, 512, 8, 8] 0 Conv2d-60 [2, 512, 8, 8] 2,359,296 BatchNorm2d-61 [2, 512, 8, 8] 1,024 ReLU-62 [2, 512, 8, 8] 0 Conv2d-63 [2, 512, 8, 8] 2,359,296 BatchNorm2d-64 [2, 512, 8, 8] 1,024 ReLU-65 [2, 512, 8, 8] 0 BasicBlock-66 [2, 512, 8, 8] 0 AdaptiveAvgPool2d-67 [2, 512, 1, 1] 0 Linear-68 [2, 1000] 513,000 ================================================================ Total params: 11,689,512 Trainable params: 11,689,512 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 1.50 Forward/backward pass size (MB): 164.02 Params size (MB): 44.59 Estimated Total Size (MB): 210.12 ----------------------------------------------------------------
MLP Head往上是我需要用到的。
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
Reduce('b n e -> b e', reduction='mean'),
nn.Linear(emb_size, n_classes))
class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, n_classes)
model = ViT()
summary(model, input_size=[(3, 224, 224)], batch_size=1, device="cpu")
torch.Size([2, 197, 768]) torch.Size([197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) 1qkv's shape: torch.Size([2, 197, 2304]) 2qkv's shape: torch.Size([3, 2, 8, 197, 96]) queries's shape: torch.Size([2, 8, 197, 96]) keys's shape: torch.Size([2, 8, 197, 96]) values's shape: torch.Size([2, 8, 197, 96]) energy's shape: torch.Size([2, 8, 197, 197]) scaling: 27.712812921102035 att1' shape: torch.Size([2, 8, 197, 197]) att2' shape: torch.Size([2, 8, 197, 197]) out1's shape: torch.Size([2, 8, 197, 96]) out2's shape: torch.Size([2, 197, 768]) out3's shape: torch.Size([2, 197, 768]) ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [1, 768, 14, 14] 590,592 Rearrange-2 [1, 196, 768] 0 PatchEmbedding-3 [1, 197, 768] 0 LayerNorm-4 [1, 197, 768] 1,536 Linear-5 [1, 197, 2304] 1,771,776 Linear-6 [1, 197, 2304] 1,771,776 Dropout-7 [1, 8, 197, 197] 0 Linear-8 [1, 197, 768] 590,592 MultiHeadAttention-9 [1, 197, 768] 0 Dropout-10 [1, 197, 768] 0 ResidualAdd-11 [1, 197, 768] 0 LayerNorm-12 [1, 197, 768] 1,536 Linear-13 [1, 197, 3072] 2,362,368 GELU-14 [1, 197, 3072] 0 Dropout-15 [1, 197, 3072] 0 Linear-16 [1, 197, 768] 2,360,064 Dropout-17 [1, 197, 768] 0 ResidualAdd-18 [1, 197, 768] 0 LayerNorm-19 [1, 197, 768] 1,536 Linear-20 [1, 197, 2304] 1,771,776 Linear-21 [1, 197, 2304] 1,771,776 Dropout-22 [1, 8, 197, 197] 0 Linear-23 [1, 197, 768] 590,592 MultiHeadAttention-24 [1, 197, 768] 0 Dropout-25 [1, 197, 768] 0 ResidualAdd-26 [1, 197, 768] 0 LayerNorm-27 [1, 197, 768] 1,536 Linear-28 [1, 197, 3072] 2,362,368 GELU-29 [1, 197, 3072] 0 Dropout-30 [1, 197, 3072] 0 Linear-31 [1, 197, 768] 2,360,064 Dropout-32 [1, 197, 768] 0 ResidualAdd-33 [1, 197, 768] 0 LayerNorm-34 [1, 197, 768] 1,536 Linear-35 [1, 197, 2304] 1,771,776 Linear-36 [1, 197, 2304] 1,771,776 Dropout-37 [1, 8, 197, 197] 0 Linear-38 [1, 197, 768] 590,592 MultiHeadAttention-39 [1, 197, 768] 0 Dropout-40 [1, 197, 768] 0 ResidualAdd-41 [1, 197, 768] 0 LayerNorm-42 [1, 197, 768] 1,536 Linear-43 [1, 197, 3072] 2,362,368 GELU-44 [1, 197, 3072] 0 Dropout-45 [1, 197, 3072] 0 Linear-46 [1, 197, 768] 2,360,064 Dropout-47 [1, 197, 768] 0 ResidualAdd-48 [1, 197, 768] 0 LayerNorm-49 [1, 197, 768] 1,536 Linear-50 [1, 197, 2304] 1,771,776 Linear-51 [1, 197, 2304] 1,771,776 Dropout-52 [1, 8, 197, 197] 0 Linear-53 [1, 197, 768] 590,592 MultiHeadAttention-54 [1, 197, 768] 0 Dropout-55 [1, 197, 768] 0 ResidualAdd-56 [1, 197, 768] 0 LayerNorm-57 [1, 197, 768] 1,536 Linear-58 [1, 197, 3072] 2,362,368 GELU-59 [1, 197, 3072] 0 Dropout-60 [1, 197, 3072] 0 Linear-61 [1, 197, 768] 2,360,064 Dropout-62 [1, 197, 768] 0 ResidualAdd-63 [1, 197, 768] 0 LayerNorm-64 [1, 197, 768] 1,536 Linear-65 [1, 197, 2304] 1,771,776 Linear-66 [1, 197, 2304] 1,771,776 Dropout-67 [1, 8, 197, 197] 0 Linear-68 [1, 197, 768] 590,592 MultiHeadAttention-69 [1, 197, 768] 0 Dropout-70 [1, 197, 768] 0 ResidualAdd-71 [1, 197, 768] 0 LayerNorm-72 [1, 197, 768] 1,536 Linear-73 [1, 197, 3072] 2,362,368 GELU-74 [1, 197, 3072] 0 Dropout-75 [1, 197, 3072] 0 Linear-76 [1, 197, 768] 2,360,064 Dropout-77 [1, 197, 768] 0 ResidualAdd-78 [1, 197, 768] 0 LayerNorm-79 [1, 197, 768] 1,536 Linear-80 [1, 197, 2304] 1,771,776 Linear-81 [1, 197, 2304] 1,771,776 Dropout-82 [1, 8, 197, 197] 0 Linear-83 [1, 197, 768] 590,592 MultiHeadAttention-84 [1, 197, 768] 0 Dropout-85 [1, 197, 768] 0 ResidualAdd-86 [1, 197, 768] 0 LayerNorm-87 [1, 197, 768] 1,536 Linear-88 [1, 197, 3072] 2,362,368 GELU-89 [1, 197, 3072] 0 Dropout-90 [1, 197, 3072] 0 Linear-91 [1, 197, 768] 2,360,064 Dropout-92 [1, 197, 768] 0 ResidualAdd-93 [1, 197, 768] 0 LayerNorm-94 [1, 197, 768] 1,536 Linear-95 [1, 197, 2304] 1,771,776 Linear-96 [1, 197, 2304] 1,771,776 Dropout-97 [1, 8, 197, 197] 0 Linear-98 [1, 197, 768] 590,592 MultiHeadAttention-99 [1, 197, 768] 0 Dropout-100 [1, 197, 768] 0 ResidualAdd-101 [1, 197, 768] 0 LayerNorm-102 [1, 197, 768] 1,536 Linear-103 [1, 197, 3072] 2,362,368 GELU-104 [1, 197, 3072] 0 Dropout-105 [1, 197, 3072] 0 Linear-106 [1, 197, 768] 2,360,064 Dropout-107 [1, 197, 768] 0 ResidualAdd-108 [1, 197, 768] 0 LayerNorm-109 [1, 197, 768] 1,536 Linear-110 [1, 197, 2304] 1,771,776 Linear-111 [1, 197, 2304] 1,771,776 Dropout-112 [1, 8, 197, 197] 0 Linear-113 [1, 197, 768] 590,592 MultiHeadAttention-114 [1, 197, 768] 0 Dropout-115 [1, 197, 768] 0 ResidualAdd-116 [1, 197, 768] 0 LayerNorm-117 [1, 197, 768] 1,536 Linear-118 [1, 197, 3072] 2,362,368 GELU-119 [1, 197, 3072] 0 Dropout-120 [1, 197, 3072] 0 Linear-121 [1, 197, 768] 2,360,064 Dropout-122 [1, 197, 768] 0 ResidualAdd-123 [1, 197, 768] 0 LayerNorm-124 [1, 197, 768] 1,536 Linear-125 [1, 197, 2304] 1,771,776 Linear-126 [1, 197, 2304] 1,771,776 Dropout-127 [1, 8, 197, 197] 0 Linear-128 [1, 197, 768] 590,592 MultiHeadAttention-129 [1, 197, 768] 0 Dropout-130 [1, 197, 768] 0 ResidualAdd-131 [1, 197, 768] 0 LayerNorm-132 [1, 197, 768] 1,536 Linear-133 [1, 197, 3072] 2,362,368 GELU-134 [1, 197, 3072] 0 Dropout-135 [1, 197, 3072] 0 Linear-136 [1, 197, 768] 2,360,064 Dropout-137 [1, 197, 768] 0 ResidualAdd-138 [1, 197, 768] 0 LayerNorm-139 [1, 197, 768] 1,536 Linear-140 [1, 197, 2304] 1,771,776 Linear-141 [1, 197, 2304] 1,771,776 Dropout-142 [1, 8, 197, 197] 0 Linear-143 [1, 197, 768] 590,592 MultiHeadAttention-144 [1, 197, 768] 0 Dropout-145 [1, 197, 768] 0 ResidualAdd-146 [1, 197, 768] 0 LayerNorm-147 [1, 197, 768] 1,536 Linear-148 [1, 197, 3072] 2,362,368 GELU-149 [1, 197, 3072] 0 Dropout-150 [1, 197, 3072] 0 Linear-151 [1, 197, 768] 2,360,064 Dropout-152 [1, 197, 768] 0 ResidualAdd-153 [1, 197, 768] 0 LayerNorm-154 [1, 197, 768] 1,536 Linear-155 [1, 197, 2304] 1,771,776 Linear-156 [1, 197, 2304] 1,771,776 Dropout-157 [1, 8, 197, 197] 0 Linear-158 [1, 197, 768] 590,592 MultiHeadAttention-159 [1, 197, 768] 0 Dropout-160 [1, 197, 768] 0 ResidualAdd-161 [1, 197, 768] 0 LayerNorm-162 [1, 197, 768] 1,536 Linear-163 [1, 197, 3072] 2,362,368 GELU-164 [1, 197, 3072] 0 Dropout-165 [1, 197, 3072] 0 Linear-166 [1, 197, 768] 2,360,064 Dropout-167 [1, 197, 768] 0 ResidualAdd-168 [1, 197, 768] 0 LayerNorm-169 [1, 197, 768] 1,536 Linear-170 [1, 197, 2304] 1,771,776 Linear-171 [1, 197, 2304] 1,771,776 Dropout-172 [1, 8, 197, 197] 0 Linear-173 [1, 197, 768] 590,592 MultiHeadAttention-174 [1, 197, 768] 0 Dropout-175 [1, 197, 768] 0 ResidualAdd-176 [1, 197, 768] 0 LayerNorm-177 [1, 197, 768] 1,536 Linear-178 [1, 197, 3072] 2,362,368 GELU-179 [1, 197, 3072] 0 Dropout-180 [1, 197, 3072] 0 Linear-181 [1, 197, 768] 2,360,064 Dropout-182 [1, 197, 768] 0 ResidualAdd-183 [1, 197, 768] 0 Reduce-184 [1, 768] 0 LayerNorm-185 [1, 768] 1,536 Linear-186 [1, 1000] 769,000 ================================================================ Total params: 107,676,904 Trainable params: 107,676,904 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 405.89 Params size (MB): 410.75 Estimated Total Size (MB): 817.22 ----------------------------------------------------------------
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。