当前位置:   article > 正文

【深度学习】注意力机制的改进:稀疏注意力、局部注意力、低秩/线性注意力_深度学习中精力分散度

深度学习中精力分散度

稀疏注意力

稀疏注意力(Sparse Attention)是一种通过选择性地处理部分token来减少整体计算负荷的方法。这在自然语言处理和计算机视觉中的注意力机制中尤为重要,因为它可以显著降低计算复杂度和内存使用。

在标准的全连接注意力机制中,每个token(词或图像patch)都与其他所有token计算注意力权重,这会导致计算复杂度为 O ( N 2 ) O(N^2) O(N2),其中 N N N 是token的数量。这种全连接的计算在处理长序列或高分辨率图像时会非常耗时且内存消耗巨大。稀疏注意力则通过只计算部分token之间的注意力权重,从而将复杂度降低到 O ( N log ⁡ N ) O(N \log N) O(NlogN) O ( N ) O(N) O(N)

PVT v2中的稀疏注意力

PVT v2(Pyramid Vision Transformer v2)是一种改进的视觉Transformer模型,它通过使用卷积核来压缩key和value的空间,从而降低计算注意力的复杂性。这意味着在计算注意力权重时,模型不需要考虑所有token,而只考虑压缩后的key和value。

公式

标准注意力机制计算如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • ( Q ) 是查询矩阵(query)
  • ( K ) 是键矩阵(key)
  • ( V ) 是值矩阵(value)
  • ( d_k ) 是键的维度

在PVT v2中,通过使用卷积核对key和value进行空间压缩,这一过程可以表示为:

K ′ = Conv ( K ) K' = \text{Conv}(K) K=Conv(K)

V ′ = Conv ( V ) V' = \text{Conv}(V) V=Conv(V)

其中 (\text{Conv}) 表示卷积操作。这样,计算注意力时使用的是压缩后的key和value:

SparseAttention ( Q , K ′ , V ′ ) = softmax ( Q K ′ T d k ) V ′ \text{SparseAttention}(Q, K', V') = \text{softmax}\left(\frac{QK'^T}{\sqrt{d_k}}\right)V' SparseAttention(Q,K,V)=softmax(dk QKT)V

代码示例

下面是一个使用PyTorch实现的简化版本的稀疏注意力机制:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseAttention(nn.Module):
    def __init__(self, d_model, d_k, kernel_size=3, stride=1, padding=1):
        super(SparseAttention, self).__init__()
        self.query_conv = nn.Linear(d_model, d_k)
        self.key_conv = nn.Conv2d(d_model, d_k, kernel_size=kernel_size, stride=stride, padding=padding)
        self.value_conv = nn.Conv2d(d_model, d_k, kernel_size=kernel_size, stride=stride, padding=padding)
    
    def forward(self, q, k, v):
        q = self.query_conv(q)  # (batch_size, seq_len, d_k)
        
        # Assuming k and v are of shape (batch_size, d_model, height, width)
        k = self.key_conv(k)  # (batch_size, d_k, new_height, new_width)
        v = self.value_conv(v)  # (batch_size, d_k, new_height, new_width)
        
        # Flatten spatial dimensions
        k = k.flatten(2)  # (batch_size, d_k, new_height * new_width)
        v = v.flatten(2)  # (batch_size, d_k, new_height * new_width)
        
        # Compute attention weights
        attn_weights = F.softmax(torch.bmm(q, k.transpose(1, 2)) / (k.size(1) ** 0.5), dim=-1)  # (batch_size, seq_len, new_height * new_width)
        
        # Compute attention output
        attn_output = torch.bmm(attn_weights, v.transpose(1, 2))  # (batch_size, seq_len, d_k)
        
        return attn_output

# Example usage
batch_size = 2
seq_len = 10
d_model = 64
d_k = 32
height = width = 16

q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, d_model, height, width)
v = torch.randn(batch_size, d_model, height, width)

sparse_attention = SparseAttention(d_model, d_k)
output = sparse_attention(q, k, v)
print(output.shape)  # (batch_size, seq_len, d_k)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

这个示例展示了如何使用卷积操作压缩key和value,并在稀疏注意力机制中计算注意力权重和输出。通过这种方法,可以显著减少计算复杂度,提高模型的效率。

局部注意力

局部注意力(Local Attention)是一种通过将注意力集中在输入序列或图像的局部区域上来减少计算负荷的方法。这种方法通过限制每个token仅与其附近的token计算注意力,从而降低了计算复杂度。

局部注意力

局部注意力机制的主要思想是将输入划分为若干个固定大小的窗口,然后在每个窗口内独立计算注意力权重。这种方法适用于处理长序列或高分辨率图像时,因为它能够显著减少计算量,同时保留足够的上下文信息。

Swin Transformer中的基于窗口的注意力

Swin Transformer(Shifted Window Transformer)是一种基于局部注意力的Transformer模型,它通过引入基于窗口的注意力机制,将计算限制在指定的窗口大小内进行。具体来说,Swin Transformer将输入图像划分为多个非重叠的窗口,然后在每个窗口内独立计算注意力。为了进一步增强模型的全局上下文捕捉能力,Swin Transformer还引入了窗口的移位操作。

公式

标准的全连接注意力计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

在Swin Transformer中,输入图像被划分为多个窗口,每个窗口内计算局部注意力:

LocalAttention ( Q w , K w , V w ) = softmax ( Q w K w T d k ) V w \text{LocalAttention}(Q_w, K_w, V_w) = \text{softmax}\left(\frac{Q_w K_w^T}{\sqrt{d_k}}\right)V_w LocalAttention(Qw,Kw,Vw)=softmax(dk QwKwT)Vw

其中 (Q_w), (K_w), 和 (V_w) 分别是窗口内的查询、键和值矩阵。

为了捕捉全局信息,Swin Transformer引入了窗口移位操作(Shifted Window)。通过在每个注意力计算层之间对窗口进行移位,可以实现窗口之间的信息交互。

代码示例

下面是一个使用PyTorch实现的简化版基于窗口的局部注意力机制:

import torch
import torch.nn as nn
import torch.nn.functional as F

class WindowAttention(nn.Module):
    def __init__(self, d_model, window_size):
        super(WindowAttention, self).__init__()
        self.window_size = window_size
        self.query_conv = nn.Linear(d_model, d_model)
        self.key_conv = nn.Linear(d_model, d_model)
        self.value_conv = nn.Linear(d_model, d_model)
        self.proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, N, C)
        windows = self.window_partition(x)
        
        # Apply linear transformations
        Q = self.query_conv(windows)
        K = self.key_conv(windows)
        V = self.value_conv(windows)
        
        # Calculate attention within each window
        attn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate windows back to original shape
        attn_output = self.window_reverse(attn_output, B, N)
        return self.proj(attn_output)
    
    def window_partition(self, x):
        B, N, C = x.shape
        window_size = self.window_size
        x = x.view(B, int(N**0.5), int(N**0.5), C)  # Assume input is a square image
        windows = x.unfold(1, window_size, window_size).unfold(2, window_size, window_size)
        windows = windows.contiguous().view(-1, window_size * window_size, C)
        return windows
    
    def window_reverse(self, windows, B, N):
        window_size = self.window_size
        C = windows.shape[-1]
        windows = windows.view(B, int(N**0.5) // window_size, int(N**0.5) // window_size, window_size, window_size, C)
        x = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, N, C)
        return x

# Example usage
batch_size = 2
seq_len = 16
d_model = 64
window_size = 2

x = torch.randn(batch_size, seq_len, d_model)
window_attention = WindowAttention(d_model, window_size)
output = window_attention(x)
print(output.shape)  # (batch_size, seq_len, d_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

这个代码实现了一个简化的基于窗口的局部注意力机制。在这里,我们假设输入是一个形状为 (B \times N \times C) 的张量,其中 (B) 是批次大小,(N) 是序列长度(例如图像被展平后的长度),(C) 是每个token的维度。窗口大小由 window_size 参数决定。通过这种方法,可以有效减少计算复杂度,同时保持较好的局部上下文信息。

窗口移位操作(Shifted Window)

窗口移位操作是Swin Transformer中的一个关键技术,它通过在注意力计算层之间对窗口进行移位,实现窗口之间的信息交互,从而增强模型的全局上下文捕捉能力。

在标准的窗口注意力机制中,每个窗口内的token只与同一窗口内的其他token进行注意力计算。这虽然减少了计算复杂度,但限制了跨窗口的信息交流。为了克服这一限制,Swin Transformer引入了窗口移位操作。在每个注意力计算层之间,对窗口进行移位,使得原本不在同一个窗口内的token可以在后续的注意力计算中进行交互。

实现细节

假设输入图像被划分为大小为 ( M \times M ) 的窗口,在第一个注意力计算层中,窗口是不重叠的。在第二个注意力计算层之前,对窗口进行移位(例如,水平和垂直方向各移位 ( M/2 ) 个像素),然后再进行注意力计算。

公式

标准的基于窗口的注意力计算公式为:

LocalAttention ( Q w , K w , V w ) = softmax ( Q w K w T d k ) V w \text{LocalAttention}(Q_w, K_w, V_w) = \text{softmax}\left(\frac{Q_w K_w^T}{\sqrt{d_k}}\right)V_w LocalAttention(Qw,Kw,Vw)=softmax(dk QwKwT)Vw

其中 (Q_w), (K_w), 和 (V_w) 分别是窗口内的查询、键和值矩阵。

在引入窗口移位操作后,计算公式保持不变,但窗口的划分方式在每层之间会有所不同。

代码示例

下面是一个包含窗口移位操作的局部注意力机制的PyTorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class WindowAttention(nn.Module):
    def __init__(self, d_model, window_size):
        super(WindowAttention, self).__init__()
        self.window_size = window_size
        self.query_conv = nn.Linear(d_model, d_model)
        self.key_conv = nn.Linear(d_model, d_model)
        self.value_conv = nn.Linear(d_model, d_model)
        self.proj = nn.Linear(d_model, d_model)
    
    def forward(self, x, shift_size=0):
        B, N, C = x.shape
        x = x.view(B, int(N**0.5), int(N**0.5), C)  # Assume input is a square image
        
        if shift_size > 0:
            x = self.shift_window(x, shift_size)  # Shift the window
        
        windows = self.window_partition(x)
        
        # Apply linear transformations
        Q = self.query_conv(windows)
        K = self.key_conv(windows)
        V = self.value_conv(windows)
        
        # Calculate attention within each window
        attn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate windows back to original shape
        attn_output = self.window_reverse(attn_output, B, N)
        
        if shift_size > 0:
            attn_output = self.reverse_shift_window(attn_output, shift_size)  # Reverse the shift
        
        return self.proj(attn_output)
    
    def window_partition(self, x):
        B, H, W, C = x.shape
        window_size = self.window_size
        x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
        return windows
    
    def window_reverse(self, windows, B, N):
        window_size = self.window_size
        C = windows.shape[-1]
        x = windows.view(B, int(N**0.5) // window_size, int(N**0.5) // window_size, window_size, window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, N, C)
        return x
    
    def shift_window(self, x, shift_size):
        B, H, W, C = x.shape
        x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
        return x
    
    def reverse_shift_window(self, x, shift_size):
        B, H, W, C = x.shape
        x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
        return x

# Example usage
batch_size = 2
seq_len = 16
d_model = 64
window_size = 2

x = torch.randn(batch_size, seq_len, d_model)
window_attention = WindowAttention(d_model, window_size)

# Without shift
output_no_shift = window_attention(x)
print(output_no_shift.shape)  # (batch_size, seq_len, d_model)

# With shift
output_shift = window_attention(x, shift_size=1)
print(output_shift.shape)  # (batch_size, seq_len, d_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79

这个代码实现了一个简化的基于窗口的局部注意力机制,并且包含了窗口移位操作。窗口移位操作通过 shift_windowreverse_shift_window 方法实现。在注意力计算前进行窗口移位,注意力计算后再将窗口移回原位置。这样可以有效地增强跨窗口的信息交互能力,从而提升模型的全局上下文捕捉能力。

低秩/线性注意力

低秩/线性注意力(Low-Rank/Linear Attention)是一种通过对自注意力机制进行低秩近似来减少计算复杂性的方法。Linformer 是一种具体的实现,它通过低秩近似大大降低了自注意力机制的计算和内存需求。

低秩/线性注意力

在标准的自注意力机制中,每个查询(query)与所有键(key)计算注意力权重,计算复杂度为 O ( N 2 ) O(N^2) O(N2),其中 N N N 是序列长度。低秩/线性注意力通过对注意力矩阵进行低秩近似,将计算复杂度降低为 O ( N ) O(N) O(N)

Linformer

Linformer 是一种通过低秩近似优化自注意力机制的模型。它的主要思想是将高维的键(key)和值(value)投影到一个低维空间,从而减少计算量。

公式

标准的自注意力计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

在Linformer中,键和值矩阵被投影到一个低维空间:

K ′ = E K , V ′ = E V K' = EK, \quad V' = EV K=EK,V=EV

其中,(E) 是一个投影矩阵,将原始的高维键和值矩阵 (K) 和 (V) 投影到一个低维空间。

因此,Linformer的自注意力计算公式变为:

LinformerAttention ( Q , K ′ , V ′ ) = softmax ( Q K ′ T d k ) V ′ \text{LinformerAttention}(Q, K', V') = \text{softmax}\left(\frac{QK'^T}{\sqrt{d_k}}\right)V' LinformerAttention(Q,K,V)=softmax(dk QKT)V

代码示例

下面是一个使用PyTorch实现的Linformer低秩注意力机制的简化版本:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LinformerAttention(nn.Module):
    def __init__(self, d_model, seq_len, k_dim):
        super(LinformerAttention, self).__init__()
        self.seq_len = seq_len
        self.k_dim = k_dim
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.E = nn.Parameter(torch.randn(seq_len, k_dim))
        self.proj = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v):
        B, N, C = q.shape
        Q = self.query_proj(q)
        K = self.key_proj(k)
        V = self.value_proj(v)
        
        # Project keys and values to low-dimensional space
        K = torch.matmul(self.E.T, K)
        V = torch.matmul(self.E.T, V)
        
        # Calculate attention
        attn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        return self.proj(attn_output)

# Example usage
batch_size = 2
seq_len = 16
d_model = 64
k_dim = 8

q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)

linformer_attention = LinformerAttention(d_model, seq_len, k_dim)
output = linformer_attention(q, k, v)
print(output.shape)  # (batch_size, seq_len, d_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

代码解释

  1. LinformerAttention类: 定义了Linformer低秩注意力机制。

    • __init__方法中,定义了投影矩阵 E 和用于查询、键、值的线性变换 query_proj, key_proj, value_proj
    • forward方法中,首先对查询、键和值进行线性变换,然后使用投影矩阵 E 将键和值投影到低维空间,最后计算注意力权重和输出。
  2. Example usage: 创建输入张量 q, k, vLinformerAttention 实例,并进行前向计算。结果输出张量的形状为 (batch_size, seq_len, d_model)

通过这种方法,Linformer能够有效降低自注意力机制的计算复杂性和内存使用,同时保持良好的性能。

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号