当前位置:   article > 正文

Vision Transformer原理及模型学习笔记_self.proj

self.proj

Vision Transformer是一种加入了self-attention的Sequence to Sequence model。它包括特征提取分类两部分。

在特征提取部分,VIT所作的工作就是特征提取。特征提取部分在图片中的对应区域是Patch+Position EmbeddingTransformer Encoder

Patch+Position Embedding就是对输入进来的图片进行分块处理,每隔一定区域大小划分图片块,然后将划分后的图片块组合成序列。在获得序列信息之后,传入Transformer Encoder进行特征提取,这是Transformer特有的Mult-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

在分类部分,VIT所做的工作就是将提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上classtoken,classtoken会作为一个单位的序列信息一起进行特征提取,在提取的过程中,classtoken会与其他的特征进行特征交互,融合其他图片序列的特征。最终,我们利用Mult-head Self-attention结构提取特征后的classtoken进行全连接分类。

Vision Transformer源码链接:VIT源码

vit的结构

Vision Transformer的结构如下图所示。

 

数据处理

输入的图片尺寸为[H,W,C],patch size为P,则分块的数目为(H\times W)/(P\times P),然后将每一个图片块展开成一维向量,每个向量大小为P\times P\times C

设输入图片尺寸为(480,480,3),patch size为16,分块的数目为(480x480)/(16x16)=900块。每个图片块展开成一维向量,向量大小为16x16x3=768。

序列化后就获得了一个(900,768)的特征层。再加上classtoken后就是一个(901,768)的特征层。再给所有特征加上位置信息,这样网络才有区分不同区域的能力。此时生成的pos_Embedding的shape也为(901,768),代表每一个特征的位置信息。添加位置信息的方式就是生成一个(901,768)的参数矩阵,这个参数矩阵是可训练的,再把这个参数矩阵加上(901,768)的特征层就可以了。

图中0*就是casstoken。

self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches, num_features))

在上一步获得shape为(901,768)的序列信息后,将序列信息传入Transform Encoder进行特征提取,这是Transformer特有的Mult-head self-attention结果。通过自注意力机制,关注每个图片块的重要程度。

处理sequence最常用的就是RNN,它的输入是一串vector sequence,输出是另一串vector sequence。如果是singe directional的RNN,输出b_{4}时,默认a{_1}a_{2}a_{3}a_{4}都已经看过了。如果是bi-directional的RNN,输出b_{n}时,默认a{_1}a_{2}a_{3}a_{4}都已经看过了。但RNN有一个问题:不能并行化(hard to parallel)。例如在singe directional的RNN中要 算出b_{4},就必须 先看a{_1},再看a_{2},再看a_{3},再看a_{4},这个过程很难并行。

而self-attention,它的输入、输出和RNN一样,都是输入一串sequence,输出另外一串sequence,它和bi-directional一样,b_{1}b_{4} 每一个输出都看过了整个的输入sequence。但self-attention与RNN不同的是:它的 b_{1}b_{4} 每一个输出都可以并行化计算。

self-attention

如上图所示,设input是x^{1}-x^{4},这是一个sequence,x^{1}-x^{4}分别与一个矩阵W相乘得到embedding,即a^{1}-a^{4},接着这个embedding进入self-attention层,每个a^{1}-a^{4}分别乘上三个不同的transformer 矩阵W^{q}W^{k}W^{v},例如a^{1}得到q^{1}k^{1}v^{1}

 接着将每个query q都对每个key k做attention,attention就是匹配这2个向量有多接近。比如要对q^{1}k^{1}做attention,就是对这2个向量做scaled inner product,得到a_{1,1},再用q^{1}k^{2}做attention,得到a_{1,2};用q^{^{1}}k^{3}做attention,得到a_{1,3};用q^{1}k^{4}做attention,得到a_{1,4}

 做scaled inner product的方法

a_{1,i} = q^{1} \cdot k^{i}/\sqrt{d}

其中d是q、k的维度,q\cdot v的值会随着维度的增大而增大,所以要除以\sqrt{d},相当于一个归一化。

接下来就是对计算得到的所有a{_1,i}做softmax操作,得到 \widehat{a_{1,i}}

接着将\widehat{a_{1,i}}v相乘,即\widehat{a_{1,1}}v^{1}相乘,\widehat{a_{1,2}}v^{2}相乘,\widehat{a_{1,3}}v^{3}相乘,\widehat{a_{1,4}}v^{4}相乘,再把结果相加得到b_{1}。在产生b_{1}的整个过程中使用了整个sequence的资讯。如果要考虑局部的information,只需要学习相应的\widehat{a_{1,i}}= 0b_{1}就不带有那个分支的信息了。如果要考虑全局的information,只需要学习相应的\widehat{a_{1,i}}\neq 0b_{1}就带有所有分支的信息了。

同样的方法可以求出b_{2}b{_{3}}b_{4}

以上一连串的计算,self-attention layer做的事情和RNN是一样的,但是它可以并行化计算。

 

 首先输入的embedding是I=[a^{1},a^{2},a^{3},a^{4}],用IW^{q}得到Q=[q^{1},q^{2},q^{3},q^{4}],它的每一列代表一个向量q。用IW^{K}得到K=[k^{1},k^{2},k^{3},k^{4}],它的每一列代表一个向量k。用IW^{v}得到V=[v^{1},v^{2},v^{3},v^{4}],它的每一列代表一个向量v

输入矩阵I\in R[d,N]分别乘上三个不同的矩阵W_{q}W_{k}W_{v}得到3个中间矩阵QKV\in R[d,N],将K转置之后与Q相乘得到attention矩阵A\in R[N,N],代表每一个位置两两之间的attention。再将它去softmax操作之后得到\widehat{A}\in R[N,N],最后将它乘以V矩阵得到输出vect或O\in R[d,N]

 

 Multi-head Self-attention

 以2个head为例,由a^{i}生成q^{i}进一步乘以2个转移矩阵变为q^{i,1}q^{i,2},由a^{i}生成k^{i} 进一步乘以2个转移矩阵变为k^{i,1}k^{i,2}, 由a^{i}生成v^{i} 进一步乘以2个转移矩阵变为v^{i,1}v^{i,2}。接下来q^{i,1}k^{i,1}做attention,得到加权和的权重\alpha,再与v^{i,1}做加权和,得到最终的b^{i,1}(i=1,2,...,N),同理得到b^{i,2}(i=1,2,...,N)。现在有了b^{i,1}(i=1,2,...,N)\in R[d,1]b^{i,2}(i=1,2,...,N)\in R[d,1],把它们连接起来,再通过一个transformation矩阵调整维度,使之与b^{i}(i=1,2,...,N)\in R(d,1)维度一致。

 

 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入X分别传递到 2个不同的 Self-Attention 中,计算得到 2 个输出结果。得到2个输出矩阵之后,Multi-Head Attention 将它们拼接在一起 (Concat),然后传入一个Linear层,得到 Multi-Head Attention 最终的输出 Z。可以看到 Multi-Head Attention 输出的矩阵Z与其输入的矩阵X的维度是一样的。

例如在第一步进行图像分割后,我们获得的特征层为(901,768)。在施加多头的时候,我们直接对(901,768)的最后一维度进行分割,比如我们想分割成12个头,那么矩阵的shape就变成(901,12,64)。然后对(901,12,64)进行转置,将12放到前面,获得(12,901,64)的特征层。之后忽略这个12,把它和batchsize维度同等对待,只对901,64进行处理,以上就是注意力机制的过程了。

注意力机制代码如下:

  1. class Attention(nn.Module):
  2. def __init__(self, dim, num_heads=12, qkv_bias=False, attn_drop=0., proj_drop=0.):
  3. super().__init__()
  4. self.num_heads = num_heads
  5. self.scale = (dim // num_heads) ** -0.5
  6. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  7. self.attn_drop = nn.Dropout(attn_drop)
  8. self.proj = nn.Linear(dim, dim)
  9. self.proj_drop = nn.Dropout(proj_drop)
  10. def forward(self, x):
  11. B, N, C = x.shape
  12. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  13. q, k, v = qkv[0], qkv[1], qkv[2]
  14. attn = (q @ k.transpose(-2, -1)) * self.scale
  15. attn = attn.softmax(dim=-1)
  16. attn = self.attn_drop(attn)
  17. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  18. x = self.proj(x)
  19. x = self.proj_drop(x)
  20. return x

 在完成MultiHeadSelfAttention的构建后,我们需要在其后加上两个全连接。就构建了整个TransformerBlock。

  1. class Mlp(nn.Module):
  2. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  3. """
  4. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
  5. super().__init__()
  6. out_features = out_features or in_features
  7. hidden_features = hidden_features or in_features
  8. drop_probs = (drop, drop)
  9. self.fc1 = nn.Linear(in_features, hidden_features)
  10. self.act = act_layer()
  11. self.drop1 = nn.Dropout(drop_probs[0])
  12. self.fc2 = nn.Linear(hidden_features, out_features)
  13. self.drop2 = nn.Dropout(drop_probs[1])
  14. def forward(self, x):
  15. x = self.fc1(x)
  16. x = self.act(x)
  17. x = self.drop1(x)
  18. x = self.fc2(x)
  19. x = self.drop2(x)
  20. return x
  21. class Block(nn.Module):
  22. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
  23. drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
  24. super().__init__()
  25. self.norm1 = norm_layer(dim)
  26. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
  27. self.norm2 = norm_layer(dim)
  28. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
  29. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  30. def forward(self, x):
  31. x = x + self.drop_path(self.attn(self.norm1(x)))
  32. #-----------------------------------------------#
  33. # 施加层标准化、多头注意力机制、残差结构
  34. #-----------------------------------------------#
  35. x = x + self.drop_path(self.mlp(self.norm2(x)))
  36. #-----------------------------------------------#
  37. # 施加层标准化、两次全连接、残差结构
  38. #-----------------------------------------------#
  39. return x

整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。

  1. self.blocks = nn.Sequential(
  2. *[
  3. Block(
  4. dim = num_features,
  5. num_heads = num_heads,
  6. mlp_ratio = mlp_ratio,
  7. qkv_bias = qkv_bias,
  8. drop = drop_rate,
  9. attn_drop = attn_drop_rate,
  10. drop_path = dpr[i],
  11. norm_layer = norm_layer,
  12. act_layer = act_layer
  13. )for i in range(depth)#depth == 12
  14. ]
  15. )

Vision Transformer的构建代码

  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. #--------------------------------------#
  9. # Gelu激活函数的实现
  10. # 利用近似的数学公式
  11. #--------------------------------------#
  12. class GELU(nn.Module):
  13. def __init__(self):
  14. super(GELU, self).__init__()
  15. def forward(self, x):
  16. return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))
  17. def drop_path(x, drop_prob: float = 0., training: bool = False):
  18. if drop_prob == 0. or not training:
  19. return x
  20. keep_prob = 1 - drop_prob
  21. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  22. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  23. random_tensor.floor_()
  24. output = x.div(keep_prob) * random_tensor
  25. return output
  26. class DropPath(nn.Module):
  27. def __init__(self, drop_prob=None):
  28. super(DropPath, self).__init__()
  29. self.drop_prob = drop_prob
  30. def forward(self, x):
  31. return drop_path(x, self.drop_prob, self.training)
  32. class PatchEmbed(nn.Module):
  33. def __init__(self, input_shape=[480, 480], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
  34. super().__init__()
  35. self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
  36. self.flatten = flatten
  37. self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
  38. self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
  39. def forward(self, x):
  40. x = self.proj(x)
  41. if self.flatten:
  42. x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
  43. x = self.norm(x)
  44. return x
  45. #--------------------------------------------------------------------------------------------------------------------#
  46. # Attention机制
  47. # 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
  48. # 然后利用 查询向量query 点乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
  49. # 然后利用 score 点乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
  50. #--------------------------------------------------------------------------------------------------------------------#
  51. class Attention(nn.Module):
  52. def __init__(self, dim, num_heads=12, qkv_bias=False, attn_drop=0., proj_drop=0.):
  53. super().__init__()
  54. self.num_heads = num_heads
  55. self.scale = (dim // num_heads) ** -0.5
  56. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  57. self.attn_drop = nn.Dropout(attn_drop)
  58. self.proj = nn.Linear(dim, dim)
  59. self.proj_drop = nn.Dropout(proj_drop)
  60. def forward(self, x):
  61. B, N, C = x.shape
  62. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  63. q, k, v = qkv[0], qkv[1], qkv[2]
  64. attn = (q @ k.transpose(-2, -1)) * self.scale
  65. attn = attn.softmax(dim=-1)
  66. attn = self.attn_drop(attn)
  67. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  68. x = self.proj(x)
  69. x = self.proj_drop(x)
  70. return x
  71. class Mlp(nn.Module):
  72. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  73. """
  74. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
  75. super().__init__()
  76. out_features = out_features or in_features
  77. hidden_features = hidden_features or in_features
  78. drop_probs = (drop, drop)
  79. self.fc1 = nn.Linear(in_features, hidden_features)
  80. self.act = act_layer()
  81. self.drop1 = nn.Dropout(drop_probs[0])
  82. self.fc2 = nn.Linear(hidden_features, out_features)
  83. self.drop2 = nn.Dropout(drop_probs[1])
  84. def forward(self, x):
  85. x = self.fc1(x)
  86. x = self.act(x)
  87. x = self.drop1(x)
  88. x = self.fc2(x)
  89. x = self.drop2(x)
  90. return x
  91. class Block(nn.Module):
  92. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
  93. drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
  94. super().__init__()
  95. self.norm1 = norm_layer(dim)
  96. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
  97. self.norm2 = norm_layer(dim)
  98. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
  99. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  100. def forward(self, x):
  101. x = x + self.drop_path(self.attn(self.norm1(x)))
  102. x = x + self.drop_path(self.mlp(self.norm2(x)))
  103. return x
  104. class Conv2dReLU(nn.Sequential):
  105. def __init__(
  106. self,
  107. in_channels,
  108. out_channels,
  109. kernel_size,
  110. padding=0,
  111. stride=1,
  112. use_batchnorm=True,
  113. ):
  114. conv = nn.Conv2d(
  115. in_channels,
  116. out_channels,
  117. kernel_size,
  118. stride=stride,
  119. padding=padding,
  120. bias=not (use_batchnorm),
  121. )
  122. relu = nn.ReLU(inplace=True)
  123. bn = nn.BatchNorm2d(out_channels)
  124. super(Conv2dReLU, self).__init__(conv, bn, relu)
  125. class VisionTransformer(nn.Module):
  126. def __init__(
  127. self, input_shape=[480, 480], patch_size=16, in_chans=3, num_classes=2, num_features=768,
  128. depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
  129. norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
  130. ):
  131. super().__init__()
  132. #-----------------------------------------------#
  133. # 480, 480, 3 -> 196, 768
  134. #-----------------------------------------------#
  135. self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
  136. num_patches = (480 // patch_size) * (480 // patch_size)
  137. self.num_features = num_features
  138. self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
  139. self.old_feature_shape = [int(480 // patch_size), int(480 // patch_size)]
  140. #--------------------------------------------------------------------------------------------------------------------#
  141. # classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
  142. #
  143. # 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
  144. # 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
  145. # 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
  146. #--------------------------------------------------------------------------------------------------------------------#
  147. # 196, 768 -> 197, 768
  148. self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
  149. #--------------------------------------------------------------------------------------------------------------------#
  150. # 为网络提取到的特征添加上位置信息。
  151. # 以输入图片为480, 480, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
  152. # 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
  153. #--------------------------------------------------------------------------------------------------------------------#
  154. # 197, 768 -> 197, 768
  155. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, num_features))
  156. self.pos_drop = nn.Dropout(p=drop_rate)
  157. #-----------------------------------------------#
  158. # 197, 768 -> 197, 768 12
  159. #-----------------------------------------------#
  160. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
  161. self.blocks = nn.Sequential(
  162. *[
  163. Block(
  164. dim = num_features,
  165. num_heads = num_heads,
  166. mlp_ratio = mlp_ratio,
  167. qkv_bias = qkv_bias,
  168. drop = drop_rate,
  169. attn_drop = attn_drop_rate,
  170. drop_path = dpr[i],
  171. norm_layer = norm_layer,
  172. act_layer = act_layer
  173. )for i in range(depth)
  174. ]
  175. )
  176. self.norm = norm_layer(num_features)
  177. self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
  178. n_patches = (480 // 16) * (480 // 16)
  179. self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, 768))
  180. self.conv_more = Conv2dReLU(
  181. 768,
  182. 2048,
  183. kernel_size=3,
  184. padding=1,
  185. use_batchnorm=True,
  186. )
  187. def forward_features(self, x):
  188. x = self.patch_embed(x)
  189. # cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  190. # x = torch.cat((cls_token, x), dim=1)
  191. #
  192. # cls_token_pe = self.pos_embed[:, 0:1, :]
  193. # img_token_pe = self.pos_embed[:, 1: , :]
  194. # print(img_token_pe.shape)
  195. #
  196. # img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
  197. # print(img_token_pe.shape)
  198. # img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
  199. # print(img_token_pe.shape)
  200. # img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(2)
  201. # print(img_token_pe.shape)
  202. # pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
  203. # x = x.flatten(2)
  204. # x = x.transpose(-1, -2)
  205. x = x + self.position_embeddings
  206. x = self.pos_drop(x)
  207. x = self.blocks(x)
  208. x = self.norm(x)
  209. return x
  210. def forward1(self,hidden_states):
  211. B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
  212. h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
  213. x = hidden_states.permute(0, 2, 1)
  214. x = x.contiguous().view(B, hidden, h, w)
  215. x = self.conv_more(x)
  216. return x
  217. def forward(self, x):
  218. x = self.forward_features(x)
  219. x = self.forward1(x)
  220. # print(x.shape)
  221. return x
  222. def freeze_backbone(self):
  223. backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
  224. for module in backbone:
  225. try:
  226. for param in module.parameters():
  227. param.requires_grad = False
  228. except:
  229. module.requires_grad = False
  230. def Unfreeze_backbone(self):
  231. backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
  232. for module in backbone:
  233. try:
  234. for param in module.parameters():
  235. param.requires_grad = True
  236. except:
  237. module.requires_grad = True
  238. def vit(input_shape=[480, 480], pretrained=False, num_classes=2):
  239. model = VisionTransformer(input_shape)
  240. if pretrained:
  241. model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))
  242. if num_classes!=1000:
  243. model.head = nn.Linear(model.num_features, num_classes)
  244. return model

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/92648
推荐阅读
相关标签
  

闽ICP备14008679号