当前位置:   article > 正文

Transformer 02:多头注意力机制的工作原理_多头注意力机制怎么实现的

多头注意力机制怎么实现的

上一篇博文介绍了自注意力机制原理,本文介绍多头注意力机制的工作原理,最后会附上代码示例,通过代码应用自注意力机制模块的步骤。

多头注意力机制是Transformer架构中的一个关键创新,它允许模型在不同的表示子空间中并行地学习输入数据的不同方面。这种机制增加了模型的灵活性和能力,使其能够捕捉到更复杂的特征关系。多头注意力机制的核心思想是将注意力操作分拆成多个“头”,每个头独立地进行注意力计算,然后将这些计算的结果合并起来。

1. 分割嵌入向量

首先,输入的嵌入向量被分割成多个较小的部分,每个部分对应一个注意力“头”。分割后的向量有更低的维度,这允许模型在更细粒度上学习数据的表示

多头注意力机制中分割嵌入向量的步骤是实现多头注意力核心功能的基础,它允许模型在多个不同的表示子空间中并行处理信息。这一步骤涉及将每个输入向量(比如词嵌入向量)分割成多个部分,每个部分对应一个注意力头。

1.1 分割嵌入向量的步骤

假设输入的嵌入向量维度为 d_{model},注意力头数为 h,则每个头处理的向量维度为 d_{k} = \tfrac{d_{model}}{h}。分割嵌入向量的具体步骤如下:

  1. 准备输入向量:首先,我们有输入向量X,其维度为 (N,L,d_{model})​,其中 N 是批次大小,L是序列长度。

  2. 应用线性变换:对输入向量 X 应用三个不同的线性变换(全连接层),分别生成查询(Q)、键(K)和值(V)向量。每个线性变换的权重矩阵维度为 (d_{model}, d_{model})

  3. 分割向量:将线性变换后的查询、键和值向量分割成 h个部分,每部分的维度为 (N,L,d_{k})。这一步通常通过调整张量的形状来实现。

1.2 数学方法解释

1. 线性变换:对于查询、键和值的生成,使用线性变换(全连接层)的数学表达式可以表示为:

Q = XW_{Q}          (1)       

K = XW^{K}        (2)

V = XW^{V}         (3)

W^{Q},W^{K},W^{V}分别代表查询、键、值的权重矩阵。

2. 张量重塑:为了实现多头处理,需要将W^{Q},W^{K},W^{V}重塑为(N,L,h,d_{k})的形状,其中,h是注意力头数,d_{k}是向量维度。然后在进行点积注意力计算之前,将批次大小和序列长度的维度合并,视为一个维度处理。

3. 点积注意力计算:在每个头上,使用缩放点积注意力计算公式,对于每个头 i,其计算可以表示为:Attention(Q_{i}, K_{i}, V_{i}) = softmax(\tfrac{Q_{i}K_{i}^{T}}{\sqrt{d_{k}}})V_{i}        (4)

其中,Q_{i},K_{i},V_{i}分别是第i个头的查询、键和值向量。

4. 输出合并:计算完所有头的注意力后,将它们的输出向量在d_{k}​ 维度上拼接起来,再次通过一个线性变换,得到多头注意力机制的最终输出。

通过这些步骤和方法,多头注意力机制能够有效地将输入向量分割成多个部分,让模型能够并行地在多个表示子空间中学习输入数据的不同特征,从而提高了模型处理信息的能力。

2. 独立计算注意力

对于每个头,我们分别计算其查询(Q)、键(K)和值(V)向量,然后进行标准的注意力计算(如之前介绍的自注意力机制)。由于每个头处理的是向量的不同部分,它们能够并行地捕捉到输入数据中不同的特征关系。

2.1 缩放点积注意力

缩放点积注意力是一种计算注意力权重的方法,它使用查询(Q)、键(K)和值(V)向量的点积来确定每个元素对其他元素的影响程度,然后通过缩放来控制梯度的稳定性。具体步骤如下:

1. 计算点积:首先,计算查询向量与所有键向量的点积。这一步骤会为序列中的每个元素生成一个得分(或权重),表示该元素与序列中其他元素的相关性。

给定查询矩阵 Q、键矩阵 K 和值矩阵 V,它们的维度分别为:(N,L_{q},d_{k}),(N,L_{k},d_{k}),(N,L_{v},d_{k}),其中L_{q}L_{k}L_{v}分别代表查询、键和值序列的长度,d_{k} 是每个头处理的维度大小。

2. 缩放:由于点积随着维度的增长而增大,直接使用点积的结果可能会导致梯度消失或爆炸的问题。因此,点积的结果会被缩放,通常是除以键向量维度的平方根\sqrt{d_{k}},以保持梯度的稳定性。

缩放点积注意力可以表示为:Attention(Q,K,V) = softmax(\tfrac{QK^{T}}{\sqrt{d_{k}}})V        (5)

其中,QK^{T}的结果是一个 (N,L_{q},L_{k})维度的矩阵,表示查询和键之间的点积得分;除以\sqrt{d_{k}}是为了缩放,以防止计算结果的梯度过大。softmax 函数是沿着L_{k} 维度应用的,为每个查询生成一个注意力权重分布;最后这些权重用来加权 V,生成输出。
3. 应用Softmax:接下来,使用softmax函数对每个元素的得分进行归一化,得到一个概率分布,表示每个元素对序列中其他元素的注意力权重。

公式(5)的Softmax函数应用于每一行(即对于每个查询,对所有键的点积),公式为:

softmax(x_{i}) = \frac{e^{x_{i}}}{\sum_{j}^{}e^{x_{j}}}        (6)

其中,x_i是点积缩放后的分数,而分母是对所有j(即序列中所有位置的键)进行求和,确保了得到的权重是一个有效的概率分布。

4. 计算加权和:得到了注意力权重之后,下一步是使用这些权重来计算每个头的输出,即通过对值(V)向量进行加权求和。这一步骤聚合了每个头中所有位置的信息,根据权重的不同给予不同的重要性。加权和的计算公式如下:

HeadOutput = AttentionWeights\cdot V        (7)

其中,AttentionWeights是前一步使用Softmax计算得到的注意力权重,V是值向量。这个操作实际上是一个加权求和,其中每个值向量的权重由对应的注意力权重给出。

5. 合并结果:在计算了所有头的输出之后,最后一步是将这些输出合并(通常是拼接)起来,然后可能通过一个额外的线性变换来整合信息,得到多头注意力的最终输出。

通过上述过程,多头注意力机制能够在处理序列数据时考虑到不同的表示子空间,从而捕获更丰富的信息。这种计算方式使得Transformer模型在处理各种复杂任务时具有更强的能力和灵活性。

3. 代码示例

3.1 代码

下面是一个简化版本的Transformer自注意力机制的代码示例,使用Python和PyTorch框架。这段代码将演示如何实现一个自注意力层,包括查询(Q)、键(K)、值(V)的生成和注意力权重的计算。

  1. import torch
  2. class SelfAttention(torch.nn.Module):
  3. def __init__(self, embed_size, heads):
  4. super(SelfAttention, self).__init__()
  5. self.embed_size = embed_size # 嵌入的大小
  6. self.heads = heads # 注意力头的数量
  7. self.head_dim = embed_size // heads # 每个注意力头的维度大小
  8. assert (
  9. self.head_dim * heads == embed_size
  10. ), "Embedding size needs to be divisible by heads" # 确保嵌入大小可以被注意力头数量整除
  11. # 定义值、键、查询的线性变换
  12. self.values = torch.nn.Linear(self.head_dim, self.head_dim, bias=False)
  13. self.keys = torch.nn.Linear(self.head_dim, self.head_dim, bias=False)
  14. self.queries = torch.nn.Linear(self.head_dim, self.head_dim, bias=False)
  15. self.fc_out = torch.nn.Linear(heads * self.head_dim, embed_size) # 最终线性层
  16. def forward(self, values, keys, queries, mask=None):
  17. N = queries.shape[0] # 批次大小
  18. value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
  19. # 分割输入,使其可以并行处理多头注意力
  20. values = values.reshape(N, value_len, self.heads, self.head_dim)
  21. keys = keys.reshape(N, key_len, self.heads, self.head_dim)
  22. queries = queries.reshape(N, query_len, self.heads, self.head_dim)
  23. # 通过线性层获得值、键、查询
  24. values = self.values(values)
  25. keys = self.keys(keys)
  26. queries = self.queries(queries)
  27. # 计算查询和键的点积,得到注意力得分
  28. energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
  29. # 可选:如果提供了掩码,使用掩码来避免注意力机制关注未来的信息
  30. if mask is not None:
  31. energy = energy.masked_fill(mask == 0, float("-1e20"))
  32. # 应用softmax函数得到注意力权重
  33. attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
  34. # 根据注意力权重加权值向量
  35. out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
  36. N, query_len, self.heads * self.head_dim
  37. )
  38. # 通过最终的线性层
  39. out = self.fc_out(out)
  40. return out
  41. if __name__ == '__main__':
  42. # 示例使用
  43. embed_size = 256 # 嵌入维度
  44. heads = 8 # 注意力头数量
  45. sentence = "Hello, this is a test sentence."
  46. tokens = sentence.split() # 这是一个简单的分词方法,仅用于演示
  47. print(f'tokens:\n{tokens}')
  48. print('*' * 100)
  49. # 假设每个标记由一个唯一的整数表示
  50. token_ids = torch.tensor([[i for i in range(len(tokens))]])
  51. print(f'token_ids:\n{token_ids}')
  52. print('*'*100)
  53. # 嵌入层(实际中,通常会使用预训练的嵌入,如Word2Vec或BERT)
  54. embedding = torch.nn.Embedding(len(tokens), embed_size)
  55. input_embeddings = embedding(token_ids)
  56. print(f'input embeddings:\n{input_embeddings}')
  57. print('*' * 100)
  58. # 初始化自注意力层
  59. self_attention = SelfAttention(embed_size, heads)
  60. print(f'self attention model:\n{self_attention}')
  61. print('*' * 100)
  62. # 虚构的掩码(在真实情况下,掩码会防止在解码器中关注未来的标记)
  63. mask = None
  64. # 通过自注意力层进行前向传播
  65. out = self_attention(input_embeddings, input_embeddings, input_embeddings, mask)
  66. print(f'out.shape:\n{out.shape}') # [批次大小, 标记数量, 嵌入维度]

3.2 运行结果

tokens:
['Hello,', 'this', 'is', 'a', 'test', 'sentence.']
****************************************************************************************************
token_ids:
tensor([[0, 1, 2, 3, 4, 5]])
****************************************************************************************************
input embeddings:
tensor([[[ 0.0401,  0.1971,  0.2713,  ...,  1.1898, -0.4191,  0.9417],
         [ 0.0914, -0.0192,  0.5706,  ...,  0.3632,  0.0879, -0.2977],
         [-0.1128,  0.2040, -0.5213,  ...,  0.0395, -0.0725, -0.1217],
         [ 1.1819,  0.2330, -0.0879,  ...,  0.4245, -1.3431,  0.7194],
         [-0.8293, -0.9667, -0.2948,  ...,  0.3293, -0.9030,  1.0991],
         [ 1.8977,  1.2574, -1.0044,  ...,  1.7029,  1.1557, -2.3370]]],
       grad_fn=<EmbeddingBackward0>)
****************************************************************************************************
self attention model:
SelfAttention(
  (values): Linear(in_features=32, out_features=32, bias=False)
  (keys): Linear(in_features=32, out_features=32, bias=False)
  (queries): Linear(in_features=32, out_features=32, bias=False)
  (fc_out): Linear(in_features=256, out_features=256, bias=True)
)
****************************************************************************************************
out.shape:
torch.Size([1, 6, 256])

Process finished with exit code 0

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/365034
推荐阅读
相关标签
  

闽ICP备14008679号