当前位置:   article > 正文

多头注意力(Multi-Head Attention)和交叉注意力(Cross-Attention)是两种常用的注意力机制的原理及区别

交叉注意力

多头注意力和交叉注意力

多头注意力和交叉注意力都是在自注意力机制的基础上发展而来的,它们的主要区别在于注意力矩阵的计算方式不同。以下是它们的原理和区别。

  1. 多头注意力机制

多头注意力(Multi-Head Attention)是一种基于自注意力机制(self-attention)的改进方法。自注意力是一种能够计算出输入序列中每个位置的权重,因此可以很好地处理序列中长距离依赖关系的问题。但在应用中,可能存在多个不同的关注点,因此就需要多个自注意力机制来处理不同的关注点。多头注意力就是在一个输入序列上使用多个自注意力机制,得到多组注意力结果,然后将这些结果进行拼接和线性投影得到最终输出。

多头注意力的优点是能够处理多个关注点的问题,可以较好地处理复杂语义关系。

多头注意力机制在计算注意力矩阵时,将输入张量 X X X 拆分成 h h h 个子张量,每个子张量都是以不同的方式学习到的注意力信息。然后,对于每个子张量,都执行一次自注意力计算,得到一个输出张量 O i O_i Oi。最后,将 h h h 个输出张量拼接在一起,得到最终的输出张量 O O O

具体地,设 X ∈ R n × d X\in \mathbb{R}^{n\times d} XRn×d 为输入张量, Q ∈ R d × d Q\in \mathbb{R}^{d\times d} QRd×d K ∈ R d × d K\in \mathbb{R}^{d\times d} KRd×d V ∈ R d × d V\in \mathbb{R}^{d\times d} VRd×d 分别为学习到的 d d d 维查询、键和值向量, h h h 为头数,具体来说,假设输入序列为 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn,则多头注意力的计算如下:

M u l t i H e a d ( X ) = C o n c a t ( head 1 , … , head h ) W O \mathrm{MultiHead}(X)=\mathrm{Concat}(\text{head}_1,\dots,\text{head}_h)W^O MultiHead(X)=Concat(head1,,headh)WO

其中, Q , K , V Q,K,V Q,K,V分别表示输入序列的queries、keys和values, h e a d i head_i headi表示第 i i i个注意力头, W O W^O WO为线性投影的权重参数。每个注意力头都是通过对 Q , K , V Q,K,V Q,K,V进行自注意力计算得到的,计算公式如下:

Attention ( Q , K , V ) = Softmax ( Q K T d k ) V \textrm{Attention}(Q,K,V) = \textrm{Softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=Softmax(dk QKT)V

其中, d k d_k dk Q , K Q,K Q,K的维度大小,用于缩放注意力的大小。

head i = A t t e n t i o n ( X Q i , X K i , X V i ) , i = 1 , … , h \text{head}_i=\mathrm{Attention}(XQ_i,XK_i,XV_i),\qquad i=1,\dots,h headi=Attention(XQi,XKi,XVi),i=1,,h

Q i = X W i Q , K i = X W i K , V i = X W i V , i = 1 , … , h Q_i=XW_i^Q,\qquad K_i=XW_i^K,\qquad V_i=XW_i^V,\qquad i=1,\dots,h Qi=XWiQ,Ki=XWiK,Vi=XWiV,i=1,,h

上式中, W Q ∈ R d × h q W^Q\in \mathbb{R}^{d\times hq} WQRd×hq W K ∈ R d × h k W^K\in \mathbb{R}^{d\times hk} WKRd×hk W V ∈ R d × h v W^V\in \mathbb{R}^{d\times hv} WVRd×hv W O ∈ R h d × d W^O\in \mathbb{R}^{hd\times d} WORhd×d 分别为学习到的投影矩阵, q , k , v q,k,v q,k,v 分别为查询、键、值向量的维度。

  1. 交叉注意力机制

交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。

交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 X 1 ∈ R n × d 1 X_1\in\mathbb{R}^{n\times d_1} X1Rn×d1 X 2 ∈ R n × d 2 X_2\in\mathbb{R}^{n\times d_2} X2Rn×d2,然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为 n × d 2 n\times d_2 n×d2 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。

具体地,令 Q = X 1 W Q Q=X_1W^Q Q=X1WQ K = V = X 2 W K K=V=X_2W^K K=V=X2WK,则交叉注意力的计算如下:

C r o s s A t t e n t i o n ( X 1 , X 2 ) = S o f t m a x ( Q K T d 2 ) V \mathrm{CrossAttention}(X_1,X_2)=\mathrm{Softmax}\left(\frac{QK^T}{\sqrt{d_2}}\right)V CrossAttention(X1,X2)=Softmax(d2 QKT)V

其中, W Q ∈ R d 1 × d k W^Q\in\mathbb{R}^{d_1\times d_k} WQRd1×dk W K ∈ R d 2 × d k W^K\in\mathbb{R}^{d_2\times d_k} WKRd2×dk 是学习到的投影矩阵, d k d_k dk 为键值集合的维度(也是查询集合的维度)。

  1. PyTorch 代码示例

以下是一个使用pytorch实现多头注意力和交叉注意力的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, k_dim, v_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.k_dim = k_dim
        self.v_dim = v_dim
        
        # 定义线性投影层,用于将输入变换到多头注意力空间
        self.proj_q = nn.Linear(in_dim, k_dim * num_heads, bias=False)
        self.proj_k = nn.Linear(in_dim, k_dim * num_heads, bias=False)
        self.proj_v = nn.Linear(in_dim, v_dim * num_heads, bias=False)
		# 定义多头注意力的线性输出层
        self.proj_o = nn.Linear(v_dim * num_heads, in_dim)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, in_dim = x.size()
        # 对输入进行线性投影, 将每个头的查询、键、值进行切分和拼接
        q = self.proj_q(x).view(batch_size, seq_len, self.num_heads, self.k_dim).permute(0, 2, 1, 3)
        k = self.proj_k(x).view(batch_size, seq_len, self.num_heads, self.k_dim).permute(0, 2, 3, 1)
        v = self.proj_v(x).view(batch_size, seq_len, self.num_heads, self.v_dim).permute(0, 2, 1, 3)
        # 计算注意力权重和输出结果
        attn = torch.matmul(q, k) / self.k_dim**0.5   # 注意力得分
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(attn, dim=-1)   # 注意力权重参数
        output = torch.matmul(attn, v).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, -1)   # 输出结果
        # 对多头注意力输出进行线性变换和输出
        output = self.proj_o(output)
        
        return output

class CrossAttention(nn.Module):
    def __init__(self, in_dim1, in_dim2, k_dim, v_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.k_dim = k_dim
        self.v_dim = v_dim
        
        self.proj_q1 = nn.Linear(in_dim1, k_dim * num_heads, bias=False)
        self.proj_k2 = nn.Linear(in_dim2, k_dim * num_heads, bias=False)
        self.proj_v2 = nn.Linear(in_dim2, v_dim * num_heads, bias=False)
        self.proj_o = nn.Linear(v_dim * num_heads, in_dim1)
        
    def forward(self, x1, x2, mask=None):
        batch_size, seq_len1, in_dim1 = x1.size()
        seq_len2 = x2.size(1)
        
        q1 = self.proj_q1(x1).view(batch_size, seq_len1, self.num_heads, self.k_dim).permute(0, 2, 1, 3)
        k2 = self.proj_k2(x2).view(batch_size, seq_len2, self.num_heads, self.k_dim).permute(0, 2, 3, 1)
        v2 = self.proj_v2(x2).view(batch_size, seq_len2, self.num_heads, self.v_dim).permute(0, 2, 1, 3)
        
        attn = torch.matmul(q1, k2) / self.k_dim**0.5
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(attn, dim=-1)
        output = torch.matmul(attn, v2).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len1, -1)
        output = self.proj_o(output)
        
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67

其中,MultiHeadAttentionCrossAttention 分别实现了多头注意力和交叉注意力,可以通过调用它们的 forward 方法来进行注意力计算。对于MultiHeadAttention,它需要输入一个三维张量x,代表(batch_size, seq_len, in_dim),其中batch_size是批量大小,seq_len是序列长度,in_dim是输入维度。在类的初始化函数中,我们需要设置注意力头数num_heads,以及每个注意力头的k维度k_dim和v维度v_dim。在前向传播函数中,我们首先使用全连接层将x投影到k_dim * num_heads和v_dim * num_heads的空间,再将结果按注意力头和k_dim/v_dim维度进行调整。接着计算注意力得分,若有mask需要进行处理。注意力计算完成后,将得分与v相乘,再将结果按注意力头和v_dim维度进行调整,最终使用全连接层将结果投影会in_dim维度,即可得到多头注意力计算的结果。

对于CrossAttention,它需要输入两个三维张量x1和x2,分别代表(batch_size, seq_len1, in_dim1)和(batch_size, seq_len2, in_dim2),其中seq_len1和seq_len2可以不相同。在类的初始化函数中,我们需要设置注意力头数num_heads,以及每个注意力头的k维度k_dim和v维度v_dim。在前向传播函数中,我们使用全连接层将x1投影到k_dim * num_heads维度,x2投影到k_dim * num_heads和v_dim * num_heads的空间,再将结果按注意力头和k_dim/v_dim维度进行调整。接着计算注意力得分,若有mask需要进行处理。注意力计算完成后,将得分与v2相乘,再将结果按注意力头和v_dim维度进行调整,最终使用全连接层将结果投影会in_dim1维度,即可得到交叉注意力计算的结果。

mask是一个掩码,mask代表的是输入序列中需要被mask掉的部分,用于在注意力计算中屏蔽某些值。在自注意力中,mask通常用于避免当前位置看到后面的位置,以避免信息泄露。在交叉注意力中,除了要避免当前位置看到后面的位置外,还要避免查询向量和键向量在跨输入序列计算注意力时跨越了不同部分。因此,在计算注意力权重时,将不需要计算的位置处的权重设为负无穷,以使其在softmax函数中变为0。在代码中,mask需要被unsqueeze(1)以便进行广播操作。

如果传入了mask参数,会将为0的位置的得分设置为-1e9,使得在softmax之后对应位置的权重为0。例如,如果输入是一个batch_size为2、序列长度为3的矩阵,其中第二个序列的第一个位置和第三个位置是padding,那么对应的mask可以设置为:

mask = torch.tensor([
    [1, 1, 1],
    [1, 0, 0]
], dtype=torch.bool)
  • 1
  • 2
  • 3
  • 4
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/531413
推荐阅读
相关标签
  

闽ICP备14008679号