首先,本期还是 一样,推荐观看原视频进行学习,以便更加痛彻(顶部附有思维导图)

链接地址:28、Vision Transformer(ViT)模型原理及PyTorch逐行实现_哔哩哔哩_bilibili



为了做分类任务,ViT使用了多个embedding,发现可训练更好(即position embedding)


为了更加了解ViT,我们接下来欣赏一篇论文:2010.11929.pdf (arxiv.org)

为了让小伙伴们更加清楚,我找了一篇中译解读该论文的文章:论文解读:AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE - 知乎 (zhihu.com)

对于该论文标题使用:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale



下面我们对 摘要进行解读,大致意思是:

虽然变压器架构已成为自然的事实标准语言处理任务,其对计算机视觉的应用仍然有限。在 视觉,注意力要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持其整体结构到位。我们表明,这种对CNN的依赖是不必要的直接应用于图像补丁序列的纯转换器可以执行在图像分类任务上做得很好。当预先训练大量数据并传输到多个中型或小型图像识别基准(ImageNet,CIFAR-100,VTAB等),Vision Transformer(ViT)获得卓越, 结果与最先进的卷积网络相比,同时需要sub stantially(经验)训练的计算资源更少







 图 1:模型概述。我们将图像拆分为固定大小的补丁,线性嵌入每个补丁,添加位置嵌入,并将生成的向量序列馈送到标准转换器编码器。为了进行分类,我们使用标准方法添加额外的可学习性序列的“classification token”。

“classification token”可以理解为为了做好分类任务而做的收集信息等任务

左图示例中,图形会被分成很多块,图形大小会变化,但是每个块的大小不会变化(在同一个模型中),然后从左到右,从上到下把块拉直,然后进行归一化,再把块中的值进行线性变化映射到这个模型的维度,得到一个patch embedding,然后还需要在开头增加一个可训练的embedding(也是可初始化的embedding),构成新的常用embedding(位置编码),然后送入transformer encoder中,加到多余的未知状态,经过MLP Head(多层感知机) 经过交叉熵完成ViT模型的搭建




step1 convert image to embedding vector sequence


1. naive实现

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def image2emb_naive(image, patch_size, weight):
  5. # image shape: bs*channel*h*w
  6. patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
  7. print(patch.shape)


讲到这,相信有些宝子还是迷糊的,那么还是老样子,先附上函数的官方讲解:torch.nn.functional.unfold — PyTorch 2.0 documentation


为了测试我们所写的函数是否正确,我们接下来定义一些常量 来测试

  1. # test code for image2emb (定义常量)
  2. bs, ic, image_h, image_w = 1, 3, 8, 8 #ic是input channel
  3. patch_size = 4 #4*4为一个patch
  4. model_dim = 8 #在模型汇总,patch_embedding大小跟模型大小是一致的
  5. patch_depth = patch_size * patch_size * ic
  6. image = torch.randn(bs, ic, image_h, image_w) #得到一张图片
  7. weight = torch.randn(patch_depth, model_dim) #patch to embedding的乘法矩阵,是个二维张量,张量的第一维度应该是张量大小
  8. image2emb_naive(image, patch_size, weight)


 出现1*4*48的原因是(bs,  num_patch,  patch_depth(patch_size*patch_size*ic))

详细解释是1是batch_size,4是因为图片是 8*8的面积,patch_size是4*4,一个8*8的图片经过4*4的处理后就是4块8是patch_size*patch_size*input channel(4*4*3)


 所以patch_embedding= patch @ weight(矩阵相乘),即

  1. def image2emb_naive(image, patch_size, weight):
  2. # image shape: bs*channel*h*w
  3. patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
  4. print(patch.shape)
  5. #patch_embedding = patch @ weight
  6. #return patch_embedding
  1. # test code for image2emb (定义常量)
  2. bs, ic, image_h, image_w = 1, 3, 8, 8 #ic是input channel
  3. patch_size = 4 #4*4为一个patch
  4. model_dim = 8 #在模型汇总,patch_embedding大小跟模型大小是一致的
  5. patch_depth = patch_size * patch_size * ic
  6. image = torch.randn(bs, ic, image_h, image_w) #得到一张图片
  7. weight = torch.randn(patch_depth, model_dim) #patch to embedding的乘法矩阵,是个二维张量,张量的第一维度应该是张量大小
  8. patch_embedding_naive = image2emb_naive(image, patch_size, weight)
  9. print(patch_embedding_naive.shape)





  1. def image2emb_conv(image, kernel, stride): #定义卷积三要素:输入,kernel(特征提取器),步长
  2. conv_output = F.conv2d(image, kernel, stride=stride) # 大小是:bs*oc*oh*ow(batch_szie*output_channel*output_height*output_weight)
  3. #一般我们卷积过的宽度和高度会拉成一个序列
  4. bs, oc, oh, ow = conv_output.shape
  5. patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1, -2)#拉直,且把序列长度放中间
  6. return patch_embedding


  1. kernel = weight.transpose(0,1).reshape((-1, ic, patch_size, patch_size))#形状是oc*ic*kh*kw(output channel*input channel*kernel height*kernel weight
  2. #先使用transpose,将通道数放到前面,然后再进行reshhape操作,然后调用到image2emb_conv


patch_embedding_conv = image2emb_conv(image, kernel, patch_size)


  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def image2emb_naive(image, patch_size, weight): #patch_size:块的大小
  5. # image shape: bs * channel * h * w
  6. patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
  7. #因为可知图像分块没有交叠,所以stride不会为1,stride=kernel_size会分块
  8. patch_embedding = patch @ weight
  9. return patch_embedding
  10. def image2emb_conv(image, kernel, stride): #定义卷积三要素:输入,kernel(特征提取器),步长
  11. conv_output = F.conv2d(image, kernel, stride=stride) # 大小是:bs*oc*oh*ow(batch_szie*output_channel*output_height*output_weight)
  12. #一般我们卷积过的宽度和高度会拉成一个序列
  13. bs, oc, oh, ow = conv_output.shape
  14. patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1, -2)#拉直,且把序列长度放中间
  15. return patch_embedding
  16. # test code for image2emb (定义常量)
  17. bs, ic, image_h, image_w = 1, 3, 8, 8 #ic是input channel
  18. patch_size = 4 #4*4为一个patch
  19. model_dim = 8 #在模型汇总,patch_embedding大小跟模型大小是一致的
  20. patch_depth = patch_size * patch_size * ic
  21. image = torch.randn(bs, ic, image_h, image_w) #得到一张图片
  22. # weight = torch.randn(patch_depth, model_dim) #naive使用-->>patch to embedding的乘法矩阵,是个二维张量,张量的第一维度应该是张量大小
  23. weight = torch.randn(patch_depth, model_dim) #卷积使用-->>model_dim是输出通道数目,patch_size是卷积核的面积*输入通道数
  24. patch_embedding_naive = image2emb_naive(image, patch_size, weight)#---分块方法得到embedding----
  25. kernel = weight.transpose(0,1).reshape((-1, ic, patch_size, patch_size))#形状是oc*ic*kh*kw(output channel*input channel*kernel height*kernel weight
  26. #先使用transpose,将通道数放到前面,然后再进行reshhape操作,然后调用到image2emb_conv
  27. patch_embedding_conv = image2emb_conv(image, kernel, patch_size)#---二维卷积的方法得到embedding---
  28. print(patch_embedding_naive.shape)
  29. print(patch_embedding_conv.shape)



step2 prepend CLS token embedding

  1. #step2 prepend CLS token embedding(在模型开头增加一个embedding)
  2. cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad=True)
  3. token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)

step3 add position embedding


max_num_token = 16

然后进行 下述操作 

  1. # step3 add position embedding
  2. position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
  3. seq_len = token_embedding.shape[1]
  4. position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])
  5. token_embedding += position_embedding

step4 pass embedding to Transform Encoder

接下来我们进行transformencord,为了让大家了解更清楚有什么用,我找了该代码的官方解释:TransformerEncoder — PyTorch 2.0 documentation



  1. encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
  2. transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  3. encoder_output = transformer_encoder(token_embedding)

step5 do classfication


  1. num_classes = 10
  2. label = torch.randint(10, (bs,))
  1. cls_token_output = encoder_output[:, 0, :] #三维为 bs 位置 通道数目,得到一维通道输出
  2. linear_layer = nn.Linear(model_dim, num_classes)
  3. logits = linear_layer(cls_token_output)
  4. loss_fn = nn.CrossEntropyLoss()
  5. loss = loss_fn(logits, label)
  6. print(loss)




  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. #step1 convert image to embedding vector sequence
  5. def image2emb_naive(image, patch_size, weight): #patch_size:块的大小
  6. # image shape: bs * channel * h * w
  7. patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
  8. #因为可知图像分块没有交叠,所以stride不会为1,stride=kernel_size会分块
  9. patch_embedding = patch @ weight
  10. return patch_embedding
  11. def image2emb_conv(image, kernel, stride): #定义卷积三要素:输入,kernel(特征提取器),步长
  12. conv_output = F.conv2d(image, kernel, stride=stride) # 大小是:bs*oc*oh*ow(batch_szie*output_channel*output_height*output_weight)
  13. #一般我们卷积过的宽度和高度会拉成一个序列
  14. bs, oc, oh, ow = conv_output.shape
  15. patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1, -2)#拉直,且把序列长度放中间
  16. return patch_embedding
  17. # test code for image2emb (定义常量)
  18. bs, ic, image_h, image_w = 1, 3, 8, 8 #ic是input channel
  19. patch_size = 4 #4*4为一个patch
  20. model_dim = 8 #在模型汇总,patch_embedding大小跟模型大小是一致的
  21. max_num_token = 16
  22. num_classes = 10
  23. label = torch.randint(10, (bs,))
  24. patch_depth = patch_size * patch_size * ic
  25. image = torch.randn(bs, ic, image_h, image_w) #得到一张图片
  26. # weight = torch.randn(patch_depth, model_dim) #naive使用-->>patch to embedding的乘法矩阵,是个二维张量,张量的第一维度应该是张量大小
  27. weight = torch.randn(patch_depth, model_dim) #卷积使用-->>model_dim是输出通道数目,patch_size是卷积核的面积*输入通道数
  28. patch_embedding_naive = image2emb_naive(image, patch_size, weight)#---分块方法得到embedding----
  29. kernel = weight.transpose(0,1).reshape((-1, ic, patch_size, patch_size))#形状是oc*ic*kh*kw(output channel*input channel*kernel height*kernel weight
  30. #先使用transpose,将通道数放到前面,然后再进行reshhape操作,然后调用到image2emb_conv
  31. patch_embedding_conv = image2emb_conv(image, kernel, patch_size)#---二维卷积的方法得到embedding---
  32. print(patch_embedding_naive.shape)
  33. print(patch_embedding_conv.shape)
  34. print(patch_embedding_naive)
  35. print(patch_embedding_conv)
  36. #step2 prepend CLS token embedding(在模型开头增加一个embedding)
  37. cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad=True)#增加参数requires_grad,因为是可训练的
  38. token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1) #在一位置上(中间维度)去拼接
  39. # step3 add position embedding
  40. position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
  41. seq_len = token_embedding.shape[1]
  42. position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])#position_embedding_table[:seq_len]是复制成batch_size的步数,[token_embedding.shape[0]指复制这么多份
  43. token_embedding += position_embedding
  44. # step4 pass embedding to Transform Encoder
  45. encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
  46. transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  47. encoder_output = transformer_encoder(token_embedding)
  48. # step5 do classfication
  49. cls_token_output = encoder_output[:, 0, :] #三维为 bs 位置 通道数目,得到一维通道输出
  50. linear_layer = nn.Linear(model_dim, num_classes)
  51. logits = linear_layer(cls_token_output)
  52. loss_fn = nn.CrossEntropyLoss()
  53. loss = loss_fn(logits, label)
  54. print(loss)

ViT结构简单,一般 都是用于图像识别,但是成本很高,需要大量图形,所以可以去了解更多模型

最后推荐记录的视频:28、Vision Transformer(ViT)模型原理及PyTorch逐行实现_哔哩哔哩_bilibili



