当前位置:   article > 正文

交叉注意力机制CrossAttention

交叉注意力机制

CrossAttention

在Transformer中,CrossAttention实际上是指编码器和解码器之间的交叉注意力层。在这一层中,解码器会对编码器的输出进行注意力调整,以获得与当前解码位置相关的编码器信息。在Transformer的编码器-解码器架构中,编码器负责将输入序列编码为一系列特征向量,而解码器则根据这些特征向量逐步生成输出序列。为了使解码器能够对当前生成位置的上下文进行有效的建模,CrossAttention层被引入其中。

CrossAttention的计算过程:

  • 编码器输入(通常是来自编码器的输出):它们通常被表示为enc_inputs,大小为(batch_size, seq_len_enc, hidden_dim)。
  • 解码器的输入(已生成的部分序列):它们通常被表示为dec_inputs,大小为(batch_size, seq_len_dec, hidden_dim)。
  • 解码器的每个位置会生成一个查询向量(query),用来在编码器的所有位置进行注意力权重计算。
  • 编码器的所有位置会生成一组键向量(keys)和值向量(values)。
  • 使用查询向量(query)和键向量(keys)进行点积操作,并通过softmax函数获得注意力权重。
  • 注意力权重与值向量相乘,并对结果进行求和,得到编码器调整的输出。

image.png

 

torch.matmul

参数情况:torch.matmul(input, other, *, out=None) → Tensor

  • input (张量) – 第一个要乘法的张量
  • other(张量)– 要乘法的第二个张量

例子:

  1. tensor1 = torch.randn(10, 3, 4)
  2. tensor2 = torch.randn(4, 5)
  3. torch.matmul(tensor1, tensor2).size()
  4. torch.Size([10, 3, 5])

代码案例

为了方便理解,此代码只是定义了一个简单的带有线性映射的注意力模型,并没有完整地实现Transformer中的CrossAttention层。如果您想实现Transformer的CrossAttention层,请参考Transformer的详细实现代码或使用现有的Transformer库(如torch.nn.Transformer)来构建模型。

  1. import torch
  2. import torch.nn as nn
  3. class CrossAttention(nn.Module):
  4. def __init__(self, input_dim_a, input_dim_b, hidden_dim):
  5. super(CrossAttention, self).__init__()
  6. self.linear_a = nn.Linear(input_dim_a, hidden_dim)
  7. self.linear_b = nn.Linear(input_dim_b, hidden_dim)
  8. def forward(self, input_a, input_b):
  9. # 线性映射
  10. mapped_a = self.linear_a(input_a) # (batch_size, seq_len_a, hidden_dim)
  11. mapped_b = self.linear_b(input_b) # (batch_size, seq_len_b, hidden_dim)
  12. y = mapped_b.transpose(1, 2)
  13. # 计算注意力权重
  14. scores = torch.matmul(mapped_a, mapped_b.transpose(1, 2)) # (batch_size, seq_len_a, seq_len_b)
  15. attentions_a = torch.softmax(scores, dim=-1) # 在维度2上进行softmax,归一化为注意力权重 (batch_size, seq_len_a, seq_len_b)
  16. attentions_b = torch.softmax(scores.transpose(1, 2), dim=-1) # 在维度1上进行softmax,归一化为注意力权重 (batch_size, seq_len_b, seq_len_a)
  17. # 使用注意力权重来调整输入表示
  18. output_a = torch.matmul(attentions_b, input_b) # (batch_size, seq_len_a, input_dim_b)
  19. output_b = torch.matmul(attentions_a.transpose(1, 2), input_a) # (batch_size, seq_len_b, input_dim_a)
  20. return output_a, output_b
  21. # 准备数据
  22. input_a = torch.randn(16, 36, 192) # 输入序列A,大小为(batch_size, seq_len_a, input_dim_a)
  23. input_b = torch.randn(16, 192, 36) # 输入序列B,大小为(batch_size, seq_len_b, input_dim_b)
  24. # 定义模型
  25. input_dim_a = input_a.shape[-1]
  26. input_dim_b = input_b.shape[-1]
  27. hidden_dim = 64
  28. cross_attention = CrossAttention(input_dim_a, input_dim_b, hidden_dim)
  29. # 前向传播
  30. output_a, output_b = cross_attention(input_a, input_b)
  31. print("Adjusted output A:\n", output_a)
  32. print("Adjusted output B:\n", output_b)

CrossAttention模块输入的要求

编码器输入:

  • 形状:(batch_size, seq_len_enc, hidden_dim)
  • batch_size:批量大小
  • seq_len_enc:编码器输入序列的长度
  • hidden_dim:编码器的隐藏维度或特征维度

解码器输入:

  • 形状:(batch_size, seq_len_dec, hidden_dim)
  • batch_size:批量大小
  • seq_len_dec:解码器输入序列的长度
  • hidden_dim:解码器的隐藏维度或特征维度

对于案例中来说,编码器和解码器的输入维度并不需要完全相同,CrossAttention输入参数是一个三元组(input_a, input_b, hidden_dim),其中input_a表示编码器的输入,input_b表示解码器的输入,hidden_dim表示隐藏维度。

对于input_a和input_b的形状,它们可以有一定的差异,只要满足以下条件之一即可:

1、当input_a和input_b形状不同但维度相同(hidden_dim相同)时,可以通过一些线性变换将它们映射到相同的维度。

2、当input_a和input_b形状不同且维度也不同时,可以通过不同的注意力权重矩阵来分别对它们进行映射和计算注意力。

而Encoder-Decoder架构中CrossAttention的输入要求略有不同。具体而言,Encoder中的输入(input_a)形状通常是(batch_size, seq_len_enc, hidden_dim),而Decoder中的输入(input_b)形状通常是(batch_size, seq_len_dec, hidden_dim),其中seq_len_enc和seq_len_dec可以是不同的。

两篇涉及Cross-Attention的论文

论文链接地址: Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation 

代码地址: GitHub - MGheini/xattn-transfer-for-mt: Code and data to accompany the camera-ready version of "Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation" in EMNLP 2021

论文链接地址:CAT: Cross Attention in Vision Transformer 

代码地址:https://github.com/linhezheng19/CAT.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/357303
推荐阅读
相关标签
  

闽ICP备14008679号