赞
踩
交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。
交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 和 ,然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。
令 和 ,则交叉注意力的计算如下:
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- class CrossAttention(nn.Module):
- def __init__(self, embed_dim, hidden_dim, num_heads):
- super(CrossAttention, self).__init__()
- self.embed_dim = embed_dim
- self.hidden_dim = hidden_dim
- self.num_heads = num_heads
-
- self.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
- self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
- self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
-
- self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim)
-
- def forward(self, query, context):
- """
- query: (batch_size, query_len, embed_dim)
- context: (batch_size, context_len, embed_dim)
- """
- batch_size, query_len, _ = query.size()
- context_len = context.size(1)
-
- # Project input embeddings
- query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)
- key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
- value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
-
- # Transpose to get dimensions (batch_size, num_heads, len, hidden_dim)
- query_proj = query_proj.permute(0, 2, 1, 3)
- key_proj = key_proj.permute(0, 2, 1, 3)
- value_proj = value_proj.permute(0, 2, 1, 3)
-
- # Compute attention scores
- scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)
- attn_weights = F.softmax(scores, dim=-1)
-
- # Compute weighted context
- context = torch.matmul(attn_weights, value_proj)
-
- # Concatenate heads and project output
- context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1)
- output = self.out_proj(context)
-
- return output, attn_weights
-
- # Example usage:
- embed_dim = 512
- hidden_dim = 64
- num_heads = 8
-
- cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads)
-
- # Dummy data
- batch_size = 2
- query_len = 10
- context_len = 20
-
- query = torch.randn(batch_size, query_len, embed_dim)
- context = torch.randn(batch_size, context_len, embed_dim)
-
- output, attn_weights = cross_attention(query, context)
- print(output.size()) # Should be (batch_size, query_len, embed_dim)
- print(attn_weights.size()) # Should be (batch_size, num_heads, query_len, context_len)
CrossAttention
类继承自 nn.Module
,包含初始化函数 __init__
和前向传播函数 forward
。query_proj
, key_proj
, 和 value_proj
,这些层将嵌入向量转换为多头注意力机制所需的维度。out_proj
再投影回原始的嵌入维度。query
和 context
分别通过线性变换层,并重新整形以适应多头注意力机制。Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。