赞
踩
1、多头注意力的概念
自注意力模型可以看作为在一个线性投影空间中建立输入向量中不同形式之间的交互关系。多头注意力就是在多个不同的投影空间中建立不同的投影信息。将输入矩阵,进行不同的投影,得到许多输出矩阵后,将其拼接在一起。
从下图中可以看出V K Q 是固定的单个值,而Linear层有3个,Scaled Dot-Product Attention 有3个,即3个多头;最后cancat在一起,然后Linear层转换变成一个和单头一样的输出值;类似于集成;多头和单头的区别在于复制多个单头,但权重系数肯定是不一样的;类比于一个神经网络模型与多个一样的神经网络模型,但由于初始化不一样,会导致权重不一样,然后结果集成。
attention层本质就是特征提取器,是基于整个序列提取了潜在的特征,因为每个embedding向量都融入了其他的embedding的信息,能生成有意义的特征,去除了无关的噪音。
上面是我查到的关于多头注意力机制的模型描述,这里谈谈我对其的理解:在求hi的公式中,给定的查询、键、值应该是指输入矩阵X,而开始这些是相同的,通过相应的权重分别计算出对应的查询、键、值,而每一个头的权重是不同的,所以有Wi。接下来就分别在每一个头上各自计算对应的自注意输出,最后有几个头就得到几个结果,然后将其拼接,乘以相应的权重,从而得到最终的结果。(以上是自己的理解,有不对的地方请批评指正,谢谢!)
具体流程如下图所示
2、pytorch实现多头注意机制
3 代码
import torch.nn as nn class Attention(nn.Module): def __init__(self,dim,num_heads=8,qkv_bias=False,qk_scale=None,attn_drop=0.,proj_drop=0.): super(Attention, self).__init__() self.num_heads = num_heads head_dim = dim // num_heads #这里的dim指的是c,即input输入N个数据的维度。 # 另外,为避免繁琐,将8个头的权重拼在一起,而三个不同的权重由后面的Linear生成。 # 而self.qkv的作用是是将input X (N,C)与权重W(C,8*C1*3)相乘得到Q_K_V拼接在一起的矩阵。 # 所以,dim*3表示的就是所有头所有权重拼接的维度,即8*C1*3。即dim=C=C1*3。 self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim,dim *3,bias = qkv_bias) #bias默认为True,为了使该层学习额外的偏置。 self.attn_drop = nn.Dropout(attn_drop) #dropout忽略一半节点 self.proj = nn.Linear(dim,dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B,N,C = x.shape #B为batch_size qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C // self.num_heads).permute(2,0,3,1,4) #将x构造成一个(B,N,3,8,c1)的一个矩阵,然后将其变成(3,B,8,N,c1)。 #是为了使得后面能将其分为三部分,分别作为不同的权重,维度为(B,8,N,c1) q,k,v = qkv[0],qkv[1],qkv[2] attn = (q @ k.transpose(-2,-1)) * self.scale #将k的维度从(B,8,N,c1)转变为(B,8,c1,N),其实就是对单头key矩阵转置,使Query*key^T,得到scores结果,然后×self.scale,即×C1的-0.5。 #就是做一个归一化。乘一个参数,不一定是这个值。 #维度为(B,8,N,N) attn = attn.softmax(dim = -1) #对归一化的scores做softmax # 维度为(B,8,N,N) x = (attn @ v).transpose(1,2).reshape(B,N,C) #将scores与values矩阵相乘,得到数据维度为(B,8,N,C1),使用transpose将其维度转换为(B,N,8,C1) x = self.proj(x) #做一个全连接 x=self.proj_drop(x) #做Dropout return x #得到多头注意力所有头拼接的结果。
参考:
https://blog.csdn.net/ningyanggege/article/details/89812558
https://blog.csdn.net/jerry_liufeng/article/details/123054063
https://blog.csdn.net/S10xuexi/article/details/122972766
https://zhuanlan.zhihu.com/p/376122835?ivk_sa=1024320u
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。