赞
踩
参考网址:知乎
Q:Query
K:Key
V:Value
其实是三个矩阵,矩阵如果表示为LxD,L是句子中词的个数,D是嵌入维度,在自注意力机制里,QKV是表示同一个句子的矩阵,否则KV一般是来自一个句子,而Q来自其他句子
我们直接用torch实现一个SelfAttention来说一说:
class BertSelfAttention(nn.Module):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
可以通过这三个线性变换query,key,value得到我们想要的QKV,其中三个变换的输入都是768维,输出都是768维
将该矩阵输入上面的三个线性转换,就可以得到三个矩阵KQV,(6x768)X(768x768)=(6x768)
,维度其实没有改变。
代码表示为:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
拿自注意力来举例(QKV都是同一个句子的矩阵)
① 首先是Q和K矩阵乘,(L, 768)*(L, 768)的转置=(L,L)
,看图:
最后得到(LxL)
的矩阵,其中图中蓝色圈圈代表的就是“我”对“我”的注意力值,其他位置的值亦然。
② 然后是除以根号dim,这个dim就是768,至于为什么要除以这个数值?主要是为了缩小点积范围,确保softmax梯度稳定性,再用softmax进行归一化操作(一种解释是为了保证注意力权重的非负性,同时增加非线性)
③ 然后就是刚才的注意力权重和V矩阵乘了,如图:
首先是“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重,和V中“我想吃酸菜鱼”里面每个字的第一维特征进行相乘再求和,这个过程其实就相当于用每个字的权重对每个字的特征进行加权求和,然后再用“我”这个字对对“我想吃酸菜鱼”这句话里面每个字的注意力权重和V中“我想吃酸菜鱼”里面每个字的第二维特征进行相乘再求和,依次类推最终也就得到了(L,768)的结果矩阵,和输入保持一致
代码:
class BertSelfAttention(nn.Module): def __init__(self, config): self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768 self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768 self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768 def forward(self,hidden_states): # hidden_states 维度是(L, 768) Q = self.query(hidden_states) K = self.key(hidden_states) V = self.value(hidden_states) attention_scores = torch.matmul(Q, K.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = nn.Softmax(dim=-1)(attention_scores) out = torch.matmul(attention_probs, V) return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。