当前位置:   article > 正文

VIT用于图像分类 学习笔记(附代码)_vit分类代码

vit分类代码

论文地址:https://arxiv.org/abs/2010.11929

代码地址:https://github.com/bubbliiiing/classification-pytorch

1.是什么?

Vision Transformer(VIT)是一种基于Transformer架构的图像分类模型。它将图像分割成一系列的图像块,并将每个图像块作为输入序列传递给Transformer模型。VIT通过自注意力机制来捕捉图像中的全局上下文信息,并使用多层感知机(MLP)来进行特征提取和分类。

VIT的核心思想是将图像转换为序列数据,这使得模型能够利用Transformer的强大表达能力来处理图像。通过将图像分割成图像块,并将它们展平为序列,VIT能够在不依赖传统卷积神经网络的情况下实现图像分类任务。

2.为什么?

从2020年,transformer开始在CV领域大放异彩:图像分类(ViT, DeiT),目标检测(DETR,Deformable DETR),语义分割(SETR,MedT),图像生成(GANsformer)等。而从深度学习暴发以来,CNN一直是CV领域的主流模型,而且取得了很好的效果,相比之下transformer却独霸NLP领域,transformer在CV领域的探索正是研究界想把transformer在NLP领域的成功借鉴到CV领域。对于图像问题,卷积具有天然的先天优势(inductive bias):平移等价性(translation equivariance)和局部性(locality)。而transformer虽然不并具备这些优势,但是transformer的核心self-attention的优势不像卷积那样有固定且有限的感受野,self-attention操作可以获得long-range信息(相比之下CNN要通过不断堆积Conv layers来获取更大的感受野),但训练的难度就比CNN要稍大一些。

ViT(vision transformer)是Google在2020年提出的直接将transformer应用在图像分类的模型,后面很多的工作都是基于ViT进行改进的。这篇论文也是受到其启发,尝试将Transformer应用到CV领域通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。

3.怎么样?

3.1网络结构

与寻常的分类网络类似,整个Vision Transformer可以分为两部分,一部分是特征提取部分,另一部分是分类部分。

在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。

3.2特征提取部分介绍

3.2.1Patch

Patch的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列

该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。也就是说,不是把图片分割,是做了一次简单的卷积,可以理解为初步特征提取,或者说是映射。

由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了。在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。
 

3.2.2Position Embedding

Position Embedding的作用主要是对组合序列加上[class]token以及Position Embedding

Position Embedding

除了patch embeddings,模型还需要另外一个特殊的position embedding。transformer和CNN不同,需要position embedding来编码tokens的位置信息,这主要是因为self-attention是permutation-invariant,即打乱sequence里的tokens的顺序并不会改变结果。如果不给模型提供patch的位置信息,那么模型就需要通过patchs的语义来学习拼图,这就额外增加了学习成本。ViT论文中对比了几种不同的position embedding方案(如下),最后发现如果不提供positional embedding效果会差,但其它各种类型的positional embedding效果都接近,这主要是因为ViT的输入是相对较大的patchs而不是pixels,所以学习位置信息相对容易很多。

  • 无positional embedding
  • 1-D positional embedding:把2-D的patchs看成1-D序列
  • 2-D positional embedding:考虑patchs的2-D位置(x, y)
  • Relative positional embeddings:patchs的相对位置

transformer原论文中是默认采用固定的positional embedding,但ViT中默认采用学习(训练的)的1-D positional embedding,在输入transformer的encoder之前直接将patch embeddings和positional embedding相加:

  1. # 这里多1是为了后面要说的class token,embed_dim即patch embed_dim
  2. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
  3. # patch emded + pos_embed
  4. x = x + self.pos_embed

对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb.,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.比起来没太大差别。

Class Token 

除了patch tokens,ViT借鉴BERT还增加了一个特殊的class token。后面会说,transformer的encoder输入是a sequence patch embeddings,输出也是同样长度的a sequence patch features,但图像分类最后需要获取image feature,简单的策略是采用pooling,比如求patch features的平均来获取image feature,但是ViT并没有采用类似的pooling策略,而是直接增加一个特殊的class token,其最后输出的特征加一个linear classifier就可以实现对图像的分类(ViT的pre-training时是接一个MLP head),所以输入ViT的sequence长度是�+1。class token对应的embedding在训练时随机初始化,然后通过训练得到,具体实现如下:

  1. # 随机初始化
  2. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  3. # Classifier head
  4. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  5. # 具体forward过程
  6. B = x.shape[0]
  7. x = self.patch_embed(x)
  8. cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
  9. x = torch.cat((cls_tokens, x), dim=1)
  10. x = x + self.pos_embed
3.2.3Transformer Encoder

Transformer Encoder的作用是将输入序列进行编码,生成一个高维表示,以便后续的任务处理。它由多个Encoder Block组成,每个Encoder Block包含一个多头自注意力层和一个前馈全连接层。在编码过程中,输入序列会经过多头自注意力层进行特征提取和关联性计算,然后再通过前馈全连接层进行非线性变换和特征融合。通过堆叠多个Encoder Block,Transformer Encoder能够捕捉输入序列中的语义信息和上下文关系,生成一个更加丰富的表示。

下图是太阳花的小绿豆绘制的Encoder Block,主要由以下几部分组成:

3.2.3.1Layer Norm

这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考链接

3.2.3.2 Multi-Head Attention

Multi-Head Attention是一种注意力机制,它由多个独立的注意力头组成。每个注意力头都可以学习到不同的表示子空间,并在不同的位置上联合关注信息。通过使用多个注意力头,Multi-Head Attention可以更好地捕捉输入序列中的不同关系和特征。

下面是一个示例代码,演示了如何实现Multi-Head Attention:

import torch

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads
        
        self.query_linear = nn.Linear(input_dim, input_dim)
        self.key_linear = nn.Linear(input_dim, input_dim)
        self.value_linear = nn.Linear(input_dim, input_dim)
        self.output_linear = nn.Linear(input_dim, input_dim)
        
    def forward(self, query, key, value):
        batch_size = query.size(0)
        
        # 线性变换得到query、key、value
        query = self.query_linear(query)
        key = self.key_linear(key)
        value = self.value_linear(value)
        
        # 将query、key、value分成多个头
        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力得分
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attention_weights = torch.softmax(scores, dim=-1)
        
        # 对value进行加权求和
        weighted_values = torch.matmul(attention_weights, value)
        
        # 将多个头的结果拼接起来
        weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        # 线性变换得到最终的输出
        output = self.output_linear(weighted_values)
        
        return output

Multi-Head Attention,看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入,每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。

(1)Self-Attention

假设输入的序列长度为2,输入就两个节点x_{1},x_{2},然后通过Input Embedding也就是图中的f(x)将输入映射到a_{1},a_{2} 。紧接着分别将a_{1},a_{2} 分别通过三个变换矩阵W_{q},W_{k},W_{v} ,(这三个参数是可训练的,是共享的)得到对应的q^{i},k^{i},v^{i}(这里在源码中是直接使用全连接层实现的,这里为了方便理解,忽略偏执)。

其中

q 代表query,后续会去和每一个k kk进行匹配
k 代表key,后续会被每个q qq匹配
v 代表从a aa中提取得到的信息
后续q 和k 匹配的过程可以理解成计算两者的相关性,相关性越大对应v 的权重也就越大

假设a_{1}=(1,1),a_{2}=(1,0),W^{q}=(_{0,1}^{1,1}),那么

q^{1}=(1,1)(_{0,1}^{1,1})=(1,2),q^{2}=(1,0)(_{0,1}^{1,1})=(1,1)

前面有说Transformer是可以并行化的,所以可以直接写成:

 

同理我们可以得到(_{k^{2}}^{k^{1}})(_{v^{2}}^{v^{1}}),那么求得的(_{q^{2}}^{q^{1}})就是原论文中的Q ,(_{k^{2}}^{k^{1}})就是K,(_{v^{2}}^{v^{1}})就是V.接着先拿q^{1}和每个k kk进行match,点乘操作,接着除以\sqrt{d}得到对应的α ,其中d 代表向量k^{i}的长度,在本示例中等于2,除以\sqrt{d}的原因在论文中的解释是“进行点乘后的数值很大,导致通过softmax后梯度变的很小”,所以通过除以\sqrt{d}来进行缩放。比如计算a_{1,i}:

同理拿q^{2} 去匹配所有的k kk能得到a_{2,i},统一写成矩阵乘法形式:

接着对每一行即(a_{1,1},a_{1,2})(a_{2,1},a_{2,2})分别进行softmax处理得到(a{\hat{}}_{1,1},a{\hat{}}_{1,2})(a{\hat{}}_{2,1},a{\hat{}}_{2,2}),这里的a\hat{}相当于计算得到针对每个v vv的权重。到这我们就完成了 Attention(Q,K,V)公式中softmax(\frac{QK^{T}}{\sqrt{d_{k}}})的部分
 

上面已经计算得到α \alphaα,即针对每个v vv的权重,接着进行加权得到最终结果:

统一写成矩阵乘法形式:

 

 到这,Self-Attention的内容就讲完了。总结下来就是论文中的一个公式:

(2)Multi-Head Attention 

首先还是和Self-Attention模块一样将a_{i}分别通过W_{q},W_{k},W_{v}得到对应的q^{i},k^{i},v^{i}然后再根据使用的head的数目h hh进一步把得到的q^{i},k^{i},v^{i}均分成h hh份。比如下图中假设h = 2 h=2h=2然后q^{1}拆分成q^{1,1}q^{1,2},那么q^{1,1}就是与head1,q^{1,2}就属于head2

看到这里,如果读过原论文的人肯定有疑问,论文中不是写的通过W_{i}^{Q},W_{i}^{K},W_{i}^{V}映射得到每个head的Q_{i},K_{i},V_{i}

但我在github上看的一些源码中就是简单的进行均分,其实也可以将 W_{i}^{Q},W_{i}^{K},W_{i}^{V}设置成对应值来实现均分,比如下图中的Q通过W_{i}^{Q}就能得到均分后的Q_{i}

通过上述方法就能得到每个head_{1} 对应的Q_{i},K_{i},V_{i}参数,接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。

 

接着将每个head得到的结果进行concat拼接,比如下图中b_{1,1}(head_{1}得到的b_{1}) 和b_{1,2}(head_{2}得到的b_{1})拼接在一起,b_{2,1}(head_{1}得到的b_{2}) 和b_{2,2}(head_{2}得到的b_{2})拼接在一起

接着将拼接后的结果通过W^{O} (可学习的参数)进行融合,如下图所示,融合后得到最终的结果b_1, b_2

 

  • Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
  • MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
     

3.3 分类部分

MLP Head是指多层感知器头部,它是在Transformer Encoder后面用于分类任务的一部分。MLP Head通常由几个线性层组成,用于将Transformer Encoder的输出转换为最终的分类结果。在原始的MLP Head论文中,它由线性层、tanh激活函数和线性层组成。但是在迁移到ImageNet1K或其他数据集时,只需要使用一个线性层即可。

下面是一个示例代码,展示了如何使用PyTorch实现一个简单的MLP Head:

import torch
import torch.nn as nn

class MLPHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLPHead, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.linear(x)
        return x

# 假设Transformer Encoder的输出维度为768,分类任务的类别数为10
input_dim = 768
output_dim = 10

mlp_head = MLPHead(input_dim, output_dim)

# 假设Transformer Encoder的输出为transformer_output
transformer_output = torch.randn(197, 768)

# 将transformer_output输入到MLP Head中得到分类结果
classification_result = mlp_head(transformer_output)

print(classification_result.shape)  # 输出:torch.Size([197, 10])

在上述代码中,我们定义了一个MLPHead类,它接受一个输入维度和一个输出维度作为参数。在forward方法中,我们使用一个线性层将输入转换为输出。然后,我们可以将Transformer Encoder的输出输入到MLP Head中,得到最终的分类结果

上面VIT通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

 3.4别人画的网络结构图

3.5代码实现

Patch+Position Embedding

  1. class PatchEmbed(nn.Module):
  2. def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
  3. super().__init__()
  4. self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
  5. self.flatten = flatten
  6. self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
  7. self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
  8. def forward(self, x):
  9. x = self.proj(x)
  10. if self.flatten:
  11. x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
  12. x = self.norm(x)
  13. return x
  14. class VisionTransformer(nn.Module):
  15. def __init__(
  16. self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
  17. depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
  18. norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
  19. ):
  20. super().__init__()
  21. #-----------------------------------------------#
  22. # 224, 224, 3 -> 196, 768
  23. #-----------------------------------------------#
  24. self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
  25. num_patches = (224 // patch_size) * (224 // patch_size)
  26. self.num_features = num_features
  27. self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
  28. self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
  29. #--------------------------------------------------------------------------------------------------------------------#
  30. # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
  31. #
  32. # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
  33. # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
  34. # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
  35. #--------------------------------------------------------------------------------------------------------------------#
  36. # 196, 768 -> 197, 768
  37. self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
  38. #--------------------------------------------------------------------------------------------------------------------#
  39. # 为网络提取到的特征添加上位置信息。
  40. # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
  41. # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
  42. #--------------------------------------------------------------------------------------------------------------------#
  43. # 197, 768 -> 197, 768
  44. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
  45. def forward_features(self, x):
  46. x = self.patch_embed(x)
  47. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  48. x = torch.cat((cls_token, x), dim=1)
  49. cls_token_pe = self.pos_embed[:, 0:1, :]
  50. img_token_pe = self.pos_embed[:, 1: , :]
  51. img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
  52. img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
  53. img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
  54. pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
  55. x = self.pos_drop(x + pos_embed)

TransformerBlock 

  1. class Mlp(nn.Module):
  2. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  3. """
  4. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
  5. super().__init__()
  6. out_features = out_features or in_features
  7. hidden_features = hidden_features or in_features
  8. drop_probs = (drop, drop)
  9. self.fc1 = nn.Linear(in_features, hidden_features)
  10. self.act = act_layer()
  11. self.drop1 = nn.Dropout(drop_probs[0])
  12. self.fc2 = nn.Linear(hidden_features, out_features)
  13. self.drop2 = nn.Dropout(drop_probs[1])
  14. def forward(self, x):
  15. x = self.fc1(x)
  16. x = self.act(x)
  17. x = self.drop1(x)
  18. x = self.fc2(x)
  19. x = self.drop2(x)
  20. return x
  21. class Block(nn.Module):
  22. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
  23. drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
  24. super().__init__()
  25. self.norm1 = norm_layer(dim)
  26. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
  27. self.norm2 = norm_layer(dim)
  28. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
  29. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  30. def forward(self, x):
  31. x = x + self.drop_path(self.attn(self.norm1(x)))
  32. x = x + self.drop_path(self.mlp(self.norm2(x)))
  33. return x

VIT

整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。 

  1. class VisionTransformer(nn.Module):
  2. def __init__(
  3. self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
  4. depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
  5. norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
  6. ):
  7. super().__init__()
  8. #-----------------------------------------------#
  9. # 224, 224, 3 -> 196, 768
  10. #-----------------------------------------------#
  11. self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
  12. num_patches = (224 // patch_size) * (224 // patch_size)
  13. self.num_features = num_features
  14. self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
  15. self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
  16. #--------------------------------------------------------------------------------------------------------------------#
  17. # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
  18. #
  19. # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
  20. # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
  21. # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
  22. #--------------------------------------------------------------------------------------------------------------------#
  23. # 196, 768 -> 197, 768
  24. self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
  25. #--------------------------------------------------------------------------------------------------------------------#
  26. # 为网络提取到的特征添加上位置信息。
  27. # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
  28. # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
  29. #--------------------------------------------------------------------------------------------------------------------#
  30. # 197, 768 -> 197, 768
  31. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
  32. self.pos_drop = nn.Dropout(p=drop_rate)
  33. #-----------------------------------------------#
  34. # 197, 768 -> 197, 768 12次
  35. #-----------------------------------------------#
  36. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
  37. self.blocks = nn.Sequential(
  38. *[
  39. Block(
  40. dim = num_features,
  41. num_heads = num_heads,
  42. mlp_ratio = mlp_ratio,
  43. qkv_bias = qkv_bias,
  44. drop = drop_rate,
  45. attn_drop = attn_drop_rate,
  46. drop_path = dpr[i],
  47. norm_layer = norm_layer,
  48. act_layer = act_layer
  49. )for i in range(depth)
  50. ]
  51. )
  52. self.norm = norm_layer(num_features)
  53. self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
  54. def forward_features(self, x):
  55. x = self.patch_embed(x)
  56. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  57. x = torch.cat((cls_token, x), dim=1)
  58. cls_token_pe = self.pos_embed[:, 0:1, :]
  59. img_token_pe = self.pos_embed[:, 1: , :]
  60. img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
  61. img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
  62. img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
  63. pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
  64. x = self.pos_drop(x + pos_embed)
  65. x = self.blocks(x)
  66. x = self.norm(x)
  67. return x[:, 0]
  68. def forward(self, x):
  69. x = self.forward_features(x)
  70. x = self.head(x)
  71. return x
  72. def freeze_backbone(self):
  73. backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
  74. for module in backbone:
  75. try:
  76. for param in module.parameters():
  77. param.requires_grad = False
  78. except:
  79. module.requires_grad = False
  80. def Unfreeze_backbone(self):
  81. backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
  82. for module in backbone:
  83. try:
  84. for param in module.parameters():
  85. param.requires_grad = True
  86. except:
  87. module.requires_grad = True

 Vision Transforme的构建代码

  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. #--------------------------------------#
  9. # Gelu激活函数的实现
  10. # 利用近似的数学公式
  11. #--------------------------------------#
  12. class GELU(nn.Module):
  13. def __init__(self):
  14. super(GELU, self).__init__()
  15. def forward(self, x):
  16. return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))
  17. def drop_path(x, drop_prob: float = 0., training: bool = False):
  18. if drop_prob == 0. or not training:
  19. return x
  20. keep_prob = 1 - drop_prob
  21. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  22. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  23. random_tensor.floor_()
  24. output = x.div(keep_prob) * random_tensor
  25. return output
  26. class DropPath(nn.Module):
  27. def __init__(self, drop_prob=None):
  28. super(DropPath, self).__init__()
  29. self.drop_prob = drop_prob
  30. def forward(self, x):
  31. return drop_path(x, self.drop_prob, self.training)
  32. class PatchEmbed(nn.Module):
  33. def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
  34. super().__init__()
  35. self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
  36. self.flatten = flatten
  37. self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
  38. self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
  39. def forward(self, x):
  40. x = self.proj(x)
  41. if self.flatten:
  42. x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
  43. x = self.norm(x)
  44. return x
  45. #--------------------------------------------------------------------------------------------------------------------#
  46. # Attention机制
  47. # 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
  48. # 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
  49. # 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
  50. #--------------------------------------------------------------------------------------------------------------------#
  51. class Attention(nn.Module):
  52. def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
  53. super().__init__()
  54. self.num_heads = num_heads
  55. self.scale = (dim // num_heads) ** -0.5
  56. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  57. self.attn_drop = nn.Dropout(attn_drop)
  58. self.proj = nn.Linear(dim, dim)
  59. self.proj_drop = nn.Dropout(proj_drop)
  60. def forward(self, x):
  61. B, N, C = x.shape
  62. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  63. q, k, v = qkv[0], qkv[1], qkv[2]
  64. attn = (q @ k.transpose(-2, -1)) * self.scale
  65. attn = attn.softmax(dim=-1)
  66. attn = self.attn_drop(attn)
  67. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  68. x = self.proj(x)
  69. x = self.proj_drop(x)
  70. return x
  71. class Mlp(nn.Module):
  72. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  73. """
  74. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
  75. super().__init__()
  76. out_features = out_features or in_features
  77. hidden_features = hidden_features or in_features
  78. drop_probs = (drop, drop)
  79. self.fc1 = nn.Linear(in_features, hidden_features)
  80. self.act = act_layer()
  81. self.drop1 = nn.Dropout(drop_probs[0])
  82. self.fc2 = nn.Linear(hidden_features, out_features)
  83. self.drop2 = nn.Dropout(drop_probs[1])
  84. def forward(self, x):
  85. x = self.fc1(x)
  86. x = self.act(x)
  87. x = self.drop1(x)
  88. x = self.fc2(x)
  89. x = self.drop2(x)
  90. return x
  91. class Block(nn.Module):
  92. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
  93. drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
  94. super().__init__()
  95. self.norm1 = norm_layer(dim)
  96. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
  97. self.norm2 = norm_layer(dim)
  98. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
  99. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  100. def forward(self, x):
  101. x = x + self.drop_path(self.attn(self.norm1(x)))
  102. x = x + self.drop_path(self.mlp(self.norm2(x)))
  103. return x
  104. class VisionTransformer(nn.Module):
  105. def __init__(
  106. self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
  107. depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
  108. norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
  109. ):
  110. super().__init__()
  111. #-----------------------------------------------#
  112. # 224, 224, 3 -> 196, 768
  113. #-----------------------------------------------#
  114. self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
  115. num_patches = (224 // patch_size) * (224 // patch_size)
  116. self.num_features = num_features
  117. self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
  118. self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
  119. #--------------------------------------------------------------------------------------------------------------------#
  120. # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
  121. #
  122. # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
  123. # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
  124. # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
  125. #--------------------------------------------------------------------------------------------------------------------#
  126. # 196, 768 -> 197, 768
  127. self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
  128. #--------------------------------------------------------------------------------------------------------------------#
  129. # 为网络提取到的特征添加上位置信息。
  130. # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
  131. # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
  132. #--------------------------------------------------------------------------------------------------------------------#
  133. # 197, 768 -> 197, 768
  134. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
  135. self.pos_drop = nn.Dropout(p=drop_rate)
  136. #-----------------------------------------------#
  137. # 197, 768 -> 197, 768 12次
  138. #-----------------------------------------------#
  139. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
  140. self.blocks = nn.Sequential(
  141. *[
  142. Block(
  143. dim = num_features,
  144. num_heads = num_heads,
  145. mlp_ratio = mlp_ratio,
  146. qkv_bias = qkv_bias,
  147. drop = drop_rate,
  148. attn_drop = attn_drop_rate,
  149. drop_path = dpr[i],
  150. norm_layer = norm_layer,
  151. act_layer = act_layer
  152. )for i in range(depth)
  153. ]
  154. )
  155. self.norm = norm_layer(num_features)
  156. self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
  157. def forward_features(self, x):
  158. x = self.patch_embed(x)
  159. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  160. x = torch.cat((cls_token, x), dim=1)
  161. cls_token_pe = self.pos_embed[:, 0:1, :]
  162. img_token_pe = self.pos_embed[:, 1: , :]
  163. img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
  164. img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
  165. img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
  166. pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
  167. x = self.pos_drop(x + pos_embed)
  168. x = self.blocks(x)
  169. x = self.norm(x)
  170. return x[:, 0]
  171. def forward(self, x):
  172. x = self.forward_features(x)
  173. x = self.head(x)
  174. return x
  175. def freeze_backbone(self):
  176. backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
  177. for module in backbone:
  178. try:
  179. for param in module.parameters():
  180. param.requires_grad = False
  181. except:
  182. module.requires_grad = False
  183. def Unfreeze_backbone(self):
  184. backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
  185. for module in backbone:
  186. try:
  187. for param in module.parameters():
  188. param.requires_grad = True
  189. except:
  190. module.requires_grad = True
  191. def vit(input_shape=[224, 224], pretrained=False, num_classes=1000):
  192. model = VisionTransformer(input_shape)
  193. if pretrained:
  194. model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))
  195. if num_classes!=1000:
  196. model.head = nn.Linear(model.num_features, num_classes)
  197. return model

参考:Vision Transformer详解

神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解

"未来"的经典之作ViT:transformer is all you need!

全网最强ViT (Vision Transformer)原理及代码解析

详解Transformer中Self-Attention以及Multi-Head Attention

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/345768
推荐阅读
相关标签
  

闽ICP备14008679号