赞
踩
论文地址:https://arxiv.org/abs/2010.11929
代码地址:https://github.com/bubbliiiing/classification-pytorch
Vision Transformer(VIT)是一种基于Transformer架构的图像分类模型。它将图像分割成一系列的图像块,并将每个图像块作为输入序列传递给Transformer模型。VIT通过自注意力机制来捕捉图像中的全局上下文信息,并使用多层感知机(MLP)来进行特征提取和分类。
VIT的核心思想是将图像转换为序列数据,这使得模型能够利用Transformer的强大表达能力来处理图像。通过将图像分割成图像块,并将它们展平为序列,VIT能够在不依赖传统卷积神经网络的情况下实现图像分类任务。
从2020年,transformer开始在CV领域大放异彩:图像分类(ViT, DeiT),目标检测(DETR,Deformable DETR),语义分割(SETR,MedT),图像生成(GANsformer)等。而从深度学习暴发以来,CNN一直是CV领域的主流模型,而且取得了很好的效果,相比之下transformer却独霸NLP领域,transformer在CV领域的探索正是研究界想把transformer在NLP领域的成功借鉴到CV领域。对于图像问题,卷积具有天然的先天优势(inductive bias):平移等价性(translation equivariance)和局部性(locality)。而transformer虽然不并具备这些优势,但是transformer的核心self-attention的优势不像卷积那样有固定且有限的感受野,self-attention操作可以获得long-range信息(相比之下CNN要通过不断堆积Conv layers来获取更大的感受野),但训练的难度就比CNN要稍大一些。
ViT(vision transformer)是Google在2020年提出的直接将transformer应用在图像分类的模型,后面很多的工作都是基于ViT进行改进的。这篇论文也是受到其启发,尝试将Transformer应用到CV领域通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。
与寻常的分类网络类似,整个Vision Transformer可以分为两部分,一部分是特征提取部分,另一部分是分类部分。
在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。
在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。
Patch的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。
该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。也就是说,不是把图片分割,是做了一次简单的卷积,可以理解为初步特征提取,或者说是映射。
由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了。在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。
在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。
Position Embedding的作用主要是对组合序列加上[class]token以及Position Embedding。
除了patch embeddings,模型还需要另外一个特殊的position embedding。transformer和CNN不同,需要position embedding来编码tokens的位置信息,这主要是因为self-attention是permutation-invariant,即打乱sequence里的tokens的顺序并不会改变结果。如果不给模型提供patch的位置信息,那么模型就需要通过patchs的语义来学习拼图,这就额外增加了学习成本。ViT论文中对比了几种不同的position embedding方案(如下),最后发现如果不提供positional embedding效果会差,但其它各种类型的positional embedding效果都接近,这主要是因为ViT的输入是相对较大的patchs而不是pixels,所以学习位置信息相对容易很多。
transformer原论文中是默认采用固定的positional embedding,但ViT中默认采用学习(训练的)的1-D positional embedding,在输入transformer的encoder之前直接将patch embeddings和positional embedding相加:
- # 这里多1是为了后面要说的class token,embed_dim即patch embed_dim
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
-
- # patch emded + pos_embed
- x = x + self.pos_embed
对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb
.
,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.
比起来没太大差别。
Class Token
除了patch tokens,ViT借鉴BERT还增加了一个特殊的class token。后面会说,transformer的encoder输入是a sequence patch embeddings,输出也是同样长度的a sequence patch features,但图像分类最后需要获取image feature,简单的策略是采用pooling,比如求patch features的平均来获取image feature,但是ViT并没有采用类似的pooling策略,而是直接增加一个特殊的class token,其最后输出的特征加一个linear classifier就可以实现对图像的分类(ViT的pre-training时是接一个MLP head),所以输入ViT的sequence长度是�+1。class token对应的embedding在训练时随机初始化,然后通过训练得到,具体实现如下:
- # 随机初始化
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
-
- # Classifier head
- self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
-
- # 具体forward过程
- B = x.shape[0]
- x = self.patch_embed(x)
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
- x = torch.cat((cls_tokens, x), dim=1)
- x = x + self.pos_embed
Transformer Encoder的作用是将输入序列进行编码,生成一个高维表示,以便后续的任务处理。它由多个Encoder Block组成,每个Encoder Block包含一个多头自注意力层和一个前馈全连接层。在编码过程中,输入序列会经过多头自注意力层进行特征提取和关联性计算,然后再通过前馈全连接层进行非线性变换和特征融合。通过堆叠多个Encoder Block,Transformer Encoder能够捕捉输入序列中的语义信息和上下文关系,生成一个更加丰富的表示。
下图是太阳花的小绿豆绘制的Encoder Block,主要由以下几部分组成:
3.2.3.1Layer Norm
这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考链接
3.2.3.2 Multi-Head Attention
Multi-Head Attention是一种注意力机制,它由多个独立的注意力头组成。每个注意力头都可以学习到不同的表示子空间,并在不同的位置上联合关注信息。通过使用多个注意力头,Multi-Head Attention可以更好地捕捉输入序列中的不同关系和特征。
下面是一个示例代码,演示了如何实现Multi-Head Attention:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = input_dim // num_heads
self.query_linear = nn.Linear(input_dim, input_dim)
self.key_linear = nn.Linear(input_dim, input_dim)
self.value_linear = nn.Linear(input_dim, input_dim)
self.output_linear = nn.Linear(input_dim, input_dim)
def forward(self, query, key, value):
batch_size = query.size(0)
# 线性变换得到query、key、value
query = self.query_linear(query)
key = self.key_linear(key)
value = self.value_linear(value)
# 将query、key、value分成多个头
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力得分
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
attention_weights = torch.softmax(scores, dim=-1)
# 对value进行加权求和
weighted_values = torch.matmul(attention_weights, value)
# 将多个头的结果拼接起来
weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 线性变换得到最终的输出
output = self.output_linear(weighted_values)
return output
Multi-Head Attention,看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入,每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。
假设输入的序列长度为2,输入就两个节点,然后通过Input Embedding也就是图中的f(x)将输入映射到 。紧接着分别将 分别通过三个变换矩阵 ,(这三个参数是可训练的,是共享的)得到对应的(这里在源码中是直接使用全连接层实现的,这里为了方便理解,忽略偏执)。
其中
q 代表query,后续会去和每一个k kk进行匹配
k 代表key,后续会被每个q qq匹配
v 代表从a aa中提取得到的信息
后续q 和k 匹配的过程可以理解成计算两者的相关性,相关性越大对应v 的权重也就越大
假设,那么
前面有说Transformer是可以并行化的,所以可以直接写成:
同理我们可以得到和,那么求得的就是原论文中的Q ,就是K,就是V.接着先拿和每个k kk进行match,点乘操作,接着除以得到对应的α ,其中d 代表向量的长度,在本示例中等于2,除以的原因在论文中的解释是“进行点乘后的数值很大,导致通过softmax后梯度变的很小”,所以通过除以来进行缩放。比如计算:
同理拿 去匹配所有的k kk能得到,统一写成矩阵乘法形式:
接着对每一行即和分别进行softmax处理得到和,这里的相当于计算得到针对每个v vv的权重。到这我们就完成了 Attention(Q,K,V)公式中的部分
上面已经计算得到α \alphaα,即针对每个v vv的权重,接着进行加权得到最终结果:
统一写成矩阵乘法形式:
到这,Self-Attention
的内容就讲完了。总结下来就是论文中的一个公式:
(2)Multi-Head Attention
首先还是和Self-Attention模块一样将分别通过得到对应的然后再根据使用的head的数目h hh进一步把得到的均分成h hh份。比如下图中假设h = 2 h=2h=2然后拆分成和,那么就是与head1,就属于head2
看到这里,如果读过原论文的人肯定有疑问,论文中不是写的通过映射得到每个head的
但我在github上看的一些源码中就是简单的进行均分,其实也可以将 设置成对应值来实现均分,比如下图中的Q通过就能得到均分后的
通过上述方法就能得到每个 对应的参数,接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。
接着将每个head得到的结果进行concat拼接,比如下图中(得到的) 和(得到的)拼接在一起,(得到的) 和(得到的)拼接在一起
接着将拼接后的结果通过 (可学习的参数)进行融合,如下图所示,融合后得到最终的结果
MLP Head是指多层感知器头部,它是在Transformer Encoder后面用于分类任务的一部分。MLP Head通常由几个线性层组成,用于将Transformer Encoder的输出转换为最终的分类结果。在原始的MLP Head论文中,它由线性层、tanh激活函数和线性层组成。但是在迁移到ImageNet1K或其他数据集时,只需要使用一个线性层即可。
下面是一个示例代码,展示了如何使用PyTorch实现一个简单的MLP Head:
import torch
import torch.nn as nnclass MLPHead(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLPHead, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)def forward(self, x):
x = self.linear(x)
return x# 假设Transformer Encoder的输出维度为768,分类任务的类别数为10
input_dim = 768
output_dim = 10mlp_head = MLPHead(input_dim, output_dim)
# 假设Transformer Encoder的输出为transformer_output
transformer_output = torch.randn(197, 768)# 将transformer_output输入到MLP Head中得到分类结果
classification_result = mlp_head(transformer_output)print(classification_result.shape) # 输出:torch.Size([197, 10])
在上述代码中,我们定义了一个MLPHead类,它接受一个输入维度和一个输出维度作为参数。在forward方法中,我们使用一个线性层将输入转换为输出。然后,我们可以将Transformer Encoder的输出输入到MLP Head中,得到最终的分类结果
上面VIT通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。
Patch+Position Embedding
- class PatchEmbed(nn.Module):
- def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
- super().__init__()
- self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
- self.flatten = flatten
-
- self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
- self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
-
- def forward(self, x):
- x = self.proj(x)
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
- x = self.norm(x)
- return x
-
- class VisionTransformer(nn.Module):
- def __init__(
- self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
- depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
- ):
- super().__init__()
- #-----------------------------------------------#
- # 224, 224, 3 -> 196, 768
- #-----------------------------------------------#
- self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
- num_patches = (224 // patch_size) * (224 // patch_size)
- self.num_features = num_features
- self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
- self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
-
- #--------------------------------------------------------------------------------------------------------------------#
- # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
- #
- # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
- # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
- # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
- #--------------------------------------------------------------------------------------------------------------------#
- # 196, 768 -> 197, 768
- self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
- #--------------------------------------------------------------------------------------------------------------------#
- # 为网络提取到的特征添加上位置信息。
- # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
- # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
- #--------------------------------------------------------------------------------------------------------------------#
- # 197, 768 -> 197, 768
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
-
- def forward_features(self, x):
- x = self.patch_embed(x)
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
- x = torch.cat((cls_token, x), dim=1)
-
- cls_token_pe = self.pos_embed[:, 0:1, :]
- img_token_pe = self.pos_embed[:, 1: , :]
-
- img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
- img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
- img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
- pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
-
- x = self.pos_drop(x + pos_embed)
TransformerBlock
- class Mlp(nn.Module):
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
- """
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- drop_probs = (drop, drop)
-
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop2 = nn.Dropout(drop_probs[1])
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
-
- class Block(nn.Module):
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
- drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
- self.norm2 = norm_layer(dim)
- self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-
- def forward(self, x):
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
VIT
整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。
- class VisionTransformer(nn.Module):
- def __init__(
- self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
- depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
- ):
- super().__init__()
- #-----------------------------------------------#
- # 224, 224, 3 -> 196, 768
- #-----------------------------------------------#
- self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
- num_patches = (224 // patch_size) * (224 // patch_size)
- self.num_features = num_features
- self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
- self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
-
- #--------------------------------------------------------------------------------------------------------------------#
- # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
- #
- # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
- # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
- # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
- #--------------------------------------------------------------------------------------------------------------------#
- # 196, 768 -> 197, 768
- self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
- #--------------------------------------------------------------------------------------------------------------------#
- # 为网络提取到的特征添加上位置信息。
- # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
- # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
- #--------------------------------------------------------------------------------------------------------------------#
- # 197, 768 -> 197, 768
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- #-----------------------------------------------#
- # 197, 768 -> 197, 768 12次
- #-----------------------------------------------#
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
- self.blocks = nn.Sequential(
- *[
- Block(
- dim = num_features,
- num_heads = num_heads,
- mlp_ratio = mlp_ratio,
- qkv_bias = qkv_bias,
- drop = drop_rate,
- attn_drop = attn_drop_rate,
- drop_path = dpr[i],
- norm_layer = norm_layer,
- act_layer = act_layer
- )for i in range(depth)
- ]
- )
- self.norm = norm_layer(num_features)
- self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
-
- def forward_features(self, x):
- x = self.patch_embed(x)
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
- x = torch.cat((cls_token, x), dim=1)
-
- cls_token_pe = self.pos_embed[:, 0:1, :]
- img_token_pe = self.pos_embed[:, 1: , :]
-
- img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
- img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
- img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
- pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
-
- x = self.pos_drop(x + pos_embed)
- x = self.blocks(x)
- x = self.norm(x)
- return x[:, 0]
-
- def forward(self, x):
- x = self.forward_features(x)
- x = self.head(x)
- return x
-
- def freeze_backbone(self):
- backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
- for module in backbone:
- try:
- for param in module.parameters():
- param.requires_grad = False
- except:
- module.requires_grad = False
-
- def Unfreeze_backbone(self):
- backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
- for module in backbone:
- try:
- for param in module.parameters():
- param.requires_grad = True
- except:
- module.requires_grad = True
Vision Transforme的构建代码
- import math
- from collections import OrderedDict
- from functools import partial
-
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- #--------------------------------------#
- # Gelu激活函数的实现
- # 利用近似的数学公式
- #--------------------------------------#
- class GELU(nn.Module):
- def __init__(self):
- super(GELU, self).__init__()
-
- def forward(self, x):
- return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))
-
- def drop_path(x, drop_prob: float = 0., training: bool = False):
- if drop_prob == 0. or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- random_tensor.floor_()
- output = x.div(keep_prob) * random_tensor
- return output
-
- class DropPath(nn.Module):
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
-
- class PatchEmbed(nn.Module):
- def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
- super().__init__()
- self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
- self.flatten = flatten
-
- self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
- self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
-
- def forward(self, x):
- x = self.proj(x)
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
- x = self.norm(x)
- return x
-
- #--------------------------------------------------------------------------------------------------------------------#
- # Attention机制
- # 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
- # 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
- # 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
- #--------------------------------------------------------------------------------------------------------------------#
- class Attention(nn.Module):
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
- super().__init__()
- self.num_heads = num_heads
- self.scale = (dim // num_heads) ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
-
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- class Mlp(nn.Module):
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
- """
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- drop_probs = (drop, drop)
-
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop2 = nn.Dropout(drop_probs[1])
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
-
- class Block(nn.Module):
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
- drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
- self.norm2 = norm_layer(dim)
- self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-
- def forward(self, x):
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
-
- class VisionTransformer(nn.Module):
- def __init__(
- self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
- depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
- ):
- super().__init__()
- #-----------------------------------------------#
- # 224, 224, 3 -> 196, 768
- #-----------------------------------------------#
- self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
- num_patches = (224 // patch_size) * (224 // patch_size)
- self.num_features = num_features
- self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
- self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
-
- #--------------------------------------------------------------------------------------------------------------------#
- # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
- #
- # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
- # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
- # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
- #--------------------------------------------------------------------------------------------------------------------#
- # 196, 768 -> 197, 768
- self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
- #--------------------------------------------------------------------------------------------------------------------#
- # 为网络提取到的特征添加上位置信息。
- # 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
- # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
- #--------------------------------------------------------------------------------------------------------------------#
- # 197, 768 -> 197, 768
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- #-----------------------------------------------#
- # 197, 768 -> 197, 768 12次
- #-----------------------------------------------#
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
- self.blocks = nn.Sequential(
- *[
- Block(
- dim = num_features,
- num_heads = num_heads,
- mlp_ratio = mlp_ratio,
- qkv_bias = qkv_bias,
- drop = drop_rate,
- attn_drop = attn_drop_rate,
- drop_path = dpr[i],
- norm_layer = norm_layer,
- act_layer = act_layer
- )for i in range(depth)
- ]
- )
- self.norm = norm_layer(num_features)
- self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
-
- def forward_features(self, x):
- x = self.patch_embed(x)
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
- x = torch.cat((cls_token, x), dim=1)
-
- cls_token_pe = self.pos_embed[:, 0:1, :]
- img_token_pe = self.pos_embed[:, 1: , :]
-
- img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
- img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
- img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
- pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
-
- x = self.pos_drop(x + pos_embed)
- x = self.blocks(x)
- x = self.norm(x)
- return x[:, 0]
-
- def forward(self, x):
- x = self.forward_features(x)
- x = self.head(x)
- return x
-
- def freeze_backbone(self):
- backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
- for module in backbone:
- try:
- for param in module.parameters():
- param.requires_grad = False
- except:
- module.requires_grad = False
-
- def Unfreeze_backbone(self):
- backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
- for module in backbone:
- try:
- for param in module.parameters():
- param.requires_grad = True
- except:
- module.requires_grad = True
-
-
- def vit(input_shape=[224, 224], pretrained=False, num_classes=1000):
- model = VisionTransformer(input_shape)
- if pretrained:
- model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))
-
- if num_classes!=1000:
- model.head = nn.Linear(model.num_features, num_classes)
- return model
神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解
"未来"的经典之作ViT:transformer is all you need!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。