当前位置:   article > 正文

详解VIT(Vision Transformer)模型原理, 代码级讲解_vit模型

vit模型

一、模型简介

1. 论文地位:VIT模型(Vision Transformer),这是一篇Google于2021年发表在计算机视觉顶级会议ICLR上的一篇文章。它首次将Transformer这种发源于NLP领域的模型引入到了CV领域,并在ImageNet数据集上击败了当时最先进的CNN网络。这是一个标志性的网络,代表transformer击败了CNN和RNN,同时在CV领域和NLP领域达到了统治地位,此后基本在ImageNet排行榜上都是基于transformer架构的模型了。

2. 论文下载链接AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE  ;

3. 推荐复现的代码仓库,可以star我这个GitHub开源项目,对每行代码有详尽的注释:VIT模型详解

二、模型亮点及整体架构介绍

1. 整体架构:VIT首次将transformer模型运用到CV领域并且取得了不错的分类效果,模型原理图如图1所示。该图表示了VIT模型的整体架构,可以看出VIT只用了transformer模型的编码器部分,并未涉及解码器。其实VIT模型不难理解,只需要将其拆成三个部分(1.图像特征嵌入模块;2.transformer编码器模块;3.MLP分类模块)就可以很容易捋顺它的结构。

        1.1.图像特征嵌入模块:标准的VIT模型对图像的输入尺寸有要求,必须为224*224.图像输入之后首先是需要进行patch分块,一般设置patch的尺寸为16*16,那么一共能生成(224/16)*(224/16)=196个patch块。这部分内容在代码中如何实现呢?其实很简单,就是用一个卷积层就可以实现,其卷积核大小为patch size=16, 步长为patch size=16.

具体的代码如下所示,每行代码均有详细注释,展示了图像分块和特征嵌入的完整过程,嵌入之后的特征维度是[196,768],之后我们还需要加上位置编码和类别token,前者使用直接相加的方法,后者使用concat的方法,所以加上类别token后,特征的维度变化为[197,768]:        

  1. class PatchEmbed(nn.Module): # 继承nn.Module
  2. """
  3. 所有注释均采用VIT-base进行说明
  4. 图像嵌入模块类
  5. 2D Image to Patch Embedding
  6. """
  7. # 初始化函数,设置默认参数
  8. def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
  9. super().__init__() # 继承父类的初始化方法
  10. # 输入图像的size为224*224
  11. img_size = (img_size, img_size)
  12. # patch_size为16*16
  13. patch_size = (patch_size, patch_size)
  14. self.img_size = img_size
  15. self.patch_size = patch_size
  16. # 滑动窗口的大小为14*14, 224/16=14
  17. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  18. # 图像切分的patch的总数为14*14=196
  19. self.num_patches = self.grid_size[0] * self.grid_size[1]
  20. # 使用一个卷积层来实现图像嵌入,输入维度为[BatchSize,3,224,224],输出维度为[BatchSize,768,14,14],
  21. # 计算公式 size= (224-16+2*0)/16 + 1= 14
  22. self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
  23. # 如果norm_layer为True,则使用,否则忽略
  24. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  25. def forward(self, x):
  26. # 获取BatchSize,Channels,Height,Width
  27. B, C, H, W = x.shape # 输入批量的图像
  28. # VIT模型对图像的输入尺寸有严格要求,故需要检查是否为224*224,如果不是则报错提示
  29. assert H == self.img_size[0] and W == self.img_size[1], \
  30. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  31. # flatten: [B, C, H, W] -> [B, C, HW]
  32. # transpose: [B, C, HW] -> [B, HW, C]
  33. # 将卷积嵌入后的图像特征图[BatchSize, 768, 14,14],从第3维度开始展平,得到BatchSize*768*196;
  34. # 然后转置1,2两个维度,得到BatchSize*196*768
  35. x = self.proj(x).flatten(2).transpose(1, 2)
  36. # norm_layer层规范化
  37. x = self.norm(x)
  38. return x

        1.2. Transformer Encoder:主要由LayerNorm层,多头注意力机制,MLP模块,残差连接这5个知识点模块构成组成。这是整个VIT模型最重要也是最需要花精力理解的地方,要搞清楚transformer的编码器部分,首先需要搞清楚多头注意力机制。

        由于注意力机制这块内容较多,引用博文供读者学习:狗都能看懂的self attention讲解 ;

 详解self attention以及 multi head self attention的原理 。后续将补上对于注意力机制这部分内容自己的理解,其实理解了注意力机制也就理解了transformer架构的基本原理。两种注意力机制的原理如下图所示。

        接下来我们按顺序往前讲解,LayerNorm层比较简单,直接调用pytorch中的nn.LayerNorm就可以实现。接下来看多头注意力机制的实现,如下所示。首先通过使用一个全连接层生成q,k,v的初始值,然后使用reshape和维度调换来进行调整,最后使用切片操作分别获得单独的Q,K,V,接下来就是transformer原始文章里提出的注意力机制的公式的实现了,公式如下:

  1. class Attention(nn.Module):
  2. # 经过注意力层的特征的输入和输出维度相同
  3. def __init__(self,
  4. dim, # 输入token的维度
  5. num_heads=8, # multiHead中 head的个数
  6. qkv_bias=False, # 决定生成Q,K,V时是否使用偏置
  7. qk_scale=None,
  8. attn_drop_ratio=0.,
  9. proj_drop_ratio=0.):
  10. super(Attention, self).__init__()
  11. # 设置多头注意力的注意力头的数目
  12. self.num_heads = num_heads
  13. # 针对每个head进行均分,它的Q,K,V对应的维度;
  14. head_dim = dim // num_heads
  15. # 放缩Q*(K的转置),就是根号head_dim(就是d(k))分之一,及和原论文保持一致
  16. self.scale = qk_scale or head_dim ** -0.5
  17. # 通过一个全连接层生成Q,K,V
  18. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  19. # 定义dropout层
  20. self.attn_drop = nn.Dropout(attn_drop_ratio)
  21. # 通过一个全连接层实现将每一个头得到的注意力进行拼接
  22. self.proj = nn.Linear(dim, dim)
  23. # 使用dropout层
  24. self.proj_drop = nn.Dropout(proj_drop_ratio)
  25. def forward(self, x):
  26. # [batch_size, num_patches + 1, total_embed_dim],其中,num_patches+1中的加1是为了保留class_token的位置
  27. B, N, C = x.shape # [batch_size, 197, 768]
  28. # 生成Q,K,V;qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
  29. # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head],
  30. # 调换维度位置,permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  31. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  32. # 将Q,K,V分离出来,切片后的Q,K,V的形状[batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  33. # 这个的维度相同,均为[BatchSize, 8, 197, 768]
  34. q, k, v = qkv[0], qkv[1], qkv[2]
  35. # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
  36. # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
  37. # 即文章里的那个根据q,k, v 计算注意力的公式的Q*K的转置再除以根号dk
  38. attn = (q @ k.transpose(-2, -1)) * self.scale
  39. # 对得到的注意力结果的每一行进行softmax处理
  40. attn = attn.softmax(dim=-1)
  41. # 添加dropout层
  42. attn = self.attn_drop(attn)
  43. # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  44. # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
  45. # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
  46. x = (attn @ v).transpose(1, 2).reshape(B, N, C) # 乘V,通过reshape处理将每个head对应的结果进行拼接处理
  47. # 使用全连接层进行映射,维度不变,输入[BatchSize, 197, 768], 输出也相同
  48. x = self.proj(x)
  49. # 添加dropout层
  50. x = self.proj_drop(x)
  51. return x

        然后是MLP模块,也就是transformer原文中的前馈网络(feed forward),这一部分其实比较简单,没什么可讲的,就是两个全连接层加上dropout层实现:

  1. class Mlp(nn.Module):
  2. """
  3. MLP as used in Vision Transformer, MLP-Mixer and related networks
  4. """
  5. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  6. super().__init__()
  7. # 注意当or两边存在None值时,输出是不为None那个
  8. out_features = out_features or in_features
  9. hidden_features = hidden_features or in_features
  10. self.fc1 = nn.Linear(in_features, hidden_features)
  11. self.act = act_layer()
  12. self.fc2 = nn.Linear(hidden_features, out_features)
  13. self.drop = nn.Dropout(drop)
  14. def forward(self, x):
  15. # MLP模块中的第一个全连接层,输入维度in_features, 输出维度hidden_features
  16. x = self.fc1(x)
  17. x = self.act(x)
  18. x = self.drop(x)
  19. # MLP模块中的第二个全连接层,输入维度hidden_features, 输出维度out_features
  20. x = self.fc2(x)
  21. x = self.drop(x)
  22. return x

        残差连接模块,这个比较简单,代码实现如下:

  1. # 相当于resnet中的shortcut,原论文中图1的结构中的+和箭头就是这个意思。因为输入和输出的维度完全相同,所以二者可以相加
  2. # 将输入x与经过layer norm和多头注意力处理后的值进行残差相加
  3. x = x + self.drop_path(self.attn(self.norm1(x)))
  4. # 将输入x与经过layer norm和MLP处理后的值进行残差相加
  5. x = x + self.drop_path(self.mlp(self.norm2(x)))

         好了,接下来是将这些不同功能的模块进行包装,代码和注释如下:

  1. class Block(nn.Module):
  2. # 集成了Transformer Encoder的所有功能
  3. def __init__(self,
  4. dim,
  5. num_heads,
  6. mlp_ratio=4.,
  7. qkv_bias=False,
  8. qk_scale=None,
  9. drop_ratio=0.,
  10. attn_drop_ratio=0.,
  11. drop_path_ratio=0.,
  12. act_layer=nn.GELU,
  13. norm_layer=nn.LayerNorm):
  14. super(Block, self).__init__()
  15. self.norm1 = norm_layer(dim)
  16. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  17. attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
  18. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  19. self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
  20. self.norm2 = norm_layer(dim)
  21. mlp_hidden_dim = int(dim * mlp_ratio)
  22. # 定义全连接层,输入维度为embedding dim=768,隐藏层为embedding dim*4=3072,输出层为in_features=768
  23. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
  24. def forward(self, x):
  25. # 相当于resnet中的shortcut,原论文中图1的结构中的+和箭头就是这个意思。因为输入和输出的维度完全相同,所以二者可以相加
  26. # 将输入x与经过layer norm和多头注意力处理后的值进行残差相加
  27. x = x + self.drop_path(self.attn(self.norm1(x)))
  28. # 将输入x与经过layer norm和MLP处理后的值进行残差相加
  29. x = x + self.drop_path(self.mlp(self.norm2(x)))
  30. return x

         其实讲到这里,基本已经理解了VIT模型知识体系的三分之二。接下来就是最后的MLP分类模块了,这一块比较简单,甚至可以只用一层全连接层来解决。之前没有了解过VIT的小伙伴,这里需要提示一下,我们输入到MLP类别分类器中的特征只有类别token。经过N层transformer编码器处理后的特征的维度与输入前相同,均为[197,768],我们只使用列表切片的方式提取出类别token,维度为[1,768].进行下一步的类别分类。有小伙伴可能不理解,那不是其它的特征没有用到吗?浪费了是不是。其实不是,多头注意力机制可以让不同位置的特征进行全面交互,这里输出的类别token和之前输入的类别token早已发生了巨变,这种变化是由其它特征影响的。

         最后提供一下,transformer模型的整体架构代码:

  1. class VisionTransformer(nn.Module):
  2. # 集成VIT模型架构
  3. def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
  4. embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
  5. qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
  6. attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
  7. act_layer=None):
  8. """
  9. Args:
  10. img_size (int, tuple): input image size
  11. patch_size (int, tuple): patch size
  12. in_c (int): number of input channels
  13. num_classes (int): number of classes for classification head
  14. embed_dim (int): embedding dimension
  15. depth (int): depth of transformer
  16. num_heads (int): number of attention heads
  17. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  18. qkv_bias (bool): enable bias for qkv if True
  19. qk_scale (float): override default qk scale of head_dim ** -0.5 if set
  20. representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
  21. distilled (bool): model includes a distillation token and head as in DeiT models
  22. drop_ratio (float): dropout rate
  23. attn_drop_ratio (float): attention dropout rate
  24. drop_path_ratio (float): stochastic depth rate
  25. embed_layer (nn.Module): patch embedding layer
  26. norm_layer: (nn.Module): normalization layer
  27. """
  28. super(VisionTransformer, self).__init__()
  29. self.num_classes = num_classes
  30. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  31. self.num_tokens = 2 if distilled else 1
  32. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  33. act_layer = act_layer or nn.GELU
  34. self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
  35. num_patches = self.patch_embed.num_patches
  36. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  37. self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
  38. # 创建可学习的位置token,形状为 (1, num_patches + self.num_tokens, embed_dim),初始值全为0
  39. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
  40. self.pos_drop = nn.Dropout(p=drop_ratio)
  41. dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
  42. # nn.Sequential作为容器用来包装层,通过列表循环构建了depth个block
  43. self.blocks = nn.Sequential(*[
  44. Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  45. drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
  46. norm_layer=norm_layer, act_layer=act_layer)
  47. for i in range(depth)
  48. ])
  49. self.norm = norm_layer(embed_dim)
  50. # Representation layer,即pre_logits层。这个if语句是问,如果需要pre_logits层并且不进行模型蒸馏则...
  51. if representation_size and not distilled:
  52. self.has_logits = True
  53. self.num_features = representation_size
  54. self.pre_logits = nn.Sequential(OrderedDict([
  55. ("fc", nn.Linear(embed_dim, representation_size)),
  56. ("act", nn.Tanh())
  57. ]))
  58. else:
  59. self.has_logits = False
  60. # 跳过pre_logits层
  61. self.pre_logits = nn.Identity()
  62. # Classifier head(s),使用一个全连接层进行分类结果输出
  63. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  64. self.head_dist = None
  65. if distilled:
  66. self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
  67. # Weight init,对位置编码进行初始化
  68. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  69. if self.dist_token is not None:
  70. nn.init.trunc_normal_(self.dist_token, std=0.02)
  71. # 对类别编码进行初始化
  72. nn.init.trunc_normal_(self.cls_token, std=0.02)
  73. self.apply(_init_vit_weights)
  74. def forward_features(self, x):
  75. # [B, C, H, W] -> [B, num_patches, embed_dim],即[B, 196, 768]
  76. x = self.patch_embed(x)
  77. # 定义类别token,维度为[1,1,768]
  78. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  79. # 注意dist_token是用于模型蒸馏的,此处设置为None
  80. if self.dist_token is None:
  81. # concat上类别token
  82. x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
  83. else:
  84. x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
  85. # 添加位置编码
  86. x = self.pos_drop(x + self.pos_embed)
  87. # 将特征输入transformer编码器
  88. x = self.blocks(x)
  89. # LayerNorm层
  90. x = self.norm(x)
  91. # 注意dist_token是用于模型蒸馏的,此处设置为None
  92. if self.dist_token is None:
  93. # 返回类别token
  94. return self.pre_logits(x[:, 0])
  95. else:
  96. return x[:, 0], x[:, 1]
  97. def forward(self, x):
  98. # 进行forward_features之后获得训练后的类别token
  99. x = self.forward_features(x)
  100. # head_dist等于None,直接执行else
  101. if self.head_dist is not None:
  102. x, x_dist = self.head(x[0]), self.head_dist(x[1])
  103. if self.training and not torch.jit.is_scripting():
  104. # during inference, return the average of both classifier predictions
  105. return x, x_dist
  106. else:
  107. return (x + x_dist) / 2
  108. else:
  109. # 使用全连接层输出分类结果
  110. x = self.head(x)
  111. return x

待更新!!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/764393
推荐阅读
相关标签
  

闽ICP备14008679号