赞
踩
因为Google Research官方的Vision Transformer源码是tensorflow版本,而笔者平时多用pytorch,所以在github上找了作者rwightman版本的代码:rwightman/pytorch-image-models/timm/models/vision_transformer.py
Vision Transformer介绍博客:论文阅读笔记:Vision Transformer
下面的代码介绍以vit_base_patch16_224
(ViT-B/16:patch_size=16, img_size=224)为例。
原文中模型由三个模块组成:
· Linear Projection of Flattened Patches
· Transformer Encoder
· MLP Head
对应代码中的三个模块:
· patch embedding layer
· Block
· Representation layer + Classifier head
如图,Linear Projection of Flattened Patches的实现的通过一个kernel_size=stride=16
的卷积加上一个flatten实现的。他的功能是将
224
×
224
×
3
224×224×3
224×224×3 的的2D Image转换为
196
×
768
196×768
196×768 的Patch Embedding。具体代码及注释如下:
class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): super().__init__() ''' image_size = (224,224) patch_size = (16,16) gird_size = (224/16,224/16)=(14,14) num_patches = 14 * 14 = 196 ''' img_size = (img_size, img_size) patch_size = (patch_size, patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] ''' 使用大小为16,stride为16的卷积核实现embeding, 输出14*14大小,通道为768(768 = 16*16*3,相当于将每个patch部分转换为1维向量)的patch ''' self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) ''' 如果norm_layer为true则使用layerNorm,这里作者没有使用, 所以self.norm = nn.Identity(),对输入不做任何改变直接输出 ''' self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." ''' self.proj(x):[B,3,224,224]->[B,768,14,14] flatten(2):[B,768,14,14]->[B,768,14*14]=[B,768,196] transpose(1, 2):[B,768,196]->[B,196,768] self.norm(x)不对输入做处理直接输出 ''' x = self.proj(x).flat1ten(2).transpose(1, 2) x = self.norm(x) return x
Transformer Encoder由Attention、MLP和DropPath代码组成,其结构图如下:
关于 Multi-Head Attention 的结构图和详细介绍可查看博文,论文阅读笔记:Attention Is All You Need。
Attention具体代码及注释如下:
class Attention(nn.Module): def __init__(self, dim, # 输入token的dim 768 num_heads=8, qkv_bias=False, qk_scale=None, attn_drop_ratio=0., proj_drop_ratio=0.): super(Attention, self).__init__() ''' num_heads = 12 head_dim = 768 // 12 = 64 (Attention is all you need论文中提到的dk=dv=dmodel/h) scale = 64 ^ -0.5 = 1/8(Attention is all you need论文中Scaled Dot-Product Attention提到的公式Attention(Q,K,V)中的根号dk分之一) qkv:将输入线性映射到q,k,v proj:Attention is all you need论文中Multi-Head Attention最后的融合矩阵 Wo,使用 Linear 的实现 ''' self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop_ratio) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop_ratio) def forward(self, x): ''' B = batch_size N = 197 C = 768 ''' B, N, C = x.shape ''' qkv(x) : [B,197,768] -> [B,197,768*3] reshape : [B,197,768*3] -> [B,197,3,12,64] (3分别代表qkv,12个head,每个head为64维向量) permute:[B,197,3,12,64] -> [3,B,12,197,64] ''' qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) ''' q,k,v = [B,12,197,64] ''' q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) ''' K.transpose(-2, -1) : [B,12,197,64] = [B,12,64,197] q @ K.transpose(-2, -1) : [B,12,197,64] @ [B,12,64,197] = [B,12,197,197] attn : [B,12,197,197] attn.softmax(dim=-1)对最后一个维度(即每一行)进行softmax处理 ''' attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) ''' attn @ v = [B,12,197,197] @ [B,12,197,64] = [B,12,197,64] transpose(1, 2) : [B,197,12,64] reshape : [B,197,768] ''' x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
MLP结构和代码都很简单,就是全连接加激活函数加dropout,这里的激活函数用的GELU:
G E L U ( x ) = 0.5 x ( 1 + t a n h [ 2 π ( x + 0.044715 x 3 ) ] ) GELU(x)=0.5x(1+tanh[\frac{2}{π}(x+0.044715x^3)]) GELU(x)=0.5x(1+tanh[π2(x+0.044715x3)])
MLP模块代码如下:
class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
在Transformer Encoder中代码使用DropPath代替论文中的Dropout,具体代码及注释如下:
def drop_path(x, drop_prob: float = 0., training: bool = False): ''' x.shape : [B,197,768] ''' if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob ''' shape = [B,1,1] 即将X的第一维度保留,其他维度改为1 ''' shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets ''' 生成形状为shape的随机张量并加上keep_prob ''' random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) ''' 将随机张量向下取整,一部分为0,一部分为1 ''' random_tensor.floor_() # binarize ''' 将x除以keep_prob再乘上随机张量,一部分变成0,一部分保留 ''' output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training)
原文中关于MLP Head的代码:
# Representation layer if representation_size and not distilled: self.has_logits = True self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh()) ])) else: self.has_logits = False self.pre_logits = nn.Identity() # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
这里的代码也很简单,就不做过多注释了,代码中distilled = False
,所以:
self.pre_logits = nn.Sequential(nn.Linear,(embed_dim, representation_size)nn.Tanh())
self.head = nn.Linear(self.num_features, num_classes)
MLPHead(x) = self.head(self.pre_logits(x[:, 0]))
ViT-B/16整体网络结构如下图:
ViT-B/16模型使用的图像输入尺寸为 224×224×3,patch尺寸为16×16×3,每个patch embed的维度为768,transformer encoder block的个数为12, Multi-Head Attention的head个数为12,最后两个参数看调用模型时的参数设置,representation_size为pre_logits中全连接层节点个数,num_classes为预测的总分类数。
def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
model = VisionTransformer(img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes)
return model
VisionTransformer具体代码及注释如下:
class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None): """ Args: img_size (int, tuple): 输入图像尺寸 patch_size (int, tuple): patch 尺寸 in_c (int): 输入通道 num_classes (int): 分类数 embed_dim (int): patchembed 维度 depth (int): transformer encoder 模块( Block 模块)个数 num_heads (int): Multi-Head Attention 中的 head 个数 mlp_ratio (int): MLP 隐藏层和 embed_dim 的比例 qkv_bias (bool): 是否使用 qkv 偏置(即使用 Linear 将输入映射到 qkv 时,Linear是否使用 bias ) qk_scale (float): qk缩放比例,默认使用根号 dim_k 分之一 representation_size (Optional[int]): pre-logits 中的全连接节点个数,如果是 None 则不要 pre-logits (MLP Head 中只有一个全连接层) distilled (bool): 是否使用 DeiT 模型(基于知识蒸馏的transformer),在 VIT 中默认为 False drop_ratio (float): dropout概率 attn_drop_ratio (float): attention 中的 dropout 概率 drop_path_ratio (float): attention 中的 droppath 概率 embed_layer (nn.Module): patch embedding 层 norm_layer: (nn.Module): normalization 层 """ super(VisionTransformer, self).__init__() self.num_classes = num_classes ''' self.num_features = self.embed_dim = 768 self.num_tokens = 1 norm_layer = nn.LayerNorm(eps=1e-6) act_layer = nn.GELU ''' self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU ''' 构建patch embeding layer num_patches = (224/16) * (224/16) = 196 ''' self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches ''' 构建可学习参数: self.cls_token : [1,1,768] 分类token self.dist_token : None self.pos_embed : [1,197,768] 位置编码 ''' self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_ratio) ''' 构建首项为0,长度为depth的等差数列,且每一项小于drop_path_ratio 也就是说 传入 Block 的 droppath 概率是递增的。 代码这里是让 drop_path_ratio 默认等于0 最后利用参数构建 depth(12) 层 block 层 并把 LayerNorm(embed_dim) 赋值给self.norm ''' dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth) ]) self.norm = norm_layer(embed_dim) ''' 构建 pre_logits : 1.全连接层:输入embed_dim(768),输出representation_size(768) 2.激活函数:Tanh ''' # Representation layer if representation_size and not distilled: self.has_logits = True self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh()) ])) else: self.has_logits = False self.pre_logits = nn.Identity() ''' 构建分类器: self.num_features = 768 ''' # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() ''' 初始化pos_embed、cls_token 初始化网络其他层的权重 ''' # Weight init nn.init.trunc_normal_(self.pos_embed, std=0.02) if self.dist_token is not None: nn.init.trunc_normal_(self.dist_token, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) self.apply(_init_vit_weights) def forward_features(self, x): ''' self.patch_embed(x) : [B,3,224,224] -> [B,196,768] 合并 cls_token: self.cls_token : [1,1,768] cls_token : [B,1,768] x = torch.cat((cls_token, x), dim=1) : [B,197,768] ''' x = self.patch_embed(x) cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.dist_token is None: x = torch.cat((cls_token, x), dim=1) else: x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) ''' 加上位置编码: x = x + self.pos_embed : [B,197,768] 经过 Attention blocks 和 LayerNorm : [B,197,768] 最后返回分类 token 并传入 pre_logits: return self.pre_logits(x[:, 0]) : [B,768] ''' x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) if self.dist_token is None: return self.pre_logits(x[:, 0]) else: return x[:, 0], x[:, 1] def forward(self, x): ''' self.forward_features(x) : [B,3,224,224] -> [B,768] x = self.head(x) : [B,768] -> [B,num_classes] ''' x = self.forward_features(x) if self.head_dist is not None: x, x_dist = self.head(x[0]), self.head_dist(x[1]) if self.training and not torch.jit.is_scripting(): # during inference, return the average of both classifier predictions return x, x_dist else: return (x + x_dist) / 2 else: x = self.head(x) return x def _init_vit_weights(m): """ ViT weight initialization :param m: module """ if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.01) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.zeros_(m.bias) nn.init.ones_(m.weight)
上述代码的distilled参数所涉及的 DeiT models 代码中并没有使用,论文中也没有提到,如有疑惑可查看ViT和DeiT的原理与使用。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。