当前位置:   article > 正文

cross attention交叉熵注意力机制

交叉熵注意力

        交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。

        交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 X1\epsilon R^{n*d1}  和 X2\epsilon R^{n*d2},然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为n*d2 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。

Q=X_{1} W^{Q} 和 K=V=X_{2} W^{K},则交叉注意力的计算如下:

\operatorname{CrossAttention}\left(X_{1}, X_{2}\right)=\operatorname{Softmax}\left(\frac{Q K^{T}}{\sqrt{d_{2}}}\right) V

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class CrossAttention(nn.Module):
  5. def __init__(self, embed_dim, hidden_dim, num_heads):
  6. super(CrossAttention, self).__init__()
  7. self.embed_dim = embed_dim
  8. self.hidden_dim = hidden_dim
  9. self.num_heads = num_heads
  10. self.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
  11. self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
  12. self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
  13. self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim)
  14. def forward(self, query, context):
  15. """
  16. query: (batch_size, query_len, embed_dim)
  17. context: (batch_size, context_len, embed_dim)
  18. """
  19. batch_size, query_len, _ = query.size()
  20. context_len = context.size(1)
  21. # Project input embeddings
  22. query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)
  23. key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
  24. value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
  25. # Transpose to get dimensions (batch_size, num_heads, len, hidden_dim)
  26. query_proj = query_proj.permute(0, 2, 1, 3)
  27. key_proj = key_proj.permute(0, 2, 1, 3)
  28. value_proj = value_proj.permute(0, 2, 1, 3)
  29. # Compute attention scores
  30. scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)
  31. attn_weights = F.softmax(scores, dim=-1)
  32. # Compute weighted context
  33. context = torch.matmul(attn_weights, value_proj)
  34. # Concatenate heads and project output
  35. context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1)
  36. output = self.out_proj(context)
  37. return output, attn_weights
  38. # Example usage:
  39. embed_dim = 512
  40. hidden_dim = 64
  41. num_heads = 8
  42. cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads)
  43. # Dummy data
  44. batch_size = 2
  45. query_len = 10
  46. context_len = 20
  47. query = torch.randn(batch_size, query_len, embed_dim)
  48. context = torch.randn(batch_size, context_len, embed_dim)
  49. output, attn_weights = cross_attention(query, context)
  50. print(output.size()) # Should be (batch_size, query_len, embed_dim)
  51. print(attn_weights.size()) # Should be (batch_size, num_heads, query_len, context_len)
  1. 类定义CrossAttention 类继承自 nn.Module,包含初始化函数 __init__ 和前向传播函数 forward
  2. 初始化
    • 定义了一些线性变换层:query_proj, key_proj, 和 value_proj,这些层将嵌入向量转换为多头注意力机制所需的维度。
    • 最终的输出通过 out_proj 再投影回原始的嵌入维度。
  3. 前向传播
    • 输入的 querycontext 分别通过线性变换层,并重新整形以适应多头注意力机制。
    • 计算注意力分数,并通过 softmax 得到注意力权重。
    • 利用注意力权重加权上下文向量,得到新的上下文表示。
    • 最后将多头的结果合并,并通过输出投影层得到最终的输出。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/912161
推荐阅读
相关标签
  

闽ICP备14008679号