赞
踩
T r a n s f o r m e r \mathrm{Transformer} Transformer在 N L P \mathrm{NLP} NLP中大获成功, V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer则将 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 T r a n s f o r m e r \mathrm{Transformer} Transformer中的注意力机制可以综合考量全局的特征信息。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer尽力做到在不改变 T r a n s f o r m e r \mathrm{Transformer} Transformer中 E n c o d e r \mathrm{Encoder} Encoder架构的前提下,直接将其从 N L P \mathrm{NLP} NLP领域迁移到计算机视觉领域中,目的是让原始的 T r a n s f o r m e r \mathrm{Transformer} Transformer模型开箱即用。如果想要了解 T r a n s f o r m e r \mathrm{Transformer} Transformer原理详细的介绍可以看我的上一篇文章《Transformer详解(附代码)》。
在正式详细介绍 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer之前,先介绍两个注意力机制在计算机视觉中应用的例子。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中 S A G A N \mathrm{SAGAN} SAGAN和 A t t n G A N \mathrm{AttnGAN} AttnGAN就早已经在 G A N \mathrm{GAN} GAN的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。
S
A
G
A
N
\mathrm{SAGAN}
SAGAN在
G
A
N
\mathrm{GAN}
GAN的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。
S
A
G
A
N
\mathrm{SAGAN}
SAGAN中自注意力机制的操作原理如下图所示。
给定一个
3
3
3通道的输入特征图
X
=
(
X
1
,
X
2
,
X
3
)
∈
R
3
×
3
×
3
X=(X^1,X^2,X^3)\in \mathbb{R}^{3\times 3\times 3}
X=(X1,X2,X3)∈R3×3×3,其中
X
i
∈
R
3
×
3
X^{i}\in \mathbb{R}^{3\times 3}
Xi∈R3×3,
i
∈
{
1
,
2
,
3
}
i\in\{1,2,3\}
i∈{1,2,3}。将
X
X
X分别输入到三个不同的
1
×
1
1\times 1
1×1的卷积层中,并生成
q
u
e
r
y
\mathrm{query}
query特征图
Q
∈
R
3
×
3
×
3
Q\in \mathbb{R}^{3\times 3\times 3}
Q∈R3×3×3,
k
e
y
\mathrm{key}
key特征图
K
∈
R
3
×
3
×
3
K\in \mathbb{R}^{3\times 3\times 3}
K∈R3×3×3和
v
a
l
u
e
\mathrm{value}
value特征图
V
∈
R
3
×
3
×
3
V\in \mathbb{R}^{3\times 3\times 3}
V∈R3×3×3。生成
Q
Q
Q具体的计算过程为,给定三个卷积核
W
q
1
W^{q1}
Wq1,
W
q
2
W^{q2}
Wq2和
W
q
3
∈
R
1
×
1
×
3
W^{q3}\in\mathbb{R}^{1\times1\times3}
Wq3∈R1×1×3,并用这三个卷积核分别与
X
X
X做卷积运算得到
Q
1
Q^1
Q1,
Q
2
Q^2
Q2和
Q
3
∈
R
3
×
3
Q^3\in \mathbb{R}^{3 \times 3}
Q3∈R3×3,即
{
Q
1
=
X
∗
W
q
1
Q
2
=
X
∗
W
q
2
Q
3
=
X
∗
W
q
3
\left\{
A
t
t
n
G
A
N
\mathrm{AttnGAN}
AttnGAN通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。
A
t
t
n
G
A
N
\mathrm{AttnGAN}
AttnGAN中注意力机制的操作原理如下图所示。
给定输入图像特征向量
h
=
(
h
1
,
h
2
,
h
3
,
h
4
)
∈
R
D
^
×
4
h=(h^1,h^2,h^3,h^4)\in\mathbb{R}^{\hat{D}\times 4}
h=(h1,h2,h3,h4)∈RD^×4和词特征向量
e
=
(
e
1
,
e
2
,
e
3
,
e
4
)
e=(e^1,e^2,e^3,e^4)
e=(e1,e2,e3,e4),其中
h
i
∈
R
D
^
×
1
h^i\in \mathbb{R}^{\hat{D}\times 1}
hi∈RD^×1,
e
i
∈
R
D
×
1
e^i\in \mathbb{R}^{D\times 1}
ei∈RD×1,
i
∈
{
1
,
2
,
3
,
4
}
i\in \{1,2,3,4\}
i∈{1,2,3,4}。首先利用矩阵
W
W
W进行线性变换将词特征空间
R
D
\mathbb{R}^{D}
RD的向量转换成图像特征空间
R
D
^
\mathbb{R}^{\hat{D}}
RD^的向量,则有
e
^
=
W
⋅
e
=
(
e
^
1
,
e
^
2
,
e
^
3
,
e
^
4
)
∈
R
D
^
×
4
\hat{e}=W\cdot e=(\hat{e}^1,\hat{e}^2,\hat{e}^3,\hat{e}^4)\in \mathbb{R}^{\hat{D}\times 4}
e^=W⋅e=(e^1,e^2,e^3,e^4)∈RD^×4然后再利用转换后的词特征
e
^
\hat{e}
e^与图像特征
h
h
h进行注意力分数的计算得到注意力分数矩阵
S
S
S,其中的分量
s
i
j
s_{ij}
sij的计算公式为
s
i
j
=
(
h
i
)
⊤
⋅
e
^
j
,
i
∈
{
1
,
2
,
3
,
4
}
,
j
∈
{
1
,
2
,
3
,
4
}
s_{ij}=(h^i)^{\top}\cdot \hat{e}^j,\quad i\in \{1,2,3,4\},j\in\{1,2,3,4\}
sij=(hi)⊤⋅e^j,i∈{1,2,3,4},j∈{1,2,3,4} 再对矩阵
S
S
S利用
s
o
f
t
m
a
x
\mathrm{softmax}
softmax函数进行注意力分布的计算得到注意力分布矩阵
β
∈
R
4
×
4
\beta\in \mathbb{R}^{4\times 4}
β∈R4×4,其中矩阵
β
\beta
β的元素
β
i
j
\beta_{ij}
βij的计算公式为
β
i
j
=
exp
(
s
i
j
)
∑
k
=
1
3
exp
(
s
i
k
)
,
i
∈
{
1
,
2
,
3
,
4
}
,
l
∈
{
1
,
2
,
3
,
4
}
\beta_{ij}=\frac{\exp(s_{ij})}{\sum\limits_{k=1}^{3}\exp(s_{ik})},\quad i \in \{1,2,3,4\},l\in\{1,2,3,4\}
βij=k=1∑3exp(sik)exp(sij),i∈{1,2,3,4},l∈{1,2,3,4}最后利用注意力分布矩阵
β
\beta
β和图像特征
h
h
h得到最后的输出
o
=
(
o
1
,
o
2
,
o
3
,
o
4
)
∈
R
D
^
×
4
o=(o^1,o^2,o^3,o^4)\in \mathbb{R}^{\hat{D}\times 4}
o=(o1,o2,o3,o4)∈RD^×4,即
{
o
1
=
β
11
⋅
h
1
+
β
12
⋅
h
2
+
β
13
⋅
h
3
+
β
14
⋅
h
4
o
2
=
β
21
⋅
h
1
+
β
22
⋅
h
2
+
β
23
⋅
h
3
+
β
24
⋅
h
4
o
3
=
β
31
⋅
h
1
+
β
32
⋅
h
2
+
β
33
⋅
h
3
+
β
34
⋅
h
4
o
4
=
β
41
⋅
h
1
+
β
42
⋅
h
2
+
β
43
⋅
h
3
+
β
44
⋅
h
4
\left\{
本节主要详细介绍 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的工作原理,3.1节是关于 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的整体框架,3.2节是关于 T r a n s f o r m e r E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder的内部操作细节。对于 T r a n s f o r m e r E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder中 M u l t i \mathrm{Multi} Multi- H e a d A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention的原理本文不会赘述,具体想了解的可以参考上一篇文章《Transformer详解(附代码)》中相关原理的介绍。不难发现,不管是自然语言处理中的 T r a n s f o r m e r \mathrm{Transformer} Transformer,还是计算机视觉中图像生成的 S A G A N \mathrm{SAGAN} SAGAN,以及文本生成图像的 A t t n G A N \mathrm{AttnGAN} AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。
如果下图所示为 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的整体框架以及相应的训练流程
注意事项: 看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量
o
0
o^0
o0,
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Vision\text{ }Transformer\text{ }Encoder}
Vision Transformer Encoder其它的输出为什么没有输入到
M
L
P
\mathrm{MLP}
MLP中?为了回答这个问题,我们令函数
f
0
(
⋅
)
f_0(\cdot)
f0(⋅)为
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Vision\text{ }Transformer\text{ }Encoder}
Vision Transformer Encoder,则类编码向量
o
0
o^{0}
o0可以表示为
o
0
=
f
0
(
z
0
+
p
0
,
⋯
,
z
9
+
p
9
)
o^0=f_0(z^0+p^0,\cdots,z^9+p^9)
o0=f0(z0+p0,⋯,z9+p9)由上公式可以发现,类编码向量
o
0
o^{0}
o0是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。
如下图所示分别为
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Vision\text{ }Transformer\text{ }Encoder}
Vision Transformer Encoder模型结构图和原始
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Transformer\text{ }Encoder}
Transformer Encoder的模型结构图。可以直观的发现
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Vision\text{ }Transformer\text{ }Encoder}
Vision Transformer Encoder和
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Transformer\text{ }Encoder}
Transformer Encoder都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的
T
r
a
n
s
f
o
r
m
e
r
\mathrm{ \text{ }Transformer}
Transformer代码实例中,将以下两种
E
n
c
o
d
e
r
\mathrm{Encoder}
Encoder网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。
下图左半部分
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
E
n
c
o
d
e
r
\mathrm{Vision\text{ }Transformer\text{ }Encoder}
Vision Transformer Encoder具体的操作流程为
V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的代码示例如下所示。该代码是由上一篇《Transformer详解(附代码)》的代码的基础上改编而来。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的作者的本意就是想让在 N L P \mathrm{NLP} NLP中的 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构做尽可能少的修改可以直接迁移到 C V \mathrm{CV} CV中,所以以下程序尽可能保持作者的原意,并在代码实现了两种 E n c o d e r \mathrm{Encoder} Encoder的网络结构,即3.2节图片所示的两个网络结构,一种是最原始的 E n c o d e r \mathrm{Encoder} Encoder网络结构,一种是 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer论文里的 E n c o d e r \mathrm{Encoder} Encoder的网络结构。这里需要注意的是, V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer里并能没有 D e c o d e r \mathrm{Decoder} Decoder模块,所以不需要计算 E n c o d e r \mathrm{Encoder} Encoder和 D e c o d e r \mathrm{Decoder} Decoder的交叉注意力分布,这就进一步给 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的编程带来了简便。 V i s i o n T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的开源代码的网址为https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch。
import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange
def inputs_deal(inputs):
return inputs if isinstance(inputs, tuple) else(inputs, inputs)
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N =query.shape[0]
value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]
# split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
# queries shape: (N, query_len, heads, heads_dim)
# keys shape : (N, key_len, heads, heads_dim)
# energy shape: (N, heads, query_len, key_len)
attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# (N, query_len, heads, head_dim)
out = self.fc_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion*embed_size),
nn.ReLU(),
nn.Linear(forward_expansion*embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, x, type_mode):
if type_mode == 'original':
attention = self.attention(value, key, query)
x = self.dropout(self.norm(attention + x))
forward = self.feed_forward(x)
out = self.dropout(self.norm(forward + x))
return out
else:
attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
x =self.dropout(attention + x)
forward = self.feed_forward(self.norm(x))
out = self.dropout(forward + x)
return out
class TransformerEncoder(nn.Module):
def __init__(
self,
embed_size,
num_layers,
heads,
forward_expansion,
dropout = 0,
type_mode = 'original'
):
super(TransformerEncoder, self).__init__()
self.embed_size = embed_size
self.type_mode = type_mode
self.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion,
)
for _ in range(num_layers)]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
for layer in self.layers:
QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)
x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)
return x
class VisionTransformer(nn.Module):
def __init__(self,
image_size,
patch_size,
num_classes,
embed_size,
num_layers,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dropout = 0,
emb_dropout = 0.1,
type_mode = 'vit'):
super(VisionTransformer, self).__init__()
img_h, img_w = inputs_deal(image_size)
patch_h, patch_w = inputs_deal(patch_size)
assert img_h % patch_h == 0 and img_w % patch_w == 0, 'Img dimensions can be divisible by the patch dimensions'
num_patches = (img_h // patch_h) * (img_w // patch_w)
patch_size = channels * patch_h * patch_w
self.patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
nn.Linear(patch_size, embed_size, bias=False)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = TransformerEncoder(embed_size,
num_layers,
heads,
mlp_dim,
dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_size),
nn.Linear(embed_size, num_classes)
)
def forward(self, img):
x = self.patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
vit = VisionTransformer(
image_size = 256,
patch_size = 16,
num_classes = 10,
embed_size = 256,
num_layers = 6,
heads = 8,
mlp_dim = 512,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(3, 3, 256, 256)
pred = vit(img)
print(pred)
以下代码是利用 V i s i o n T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer网络结构训练一个分类 m n i s t \mathrm{mnist} mnist数据集的主程序代码。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import os
def train():
batch_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epoches = 20
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
mnist_model = VIT.VisionTransformer(
image_size = 28,
patch_size = 7,
num_classes = 10,
channels = 1,
embed_size = 512,
num_layers = 1,
heads = 2,
mlp_dim =1024,
dropout = 0,
emb_dropout = 0)
loss_fn = nn.CrossEntropyLoss()
mnist_model = mnist_model.to(device)
opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
mnist_model.train()
for epoch in range(epoches):
total_loss = 0
corrects = 0
num = 0
for batch_X, batch_Y in train_loader:
batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
opitimizer.zero_grad()
outputs = mnist_model(batch_X)
_, pred = torch.max(outputs.data, 1)
loss = loss_fn(outputs, batch_Y)
loss.backward()
opitimizer.step()
total_loss += loss.item()
corrects = torch.sum(pred == batch_Y.data)
num += batch_size
print(epoch, total_loss/float(num), corrects.item()/float(batch_size))
if __name__ == '__main__':
train()
训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
\mathrm{Vision \text{ }Transformer}
Vision Transformer模型真的是很烧硬件,跟训练一个普通的
C
N
N
\mathrm{CNN}
CNN模型相比,训练一个
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
\mathrm{Vision \text{ }Transformer}
Vision Transformer模型更加耗时耗力。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。