当前位置:   article > 正文

多头注意力_什么 多头注意力

什么 多头注意力
  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

  • 基于适当的张量操作,可以实现多头注意力的并行计算。

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的ℎ组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这ℎ组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这ℎ个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (Vaswani et al., 2017)。 对于ℎ个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

1.模型

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

pip install mxnet==1.7.0.post1
pip install d2l==0.15.0
  1. import math
  2. from mxnet import autograd, np, npx
  3. from mxnet.gluon import nn
  4. from d2l import mxnet as d2l
  5. npx.set_np()

2. 实现

  1. #@save
  2. class MultiHeadAttention(nn.Block):
  3. """多头注意力"""
  4. def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
  5. **kwargs):
  6. super(MultiHeadAttention, self).__init__(**kwargs)
  7. self.num_heads = num_heads
  8. self.attention = d2l.DotProductAttention(dropout)
  9. self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  10. self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  11. self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  12. self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  13. def forward(self, queries, keys, values, valid_lens):
  14. # queries,keys,values的形状:
  15. # (batch_size,查询或者“键-值”对的个数,num_hiddens)
  16. # valid_lens 的形状:
  17. # (batch_size,)或(batch_size,查询的个数)
  18. # 经过变换后,输出的queries,keys,values 的形状:
  19. # (batch_size*num_heads,查询或者“键-值”对的个数,
  20. # num_hiddens/num_heads)
  21. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  22. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  23. values = transpose_qkv(self.W_v(values), self.num_heads)
  24. if valid_lens is not None:
  25. # 在轴0,将第一项(标量或者矢量)复制num_heads次,
  26. # 然后如此复制第二项,然后诸如此类。
  27. valid_lens = valid_lens.repeat(self.num_heads, axis=0)
  28. # output的形状:(batch_size*num_heads,查询的个数,
  29. # num_hiddens/num_heads)
  30. output = self.attention(queries, keys, values, valid_lens)
  31. # output_concat的形状:(batch_size,查询的个数,num_hiddens)
  32. output_concat = transpose_output(output, self.num_heads)
  33. return self.W_o(output_concat)

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

  1. #@save
  2. def transpose_qkv(X, num_heads):
  3. """为了多注意力头的并行计算而变换形状"""
  4. # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  5. # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
  6. # num_hiddens/num_heads)
  7. X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
  8. # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
  9. # num_hiddens/num_heads)
  10. X = X.transpose(0, 2, 1, 3)
  11. # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
  12. # num_hiddens/num_heads)
  13. return X.reshape(-1, X.shape[2], X.shape[3])
  14. #@save
  15. def transpose_output(X, num_heads):
  16. """逆转transpose_qkv函数的操作"""
  17. X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
  18. X = X.transpose(0, 2, 1, 3)
  19. return X.reshape(X.shape[0], X.shape[1], -1)
'
运行

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

  1. num_hiddens, num_heads = 100, 5
  2. attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
  3. attention.initialize()
  4. batch_size, num_queries = 2, 4
  5. num_kvpairs, valid_lens = 6, np.array([3, 2])
  6. X = np.ones((batch_size, num_queries, num_hiddens))
  7. Y = np.ones((batch_size, num_kvpairs, num_hiddens))
  8. attention(X, Y, Y, valid_lens).shape
(2, 4, 100)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/817473
推荐阅读
相关标签
  

闽ICP备14008679号