赞
踩
以下内容为小白学习vit内容记录,如理解有误,望帮助指出修正。基于Paddle框架学习,aistudio课程即可学习。此次记录课程里视觉问题中的注意力机制小节的学习内容
课程中注意力机制从NLP的方向为我们举例,我直接从公式开始。假设有三个image token输入,输入在中间会经过一次权重矩阵维度(out = proj * x)相乘变化,随着proj权重矩阵维度的不一样输出的维度也会跟着变化。得到的结果(Vi)再与注意力权重(a1—a3)运算求和,这就是注意力机制。
为了模型的复杂性、深度更深(为了它更强),在简单的注意力机制上,变得复杂一些。依旧输入image tokens,一个输入不只经过一次权重矩阵,而是经过三个,得到(v1、k1、q1),注意力权重经过K 、Q的点积得出。在计算X1的注意力时,C1 = a1 * v1 + a2 * v2 + a3 * v3 (注意力计算公式),此时注意力权重使用自己的q1与其他的K 做点积得到用于计算X1的注意力权重(a1、a2、a3)。同理X2,注意力权重使用X2的q2与其他的K做点积得到用于计算X2的注意力权重,再将权重带入公式就可以对每一个patch image进行计算(如下图展示)。
Wq、Wk、Wv是权重矩阵,是可学习的。
image token经过权重矩阵后,就开始计算每一个token的注意力。
这里对注意力权重再做了一次计算,在原来点积后的结果除于根号下Dk(dk是embedding的数值,如果是Multi--Head self attention,dk是embedding // num_head的数值),再经过softmax的结果,用于最后注意力计算时的注意力权重。
课程中对于Multi--Head self attention的讲解,结合代码展示的代码感觉不是很相符合。课上的理解就是把原来的image token经过的权重矩阵,复制几个维度一致的矩阵(矩阵权重不一致)这样就可以产生多个的qkv去计算注意力,对多个qkv计算的注意力经过一个权重矩阵变化得到最后的结果。
后来看了一个大佬写的解释感觉更符合实际代码过程。把原来计算的qkv均分成多个,一组一组的qkv去计算注意力(步骤和前面的过程一致)。大佬文章(下图取自大佬的文章)
代码是自己看完视频参考写的,如果有误可以指出修改。
- import paddle
- import paddle.nn
-
-
-
- class Atten(paddle.nn.Layer):
- def __init__(self,embed_dim,head_dim,dropout=0.):
- super(Atten, self).__init__()
- self.head_dim = embed_dim//self.head_dim
- self.qkv = paddle.nn.Linear(embed_dim,int(3*embed_dim))
- self.dropout = paddle.nn.Dropout(dropout)
- self.scale = self.head_dim ** -0.5
- self.proj = paddle.nn.Linear(embed_dim,embed_dim)
- self.softmax = paddle.nn.Softmax(-1)
- def forward(self,x):
- batchsize,num_patch,embedding_dim = x.shape
- qkv = self.qkv(x)
- #[4, 16, 288]
- qkv = qkv.reshape((batchsize,num_patch,3,embedding_dim//self.head_dim,-1)).transpose((2,0,4,1,3))
- q,k,v = qkv[0],qkv[1],qkv[2]
- attn = paddle.matmul(q, k, transpose_y=True)
- attn = self.scale * attn
-
- attn = self.softmax(attn)
- attn = self.dropout(attn)
- out = paddle.matmul(attn, v)
- out = out.transpose([0, 2, 1, 3])
- out = out.reshape([batchsize, num_patch, -1])
- out = self.proj(out)
- out = self.dropout(out)
- return out
-
-
-
- def main():
- ## [batchsize,num_patch,dim_embed]
- t = paddle.randn((4,16,96))
- model = Atten(96,8)
- out = model(t)
- print(out.shape)
-
-
- if __name__ =='__main__':
- main()
正在入门深度学习,其中的理解与解释比较牵强,后面会不断学习此课程,学习记录。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。