当前位置:   article > 正文

深度学习中常用的注意力模块及其原理和作用

注意力模块

以下是深度学习中常用的注意力模块及其原理和作用,以及相应的PyTorch代码示例。

1. Scaled Dot-Product Attention

Scaled Dot-Product Attention 是注意力机制的一种变体,常用于 Seq2Seq 模型和 Transformer 模型中。它通过计算 Query 和 Key 的内积,再除以一个 scaling factor 得到 Attention 分数,最后将 Attention 分数作为权重对 Value 做加权求和,来计算 Attention 输出。

Scaled Dot-Product Attention 的公式如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中 Q , K , V Q, K, V Q,K,V 分别表示 Query, Key, Value, d k d_k dk 表示 Key 的维度,softmax 函数对每个 Query 计算一个 Attention Distribution。

PyTorch 实现代码示例:

import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk):
        super(ScaledDotProductAttention, self).__init__()
        self.dk = dk

    def forward(self, Q, K, V):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dk ** 0.5)
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        return output

# 使用 Scaled Dot-Product Attention
q = torch.randn(2, 3, 4)  # shape: [batch_size, query_len, hidden_size]
k = torch.randn(2, 5, 4)  # shape: [batch_size, key_len, hidden_size]
v = torch.randn(2, 5, 6)  # shape: [batch_size, key_len, value_size]
attention = ScaledDotProductAttention(dk=4)
output = attention(q, k, v)  # shape: [batch_size, query_len, value_size]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

2. Multi-Head Attention

Multi-Head Attention 是一种将 Scaled Dot-Product Attention 扩展到多头的方法,它将 Query, Key, Value 分别经过多个线性变换(称为“头”)后再输入到 Scaled Dot-Product Attention 中计算,最后将多个 Attention 输出按照通道维度拼接起来。Multi-Head Attention 可以学习多种不同的表示,来提升模型的表现能力。

Multi-Head Attention 的公式如下:

MultiHead ( Q , K , V ) = Concat ( h e a d 1 , . . . , h e a d h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,...,head_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中 h e a d i = Attention ( Q W i Q , K W i K , V W i V ) head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV) 表示第 i i i 个头, W O W^O WO 表示最终输出的线性变换。

PyTorch 实现代码示例:

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dk = hidden_size // num_heads

        # 定义 W^Q, W^K, W^V 矩阵
        self.Wq = nn.Linear(hidden_size, hidden_size)
        self.Wk = nn.Linear(hidden_size, hidden_size)
        self.Wv = nn.Linear(hidden_size, hidden_size)

        # 定义输出矩阵 W^O
        self.Wo = nn.Linear(hidden_size, hidden_size)

    def forward(self, Q, K, V):
        # 将 Query, Key, Value 分别经过 W^Q, W^K, W^V 线性变换
        Q = self.Wq(Q)
        K = self.Wk(K)
        V = self.Wv(V)

        # 将多个头拼接在一起
        Q = Q.view(Q.shape[0], Q.shape[1], self.num_heads, self.dk).transpose(1, 2)  # [batch_size, num_heads, query_len, dk]
        K = K.view(K.shape[0], K.shape[1], self.num_heads, self.dk).transpose(1, 2)  # [batch_size, num_heads, key_len, dk]
        V = V.view(V.shape[0], V.shape[1], self.num_heads, self.dk).transpose(1, 2)  # [batch_size, num_heads, key_len, dk]

        # 使用 Scaled Dot-Product Attention 计算 Attention 输出
        attn = ScaledDotProductAttention(self.dk)
        output = attn(Q, K, V)  # [batch_size, num_heads, query_len, dk]

        # 将多个头拼接在一起
        output = output.transpose(1, 2).reshape(output.shape[0], output.shape[2], -1)  # [batch_size, query_len, num_heads * dk]

        # 经过输出矩阵 W^O 线性变换
        output = self.Wo(output)  # [batch_size, query_len, hidden_size]
        return output

# 使用 Multi-Head Attention
q = torch.randn(2, 3, 4)  # shape: [batch_size, query_len, hidden_size]
k = torch.randn(2, 5, 4)  # shape: [batch_size, key_len, hidden_size]
v = torch.randn(2, 5, 6)  # shape: [batch_size, key_len, value_size]
attention = MultiHeadAttention(hidden_size=12, num_heads=2)
output = attention(q, k, v)  # shape: [batch_size, query_len, hidden_size]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

3. Self-Attention

Self-Attention 是一种只包含一个输入序列(无需 Key/Value)的注意力模型,它通过将输入序列映射到 Query, Key, Value 向量后,计算向量之间的 Attention 分数,最后将 Attention 输出加权求和得到最终输出。Self-Attention 被广泛应用于自然语言处理任务中,如机器翻译、语言模型等。

Self-Attention 的公式如下:

SelfAttention ( X ) = softmax ( X W Q ( X W K ) T d k ) ( X W V ) \text{SelfAttention}(X) = \text{softmax}(\frac{XW_Q(XW_K)^T}{\sqrt{d_k}})(XW_V) SelfAttention(X)=softmax(dk XWQ(XWK)T)(XWV)

其中 X X X 表示输入序列, W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 分别表示映射到 Query, Key, Value 的线性变换矩阵, d k d_k dk 表示 Key 的维度,softmax 函数对每个 Query 计算一个 Attention Distribution。

PyTorch 实现代码示例:

class SelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super(SelfAttention, self).__init__()
        self.hidden_size = hidden_size
        self.dk = hidden_size

        # 定义 W^Q, W^K, W^V 矩阵
        self.Wq = nn.Linear(hidden_size, hidden_size)
        self.Wk = nn.Linear(hidden_size, hidden_size)
        self.Wv = nn.Linear(hidden_size, hidden_size)

    def forward(self, X):
        # 将输入序列 X 分别经过 W^Q, W^K, W^V 线性变换
        Q = self.Wq(X)
        K = self.Wk(X)
        V = self.Wv(X)

        # 使用 Scaled Dot-Product Attention 计算 Self-Attention 输出
        attn = ScaledDotProductAttention(self.dk)
        output = attn(Q, K, V)

        return output

# 使用 Self-Attention
x = torch.randn(2, 3, 4)  # shape: [batch_size, seq_len, hidden_size]
attention = SelfAttention(hidden_size=4)
output = attention(x)  # shape: [batch_size, seq_len, hidden_size]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

4. Relative Positional Encoding

Relative Positional Encoding 是一种用于自然语言处理任务的注意力机制,它考虑了词语之间的相对位置信息,通过加入位置编码矩阵来改善传统的位置编码方法。相对位置编码矩阵可以在 Transformers 中直接应用于 Self-Attention 和 Multi-Head Attention 等模块中,来提升模型的泛化性能。

Relative Positional Encoding 的公式如下:

Position ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d ) \text{Position}(pos, 2i) = \sin(\frac{pos}{10000^{2i/d}}) Position(pos,2i)=sin(100002i/dpos)

Position ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d ) \text{Position}(pos, 2i+1) = \cos(\frac{pos}{10000^{2i/d}}) Position(pos,2i+1)=cos(100002i/dpos)

其中 p o s pos pos 表示位置索引, d d d 表示 Embedding 的维度, i i i 表示位置编码矩阵中的第 i i i 维, Position ( p o s , i ) \text{Position}(pos, i) Position(pos,i) 表示位置编码矩阵中索引为 ( p o s , i ) (pos, i) (pos,i) 的值。

PyTorch 实现代码示例:

class RelativePositionalEncoding(nn.Module):
    def __init__(self, max_position, hidden_size):
        super(RelativePositionalEncoding, self).__init__()
        self.hidden_size = hidden_size

        # 定义位置编码矩阵
        self.position_encoding = nn.Parameter(torch.zeros(2 * max_position - 1, hidden_size))
        nn.init.normal_(self.position_encoding, mean=0, std=hidden_size ** -0.5)

    def forward(self, q, k):
        # 计算 Query 和 Key 的相对位置
        pos = torch.arange(q.size(-2), device=q.device).unsqueeze(-1) - torch.arange(k.size(-2), device=q.device)

        # 根据相对位置从位置编码矩阵中获取相应的值
        pos_enc = self.position_encoding[self.position_index(pos)].unsqueeze(0)

        # 将位置编码加到 Query, Key 上
        q = q + pos_enc[:, :, :q.size(-2)]
        k = k + pos_enc[:, :, :k.size(-2)]

        return q, k

    def position_index(self, pos):
        # 将相对位置索引映射到位置编码矩阵中的索引
        position_index = pos + self.position_encoding.size(0) // 2
        return torch.clamp(position_index, 0, self.position_encoding.size(0) - 1)

# 使用 Relative Positional Encoding
q = torch.randn(2, 3, 4)  # shape: [batch_size, query_len, hidden_size]
k = torch.randn(2, 5, 4)  # shape: [batch_size, key_len, hidden_size]
pos_enc = RelativePositionalEncoding(max_position=5, hidden_size=4)
q, k = pos_enc(q, k)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

5. Multi-Head Self-Attention with Relative Positional Encoding

结合了 Multi-Head Attention 和 Relative Positional Encoding 的模型常用于自然语言处理任务中,如 Transformer 模型。它在 Multi-Head Self-Attention 中加入相对位置编码矩阵,以考虑词语之间的相对位置关系,从而提高模型性能。

Multi-Head Self-Attention with Relative Positional Encoding 的代码示例:

class MultiHeadSelfAttentionWithRPE(nn.Module):
    def __init__(self, hidden_size, num_heads, max_position):
        super(MultiHeadSelfAttentionWithRPE, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dk = hidden_size // num_heads

        # Multi-Head Attention
        self.Wq = nn.Linear(hidden_size, hidden_size)
        self.Wk = nn.Linear(hidden_size, hidden_size)
        self.Wv = nn.Linear(hidden_size, hidden_size)
        self.Wo = nn.Linear(hidden_size, hidden_size)

        # RPE
        self.pos_enc = RelativePositionalEncoding(max_position, hidden_size)

    def forward(self, X):
        # Multi-Head Attention
        Q = self.Wq(X)
        K = self.Wk(X)
        V = self.Wv(X)
        Q = Q.view(Q.shape[0], Q.shape[1], self.num_heads, self.dk).transpose(1, 2)  # [batch_size, num_heads, seq_len, dk]
        K = K.view(K.shape[0], K.shape[1], self.num_heads, self.dk).transpose(1, 2)  # [batch_size, num_heads, seq_len, dk]
        V = V.view(V.shape[0], V.shape[1], self.num_heads, self.dk).transpose(1, 2)  # [batch_size, num_heads, seq_len, dk]
        attn = ScaledDotProductAttention(self.dk)
        output = attn(Q, K, V)  # [batch_size, num_heads, seq_len, dk]

        # 将多个头拼接在一起
        output = output.transpose(1, 2).reshape(output.shape[0], output.shape[2], -1)  # [batch_size, seq_len, hidden_size]

        # RPE
        output = self.pos_enc(output, output)

        # 输出层
        output = self.Wo(output)

        return output

# 使用 Multi-Head Self-Attention with RPE
x = torch.randn(2, 3, 4)  # shape: [batch_size, seq_len, hidden_size]
attention = MultiHeadSelfAttentionWithRPE(hidden_size=12, num_heads=2, max_position=5)
output = attention(x)  # shape: [batch_size, seq_len, hidden_size]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

6. Cross Attention

交叉注意力(Cross Attention)是指在两个不同的输入序列之间进行注意力计算的过程。它的原理是通过计算两个不同输入序列之间的相似度,让模型能够更好地关注和利用不同输入序列之间的信息,从而提高模型的性能。

在编程实现上,我们可以使用 PyTorch 中的 torch.nn.MultiheadAttention 来实现交叉注意力。下面是一个示例代码:

import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_size):
        super(CrossAttention, self).__init__()
        self.hidden_size = hidden_size
        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=8)
        
    def forward(self, input1, input2):
        # input1: [seq_len1, batch_size, hidden_size]
        # input2: [seq_len2, batch_size, hidden_size]
        # output: [seq_len1, batch_size, hidden_size]
        
        # 将 seq_len1 和 seq_len2 维度上合并,同时将 batch_size 维度放在第二维
        combined = torch.cat([input1, input2], dim=0).transpose(0, 1)  # [batch_size, seq_len1+seq_len2, hidden_size]
        attn_output, _ = self.multihead_attn(combined, combined, combined)  # [batch_size, seq_len1+seq_len2, hidden_size]
        
        # 将 seq_len1 和 seq_len2 维度上切分开,并将 seq_len1 放回原来的位置
        attn_output = attn_output.transpose(0, 1)  # [seq_len1+seq_len2, batch_size, hidden_size]
        output1 = attn_output[:input1.size(0), :, :]  # [seq_len1, batch_size, hidden_size]
        
        return output1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

以上代码实现了一个 CrossAttention 类,其中包含一个 MultiheadAttention 模块,该模块实现了注意力计算的过程。在 forward 函数中,我们首先将输入序列按照 batch_size 和 hidden_size 维度进行合并,然后将合并后的序列作为 Q、K、V 三个输入传入 MultiheadAttention 模块中进行计算,最后将计算结果按照原来的序列长度进行切分,并返回原输入序列 1 的注意力输出结果。

如果要在模型中使用 CrossAttention,只需要将上述代码加入到网络模型的 forward 函数中即可。

7. CBAM Module

CBAM模块是一种包含通道和空间注意力机制的注意力模块,可以在通道维度和空间维度分别对特征进行加权融合,以提高模型的表达能力。

在CBAM模块中,首先通过基于通道信息的全局最大池化和全局平均池化计算通道注意力系数,然后对通道信息进行加权融合。接着,将加权后的通道信息进行基于空间信息的自注意力计算,得到最终的特征表示。

在PyTorch中,可以通过nn.Sequential和nn.AdaptiveAvgPool2d来实现。

PyTorch代码示例:

import torch.nn as nn

# 定义一个包含CBAM模块的卷积层
class CBAMBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CBAMBlock, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction

        # 通道注意力计算
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

        # 空间注意力计算
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, stride=1),
            nn.BatchNorm2d(in_channels // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels // reduction, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, 1, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # 计算通道注意力系数
        channel_att = self.channel_attention(x)
        out = x * channel_att

        # 计算空间注意力系数
        spatial_att = self.spatial_attention(out)
        out = out * spatial_att        
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/757396
推荐阅读
相关标签
  

闽ICP备14008679号