当前位置:   article > 正文

一文详解Vision Transformer(附代码)

visiontransformer

38187a62b43581490957de80de976559.gif

©PaperWeekly 原创 · 作者 | 孙裕道

单位 | 北京邮电大学博士生

研究方向 | GAN图像生成、情绪对抗样本生成

2aaef3a69f27414d742818355616d300.png

引言

Transformer 在 NLP 中大获成功,Vision Transformer 则将 Transformer 模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 Transformer 中的注意力机制可以综合考量全局的特征信息。

Vision Transformer 尽力做到在不改变 Transformer 中 Encoder 架构的前提下,直接将其从 NLP 领域迁移到计算机视觉领域中,目的是让原始的 Transformer 模型开箱即用。如果想要了解 Transformer 原理详细的介绍可以看我的上一篇文章《矩阵视角下的 Transformer 详解(附代码)》

b00be8f3cd5fff9584f716dbe7c5735b.png

注意力机制应用

在正式详细介绍 Vision Transformer 之前,先介绍两个注意力机制在计算机视觉中应用的例子。Vision Transformer 并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中 SAGAN 和 AttnGAN 就早已经在 GAN 的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

82e9a9ff8b47151f658c65db6f85079e.png

SAGAN 在 GAN 的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。SAGAN 中自注意力机制的操作原理如上图所示。

给定一个 3 通道的输入特征图 ,其中 ,。将 分别输入到三个不同的 的卷积层中,并生成 query 特征图 ,key 特征图 和 value 特征图 。生成 具体的计算过程为,给定三个卷积核 , 和 ,并用这三个卷积核分别与 做卷积运算得到 , 和 ,即:

a25145903bfd489222314e4075e94e54.png

其中 表示卷积运算符号。同理生成 和 的计算过程与 的计算过程类似。然后再利用 和 进行注意力分数的计算得到矩阵 ,其中矩阵 的元素 的计算公式为:

b4b531e6a64dea57611df3b515fcf1f4.png

再对矩阵 利用 softmax 函数进行注意力分布的计算得到注意力分布矩阵 ,其中矩阵 的元素 的计算公式为:

8e56b1c9166bb30d57b1abe9240deddf.png

最后利用注意力分布矩阵 和value特征图 得到最后的输出 ,即:

4c1af573c6387ec9c0ac60879e3f6939.png

2.2 AttnGAN

4b3f448e03383fd51dd5186733f9d2f9.png

AttnGAN 通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。AttnGAN 中注意力机制的操作原理如上图所示。

给定输入图像特征向量 和词特征向量 ,其中 ,,。首先利用矩阵 进行线性变换将词特征空间 的向量转换成图像特征空间 的向量,则有:

22e0cd8f42d8740aabda7e53bbd1b417.png

然后再利用转换后的词特征 与图像特征 进行注意力分数的计算得到注意力分数矩阵 ,其中的分量 的计算公式为:

840b0b807f6a7fad27775eed80ab30c3.png

再对矩阵 利用 函数进行注意力分布的计算得到注意力分布矩阵 ,其中矩阵 的元素 的计算公式为:

a7aac5a2d5cbce2517dd8ff187f1b52d.png

最后利用注意力分布矩阵 和图像特征 得到最后的输出 ,即:

29031397278956ea5d5e3efed63dd27d.png

e2f3f0faebd294891a2631fb13d92d48.png

Vision Transformer

本节主要详细介绍 Vision Transformer 的工作原理,3.1 节是关于 Vision Transformer 的整体框架,3.2 节是关于 Transformer Encoder 的内部操作细节。对于 Transformer Encoder 中 Multi-Head Attention 的原理本文不会赘述,具体想了解的可以参考上一篇文章《矩阵视角下的 Transformer 详解(附代码)》中相关原理的介绍。

不难发现,不管是自然语言处理中的 Transformer,还是计算机视觉中图像生成的 SAGAN,以及文本生成图像的 AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer 整体框架

如果下图所示为 Vision Transformer 的整体框架以及相应的训练流程。

  • 给定一张图片 ,并将它分割成 9 个 patch 分别为 。然后再将这个 9 个 patch 拉平,则有 ;

  • 利用矩阵 将拉平后的向量 经过线性变换得到图像编码向量 ,具体的计算公式为:

    03b706107bd3cef80188f784ae2e8375.png

  • 然后将图像编码向量 和类编码向量 分别与对应的位置编进行加和得到输入编码向量,则有:

    9d786471bd5de9e302eeffdd68bfd8fe.png

  • 接着将输入编码向量输入到 Vision Transformer Encoder 中得到对应的输出 ;

  • 最后将类编码向量 输入全连接神经网络中 MLP 得到类别预测向量 ,并与真实类别向量 计算交叉熵损失得到损失值 loss,利用优化算法更新模型的权重参数。

注意事项:

看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量 ,Vision Transformer Encoder 其它的输出为什么没有输入到 MLP 中?为了回答这个问题,我们令函数 为 Vision Transformer Encoder},则类编码向量 可以表示为:

b47ee53e06e6a4aa21b32468423009e8.png

由上公式可以发现,类编码向量 是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

5c6eacbb3241cc519795a15f9c438de1.png

3.2 Transformer Encoder操作原理

如下图所示分别为 Vision Transformer Encoder 模型结构图和原始 Transformer Encoder 的模型结构图。可以直观的发现 Vision Transformer Encoder 和 Transformer Encoder 都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的 Transformer 代码实例中,将以下两种 Encoder 网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。

下图左半部分 Vision Transformer Encoder 具体的操作流程为:

  • 给定输入编码矩阵 ,首先将其进行层归一化得到 ;

  • 利用矩阵 对 进行线性变换得到矩阵 。具体的计算过程为:

    9dfa68c399fb86214b8f8f2c0754353f.png

    再将这三个矩阵输入到 Multi-Head Attention(该原理参考《矩阵视角下的 Transformer 详解(附代码)》)中得到矩阵 ,将最原始的输入矩阵 与 进行残差计算得到 ;

  • 将 进行第二次层归一化得到 ,然后再将 输入到全连接神经网络中进行线性变换得到 。最后将 与 进行残差操作得到该 Block 的输出;。一个 Encoder 可以将 个 Block 进行堆叠,最后得到的输出为 。

cecf6c82f68804794181735d503506c4.png

7476881474ec6d77d7604c2b7a852043.png

程序代码

Vision Transformer 的代码示例如下所示。该代码是由上一篇《矩阵视角下的Transformer详解(附代码)》的代码的基础上改编而来。Vision Transformer 的作者的本意就是想让在 NLP 中的 Transformer 模型架构做尽可能少的修改可以直接迁移到 CV 中,所以以下程序尽可能保持作者的愿意,并在代码实现了两种Encoder 的网络结构,即 3.2 节图片所示的两个网络结构,一种是最原始的Encoder 网络结构,一种是 Vision Transformer。论文里的 Encoder 的网络结构。

这里需要注意的是,Vision Transformer 里并能没有 Decoder 模块,所以不需要计算 Encoder 和 Decoder 的交叉注意力分布,这就进一步给 Vision Transformer 的编程带来了简便。Vision Transformer的开源代码的网址为:

https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch

  1. import torch
  2. import torch.nn as nn
  3. import os
  4. from einops import rearrange
  5. from einops import repeat
  6. from einops.layers.torch import Rearrange
  7. def inputs_deal(inputs):
  8.     return inputs if isinstance(inputs, tuple) else(inputs, inputs)
  9. class SelfAttention(nn.Module):
  10.     def __init__(self, embed_size, heads):
  11.         super(SelfAttention, self).__init__()
  12.         self.embed_size = embed_size
  13.         self.heads = heads
  14.         self.head_dim = embed_size // heads
  15.         assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
  16.         self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
  17.         self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
  18.         self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
  19.         self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
  20.     def forward(self, values, keys, query):
  21.         N =query.shape[0]
  22.         value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]
  23.         # split embedding into self.heads pieces
  24.         values = values.reshape(N, value_len, self.heads, self.head_dim)
  25.         keys = keys.reshape(N, key_len, self.heads, self.head_dim)
  26.         queries = query.reshape(N, query_len, self.heads, self.head_dim)
  27.         values = self.values(values)
  28.         keys = self.keys(keys)
  29.         queries = self.queries(queries)
  30.         energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
  31.         # queries shape: (N, query_len, heads, heads_dim)
  32.         # keys shape : (N, key_len, heads, heads_dim)
  33.         # energy shape: (N, heads, query_len, key_len)
  34.         attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)
  35.         out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
  36.         # attention shape: (N, heads, query_len, key_len)
  37.         # values shape: (N, value_len, heads, heads_dim)
  38.         # (N, query_len, heads, head_dim)
  39.         out = self.fc_out(out)
  40.         return out
  41. class TransformerBlock(nn.Module):
  42.     def __init__(self, embed_size, heads, dropout, forward_expansion):
  43.         super(TransformerBlock, self).__init__()
  44.         self.attention = SelfAttention(embed_size, heads)
  45.         self.norm = nn.LayerNorm(embed_size)
  46.         self.feed_forward = nn.Sequential(
  47.             nn.Linear(embed_size, forward_expansion*embed_size),
  48.             nn.ReLU(),
  49.             nn.Linear(forward_expansion*embed_size, embed_size)
  50.         )
  51.         self.dropout = nn.Dropout(dropout)
  52.     def forward(self, value, key, query, x, type_mode):
  53.         if type_mode == 'original':
  54.             attention = self.attention(value, key, query)
  55.             x = self.dropout(self.norm(attention + x))
  56.             forward = self.feed_forward(x)
  57.             out = self.dropout(self.norm(forward + x))
  58.             return out
  59.         else:
  60.             attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
  61.             x =self.dropout(attention + x)
  62.             forward = self.feed_forward(self.norm(x))
  63.             out = self.dropout(forward + x)
  64.             return out
  65. class TransformerEncoder(nn.Module):
  66.     def __init__(
  67.             self,
  68.             embed_size,
  69.             num_layers,
  70.             heads,
  71.             forward_expansion,
  72.             dropout = 0,
  73.             type_mode = 'original'
  74.         ):
  75.         super(TransformerEncoder, self).__init__()
  76.         self.embed_size = embed_size
  77.         self.type_mode = type_mode
  78.         self.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)
  79.         self.layers = nn.ModuleList(
  80.             [
  81.                 TransformerBlock(
  82.                     embed_size,
  83.                     heads,
  84.                     dropout=dropout,
  85.                     forward_expansion=forward_expansion,
  86.                     )
  87.                 for _ in range(num_layers)]
  88.         )
  89.         self.dropout = nn.Dropout(dropout)
  90.     def forward(self, x):
  91.         for layer in self.layers:
  92.             QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)
  93.             x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)
  94.         return x
  95. class VisionTransformer(nn.Module):
  96.     def __init__(self,
  97.                 image_size,
  98.                 patch_size,
  99.                 num_classes,
  100.                 embed_size,
  101.                 num_layers,
  102.                 heads,
  103.                 mlp_dim,
  104.                 pool = 'cls',
  105.                 channels = 3,
  106.                 dropout = 0,
  107.                 emb_dropout = 0.1,
  108.                 type_mode = 'vit'):
  109.         super(VisionTransformer, self).__init__()
  110.         img_h, img_w = inputs_deal(image_size)
  111.         patch_h, patch_w = inputs_deal(patch_size)
  112.         assert img_h % patch_h == 0 and img_w % patch_w == 0'Img dimensions can be divisible by the patch dimensions'
  113.         num_patches = (img_h // patch_h) * (img_w // patch_w)
  114.         patch_size = channels * patch_h * patch_w
  115.         self.patch_embedding = nn.Sequential(
  116.             Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
  117.             nn.Linear(patch_size, embed_size, bias=False)
  118.         )
  119.         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))
  120.         self.cls_token = nn.Parameter(torch.randn(11, embed_size))
  121.         self.dropout = nn.Dropout(emb_dropout)
  122.         self.transformer = TransformerEncoder(embed_size,
  123.                                     num_layers,
  124.                                     heads,
  125.                                     mlp_dim,
  126.                                     dropout)
  127.         self.pool = pool
  128.         self.to_latent = nn.Identity()
  129.         self.mlp_head = nn.Sequential(
  130.             nn.LayerNorm(embed_size),
  131.             nn.Linear(embed_size, num_classes)
  132.         )
  133.     def forward(self, img):
  134.         x = self.patch_embedding(img)
  135.         b, n, _ = x.shape
  136.         cls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)
  137.         x = torch.cat((cls_tokens, x), dim = 1)
  138.         x += self.pos_embedding[:, :(n + 1)]
  139.         x = self.dropout(x)
  140.         x = self.transformer(x)
  141.         x = x.mean(dim = 1if self.pool == 'mean' else x[:, 0]
  142.         x = self.to_latent(x)
  143.         return self.mlp_head(x)
  144. if __name__ == '__main__':
  145.     vit = VisionTransformer(
  146.             image_size = 256,
  147.             patch_size = 16,
  148.             num_classes = 10,
  149.             embed_size = 256,
  150.             num_layers = 6,
  151.             heads = 8,
  152.             mlp_dim = 512,
  153.             dropout = 0.1,
  154.             emb_dropout = 0.1
  155.         )
  156.     img = torch.randn(33256256)
  157.     pred = vit(img)
  158.     print(pred)

以下代码是利用 Vision Transformer 网络结构训练一个分类 mnist 数据集的主程序代码。

  1. from torchvision import datasets, transforms
  2. from torch.utils.data import DataLoader, Dataset
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. import VIT
  8. import os
  9. def train():
  10.     batch_size = 4
  11.     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  12.     epoches = 20
  13.     mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
  14.     train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
  15.     mnist_model = VIT.VisionTransformer(
  16.         image_size = 28,
  17.         patch_size = 7,
  18.         num_classes = 10,
  19.         channels = 1,
  20.         embed_size = 512,
  21.         num_layers = 1,
  22.         heads = 2,
  23.         mlp_dim =1024,
  24.         dropout = 0,
  25.         emb_dropout = 0)
  26.     loss_fn = nn.CrossEntropyLoss()
  27.     mnist_model = mnist_model.to(device)
  28.     opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
  29.     mnist_model.train()
  30.     for epoch in range(epoches):
  31.         total_loss = 0
  32.         corrects = 0
  33.         num = 0
  34.         for batch_X, batch_Y in train_loader:
  35.             batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
  36.             opitimizer.zero_grad()
  37.             outputs = mnist_model(batch_X)
  38.             _, pred = torch.max(outputs.data, 1)
  39.             loss = loss_fn(outputs, batch_Y)
  40.             loss.backward()
  41.             opitimizer.step()
  42.             total_loss += loss.item()
  43.             corrects = torch.sum(pred == batch_Y.data)
  44.             num += batch_size
  45.             print(epoch, total_loss/float(num), corrects.item()/float(batch_size))
  46. if __name__ == '__main__':
  47.     train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个 Vision Transformer 模型真的是很烧硬件,跟训练一个普通的 CNN 模型相比,训练一个 Vision Transformer 模型更加耗时耗力。

a22a1806b3bc0165e9146f8279a74bee.png

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

更多阅读

1e9e20c6e4c39b5503d120be7155974f.png

beefe50427b704fc60f08c6583423c8c.png

f9dbbd0f4afb1186003d22ddd19d13a2.png

0d6a0dd3253633bb09a699aa395aeb5a.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/92603
推荐阅读
相关标签