当前位置:   article > 正文

注意力机制——Multi-Head Attention(MHA)_多头注意力机制公式

多头注意力机制公式

Multi-Head Attention(MHA)MHA是一种多头注意力模型,将注意力机制扩展到多个头,从而增强模型对于不同特征的关注度。

MHA 的输入包括三个向量:查询向量(query)、键向量(key)和值向量(value)。对于一个给定的查询向量,MHA 会对键向量进行加权求和,权重由查询向量和键向量之间的相似度计算得到,然后将得到的加权和乘以值向量进行输出。在计算相似度时,常用的方法是使用点积(dot product)或者是双线性(bilinear)计算。

MHA 的多头机制可以有效提高模型的表达能力,同时也可以使模型学习到更加多样化和复杂的特征。在多头机制下,输入的序列数据会被分成多个头,每个头进行独立的计算,得到不同的输出。这些输出最后被拼接在一起,形成最终的输出。

MHA 的计算可以表示为以下的公式:

 其中 Q, K, V 分别表示查询向量、键向量和值向量,ℎ 表示头的数量,headi​ 表示第 i 个头的输出,WO 是输出变换矩阵。每个头的输出 headi​ 可以表示为:

其中 WiQ​, WiK​, WiV​ 分别是第 i 个头的查询、键、值变换矩阵,AttentionAttention 是注意力计算函数。在 MHA 中,一般使用自注意力机制(Self-Attention)来计算注意力。

自注意力机制的计算可以表示为以下的公式:

其中 dk​ 是键向量的维度,softmax相似度进行归一化,将每个键向量的权重计算出来,然后将权重乘以值向量,最后进行加权求和得到注意力输出。

MHA 的流程可以总结为以下几步:

  1. 将输入的序列数据分成多个头;
  2. 对每个头进行独立的查询、键、值线性变换;
  3. 对每个头进行自注意力计算,得到该头的输出;
  4. 将所有头的输出拼接在一起,并进行输出线性变换。

 MultiHeadAttention模块pytorch实现:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class MultiHeadAttention(nn.Module):
  5. def __init__(self, d_model, num_heads, dropout=0.1):
  6. super().__init__()
  7. # 初始化模块的属性
  8. self.num_heads = num_heads # 多头注意力头数
  9. self.d_model = d_model # 模型维度
  10. self.depth = d_model // num_heads # 每个头的维度
  11. # 定义权重矩阵
  12. self.Wq = nn.Linear(d_model, d_model)
  13. self.Wk = nn.Linear(d_model, d_model)
  14. self.Wv = nn.Linear(d_model, d_model)
  15. # 定义最终的线性层
  16. self.fc = nn.Linear(d_model, d_model)
  17. # 定义dropout层
  18. self.dropout = nn.Dropout(p=dropout)
  19. def scaled_dot_product_attention(self, Q, K, V, mask=None):
  20. # 计算注意力得分
  21. scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
  22. # 如果存在掩码,应用掩码
  23. if mask is not None:
  24. scores += mask * -1e9
  25. # 计算softmax
  26. attention = F.softmax(scores, dim=-1)
  27. # 应用dropout
  28. attention = self.dropout(attention)
  29. # 将注意力得分乘以value向量
  30. output = torch.matmul(attention, V)
  31. return output, attention
  32. def forward(self, Q, K, V, mask=None):
  33. batch_size = Q.size(0)
  34. # 线性投影
  35. Q = self.Wq(Q).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
  36. K = self.Wk(K).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
  37. V = self.Wv(V).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
  38. # Scaled Dot-Product Attention
  39. scores, attention = self.scaled_dot_product_attention(Q, K, V, mask)
  40. concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
  41. # 最终的线性投影
  42. output = self.fc(concat)
  43. return output, attention

在这个实现中,我们首先定义了一个MultiHeadAttention类,它继承了nn.Module,并包含以下属性:

  • num_heads:多头注意力中头的数量。
  • d_model:输入向量的维度。
  • depth:每个头的向量维度,即dmodel​/num_heads。
  • WqWkWv:输入的Q、K、V向量分别通过这些线性层进行转换。
  • fc:输出向量通过这个线性层进行转换。

 

scaled_dot_product_attention 函数是 MultiHeadAttention 模块中的一个重要方法,实现了Scaled Dot-Product Attention操作。

该函数的输入参数 QK 和 V 分别表示查询向量、键向量和值向量,均为张量。其中,Q 和 K 的 shape 为 (batch_size, num_heads, seq_len, depth),而 V 的 shape 为 (batch_size, num_heads, seq_len, depth),其中 seq_len 代表序列长度,depth 代表每个头的维度。

将 MultiHeadAttention 模块添加到 PyTorch 模型示例:

  1. import torch.nn as nn
  2. class MyModel(nn.Module):
  3. def __init__(self):
  4. super(MyModel, self).__init__()
  5. # 定义一个 MultiHeadAttention 模块
  6. self.attn = nn.MultiheadAttention(embed_dim=512, num_heads=8)
  7. # 其他模型层的定义
  8. self.linear1 = nn.Linear(512, 256)
  9. self.relu = nn.ReLU()
  10. self.linear2 = nn.Linear(256, 10)
  11. def forward(self, x):
  12. # 输入 x 的 shape 为 (batch_size, seq_len, input_dim)
  13. # 将输入张量转换为 (seq_len, batch_size, input_dim) 的形式
  14. x = x.permute(1, 0, 2)
  15. # 使用 MultiHeadAttention 模块进行注意力机制
  16. attn_output, attn_weights = self.attn(x, x, x)
  17. # 将输出张量转换回 (batch_size, seq_len, input_dim) 的形式
  18. attn_output = attn_output.permute(1, 0, 2)
  19. # 其他模型层的计算
  20. x = self.linear1(attn_output)
  21. x = self.relu(x)
  22. x = self.linear2(x)
  23. return x

在上面的示例中,我们定义了一个名为 MyModel 的 PyTorch 模型,其中包含一个 MultiHeadAttention 模块。在模型的 __init__ 方法中,我们创建了一个 MultiHeadAttention 实例,并将其存储在模型中的 self.attn 属性中。在模型的 forward 方法中,我们将输入张量 x 转换为 (seq_len, batch_size, input_dim) 的形式,并使用 self.attn 对其进行注意力机制。注意力机制的输出为元组 (attn_output, attn_weights),其中 attn_output 表示注意力机制的输出张量,attn_weights 表示注意力权重。最后,我们将 attn_output 转换回 (batch_size, seq_len, input_dim) 的形式,并将其输入到其他模型层中。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/346687
推荐阅读
相关标签
  

闽ICP备14008679号