赞
踩
多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
基于适当的张量操作,可以实现多头注意力的并行计算。
在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。
为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的ℎ组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这ℎ组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这ℎ个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (Vaswani et al., 2017)。 对于ℎ个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。 图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。
基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。
pip install mxnet==1.7.0.post1
pip install d2l==0.15.0
- import math
- from mxnet import autograd, np, npx
- from mxnet.gluon import nn
- from d2l import mxnet as d2l
-
- npx.set_np()
- #@save
- class MultiHeadAttention(nn.Block):
- """多头注意力"""
- def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
- **kwargs):
- super(MultiHeadAttention, self).__init__(**kwargs)
- self.num_heads = num_heads
- self.attention = d2l.DotProductAttention(dropout)
- self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
- self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
- self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
- self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
-
- def forward(self, queries, keys, values, valid_lens):
- # queries,keys,values的形状:
- # (batch_size,查询或者“键-值”对的个数,num_hiddens)
- # valid_lens 的形状:
- # (batch_size,)或(batch_size,查询的个数)
- # 经过变换后,输出的queries,keys,values 的形状:
- # (batch_size*num_heads,查询或者“键-值”对的个数,
- # num_hiddens/num_heads)
- queries = transpose_qkv(self.W_q(queries), self.num_heads)
- keys = transpose_qkv(self.W_k(keys), self.num_heads)
- values = transpose_qkv(self.W_v(values), self.num_heads)
-
- if valid_lens is not None:
- # 在轴0,将第一项(标量或者矢量)复制num_heads次,
- # 然后如此复制第二项,然后诸如此类。
- valid_lens = valid_lens.repeat(self.num_heads, axis=0)
-
- # output的形状:(batch_size*num_heads,查询的个数,
- # num_hiddens/num_heads)
- output = self.attention(queries, keys, values, valid_lens)
-
- # output_concat的形状:(batch_size,查询的个数,num_hiddens)
- output_concat = transpose_output(output, self.num_heads)
- return self.W_o(output_concat)
为了能够使多个头并行计算, 上面的MultiHeadAttention
类将使用下面定义的两个转置函数。 具体来说,transpose_output
函数反转了transpose_qkv
函数的操作。
- #@save
- def transpose_qkv(X, num_heads):
- """为了多注意力头的并行计算而变换形状"""
- # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
- # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
- # num_hiddens/num_heads)
- X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
-
- # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
- # num_hiddens/num_heads)
- X = X.transpose(0, 2, 1, 3)
-
- # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
- # num_hiddens/num_heads)
- return X.reshape(-1, X.shape[2], X.shape[3])
-
-
- #@save
- def transpose_output(X, num_heads):
- """逆转transpose_qkv函数的操作"""
- X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
- X = X.transpose(0, 2, 1, 3)
- return X.reshape(X.shape[0], X.shape[1], -1)
'运行
下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention
类。 多头注意力输出的形状是(batch_size
,num_queries
,num_hiddens
)。
- num_hiddens, num_heads = 100, 5
- attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
- attention.initialize()
-
- batch_size, num_queries = 2, 4
- num_kvpairs, valid_lens = 6, np.array([3, 2])
- X = np.ones((batch_size, num_queries, num_hiddens))
- Y = np.ones((batch_size, num_kvpairs, num_hiddens))
- attention(X, Y, Y, valid_lens).shape
(2, 4, 100)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。