稀疏注意力(Sparse Attention)是一种通过选择性地处理部分token来减少整体计算负荷的方法。这在自然语言处理和计算机视觉中的注意力机制中尤为重要,因为它可以显著降低计算复杂度和内存使用。
在标准的全连接注意力机制中,每个token(词或图像patch)都与其他所有token计算注意力权重,这会导致计算复杂度为 O ( N 2 ) O(N^2) O(N2),其中 N N N 是token的数量。这种全连接的计算在处理长序列或高分辨率图像时会非常耗时且内存消耗巨大。稀疏注意力则通过只计算部分token之间的注意力权重,从而将复杂度降低到 O ( N log N ) O(N \log N) O(NlogN) 或 O ( N ) O(N) O(N)。
PVT v2(Pyramid Vision Transformer v2)是一种改进的视觉Transformer模型,它通过使用卷积核来压缩key和value的空间,从而降低计算注意力的复杂性。这意味着在计算注意力权重时,模型不需要考虑所有token,而只考虑压缩后的key和value。
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在PVT v2中,通过使用卷积核对key和value进行空间压缩,这一过程可以表示为:
K ′ = Conv ( K ) K' = \text{Conv}(K) K′=Conv(K)
V ′ = Conv ( V ) V' = \text{Conv}(V) V′=Conv(V)
其中 (\text{Conv}) 表示卷积操作。这样,计算注意力时使用的是压缩后的key和value:
SparseAttention ( Q , K ′ , V ′ ) = softmax ( Q K ′ T d k ) V ′ \text{SparseAttention}(Q, K', V') = \text{softmax}\left(\frac{QK'^T}{\sqrt{d_k}}\right)V' SparseAttention(Q,K′,V′)=softmax(dk QK′T)V′
import torch
import torch.nn as nn
import torch.nn.functional as F
class SparseAttention(nn.Module):
def __init__(self, d_model, d_k, kernel_size=3, stride=1, padding=1):
super(SparseAttention, self).__init__()
self.query_conv = nn.Linear(d_model, d_k)
self.key_conv = nn.Conv2d(d_model, d_k, kernel_size=kernel_size, stride=stride, padding=padding)
self.value_conv = nn.Conv2d(d_model, d_k, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, q, k, v):
q = self.query_conv(q) # (batch_size, seq_len, d_k)
# Assuming k and v are of shape (batch_size, d_model, height, width)
k = self.key_conv(k) # (batch_size, d_k, new_height, new_width)
v = self.value_conv(v) # (batch_size, d_k, new_height, new_width)
# Flatten spatial dimensions
k = k.flatten(2) # (batch_size, d_k, new_height * new_width)
v = v.flatten(2) # (batch_size, d_k, new_height * new_width)
# Compute attention weights
attn_weights = F.softmax(torch.bmm(q, k.transpose(1, 2)) / (k.size(1) ** 0.5), dim=-1) # (batch_size, seq_len, new_height * new_width)
# Compute attention output
attn_output = torch.bmm(attn_weights, v.transpose(1, 2)) # (batch_size, seq_len, d_k)
return attn_output
# Example usage
batch_size = 2
seq_len = 10
d_model = 64
d_k = 32
height = width = 16
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, d_model, height, width)
v = torch.randn(batch_size, d_model, height, width)
sparse_attention = SparseAttention(d_model, d_k)
output = sparse_attention(q, k, v)
print(output.shape) # (batch_size, seq_len, d_k)
局部注意力(Local Attention)是一种通过将注意力集中在输入序列或图像的局部区域上来减少计算负荷的方法。这种方法通过限制每个token仅与其附近的token计算注意力,从而降低了计算复杂度。
Swin Transformer(Shifted Window Transformer)是一种基于局部注意力的Transformer模型,它通过引入基于窗口的注意力机制,将计算限制在指定的窗口大小内进行。具体来说,Swin Transformer将输入图像划分为多个非重叠的窗口,然后在每个窗口内独立计算注意力。为了进一步增强模型的全局上下文捕捉能力,Swin Transformer还引入了窗口的移位操作。
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在Swin Transformer中,输入图像被划分为多个窗口,每个窗口内计算局部注意力:
LocalAttention ( Q w , K w , V w ) = softmax ( Q w K w T d k ) V w \text{LocalAttention}(Q_w, K_w, V_w) = \text{softmax}\left(\frac{Q_w K_w^T}{\sqrt{d_k}}\right)V_w LocalAttention(Qw,Kw,Vw)=softmax(dk QwKwT)Vw
其中 (Q_w), (K_w), 和 (V_w) 分别是窗口内的查询、键和值矩阵。
为了捕捉全局信息,Swin Transformer引入了窗口移位操作(Shifted Window)。通过在每个注意力计算层之间对窗口进行移位,可以实现窗口之间的信息交互。
import torch
import torch.nn as nn
import torch.nn.functional as F
class WindowAttention(nn.Module):
def __init__(self, d_model, window_size):
super(WindowAttention, self).__init__()
self.window_size = window_size
self.query_conv = nn.Linear(d_model, d_model)
self.key_conv = nn.Linear(d_model, d_model)
self.value_conv = nn.Linear(d_model, d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, C = x.shape
x = x.view(B, N, C)
windows = self.window_partition(x)
# Apply linear transformations
Q = self.query_conv(windows)
K = self.key_conv(windows)
V = self.value_conv(windows)
# Calculate attention within each window
attn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)
attn_output = torch.matmul(attn_weights, V)
# Concatenate windows back to original shape
attn_output = self.window_reverse(attn_output, B, N)
return self.proj(attn_output)
def window_partition(self, x):
B, N, C = x.shape
window_size = self.window_size
x = x.view(B, int(N**0.5), int(N**0.5), C) # Assume input is a square image
windows = x.unfold(1, window_size, window_size).unfold(2, window_size, window_size)
windows = windows.contiguous().view(-1, window_size * window_size, C)
return windows
def window_reverse(self, windows, B, N):
window_size = self.window_size
C = windows.shape[-1]
windows = windows.view(B, int(N**0.5) // window_size, int(N**0.5) // window_size, window_size, window_size, C)
x = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, N, C)
return x
# Example usage
batch_size = 2
seq_len = 16
d_model = 64
window_size = 2
x = torch.randn(batch_size, seq_len, d_model)
window_attention = WindowAttention(d_model, window_size)
output = window_attention(x)
print(output.shape) # (batch_size, seq_len, d_model)
这个代码实现了一个简化的基于窗口的局部注意力机制。在这里,我们假设输入是一个形状为 (B \times N \times C) 的张量,其中 (B) 是批次大小,(N) 是序列长度(例如图像被展平后的长度),(C) 是每个token的维度。窗口大小由 window_size
窗口移位操作是Swin Transformer中的一个关键技术,它通过在注意力计算层之间对窗口进行移位,实现窗口之间的信息交互,从而增强模型的全局上下文捕捉能力。
在标准的窗口注意力机制中,每个窗口内的token只与同一窗口内的其他token进行注意力计算。这虽然减少了计算复杂度,但限制了跨窗口的信息交流。为了克服这一限制,Swin Transformer引入了窗口移位操作。在每个注意力计算层之间,对窗口进行移位,使得原本不在同一个窗口内的token可以在后续的注意力计算中进行交互。
假设输入图像被划分为大小为 ( M \times M ) 的窗口,在第一个注意力计算层中,窗口是不重叠的。在第二个注意力计算层之前,对窗口进行移位(例如,水平和垂直方向各移位 ( M/2 ) 个像素),然后再进行注意力计算。
LocalAttention ( Q w , K w , V w ) = softmax ( Q w K w T d k ) V w \text{LocalAttention}(Q_w, K_w, V_w) = \text{softmax}\left(\frac{Q_w K_w^T}{\sqrt{d_k}}\right)V_w LocalAttention(Qw,Kw,Vw)=softmax(dk QwKwT)Vw
其中 (Q_w), (K_w), 和 (V_w) 分别是窗口内的查询、键和值矩阵。
import torch
import torch.nn as nn
import torch.nn.functional as F
class WindowAttention(nn.Module):
def __init__(self, d_model, window_size):
super(WindowAttention, self).__init__()
self.window_size = window_size
self.query_conv = nn.Linear(d_model, d_model)
self.key_conv = nn.Linear(d_model, d_model)
self.value_conv = nn.Linear(d_model, d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x, shift_size=0):
B, N, C = x.shape
x = x.view(B, int(N**0.5), int(N**0.5), C) # Assume input is a square image
if shift_size > 0:
x = self.shift_window(x, shift_size) # Shift the window
windows = self.window_partition(x)
# Apply linear transformations
Q = self.query_conv(windows)
K = self.key_conv(windows)
V = self.value_conv(windows)
# Calculate attention within each window
attn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)
attn_output = torch.matmul(attn_weights, V)
# Concatenate windows back to original shape
attn_output = self.window_reverse(attn_output, B, N)
if shift_size > 0:
attn_output = self.reverse_shift_window(attn_output, shift_size) # Reverse the shift
return self.proj(attn_output)
def window_partition(self, x):
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
return windows
def window_reverse(self, windows, B, N):
window_size = self.window_size
C = windows.shape[-1]
x = windows.view(B, int(N**0.5) // window_size, int(N**0.5) // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, N, C)
return x
def shift_window(self, x, shift_size):
B, H, W, C = x.shape
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
return x
def reverse_shift_window(self, x, shift_size):
B, H, W, C = x.shape
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
return x
# Example usage
batch_size = 2
seq_len = 16
d_model = 64
window_size = 2
x = torch.randn(batch_size, seq_len, d_model)
window_attention = WindowAttention(d_model, window_size)
# Without shift
output_no_shift = window_attention(x)
print(output_no_shift.shape) # (batch_size, seq_len, d_model)
# With shift
output_shift = window_attention(x, shift_size=1)
print(output_shift.shape) # (batch_size, seq_len, d_model)
这个代码实现了一个简化的基于窗口的局部注意力机制,并且包含了窗口移位操作。窗口移位操作通过 shift_window
和 reverse_shift_window
低秩/线性注意力(Low-Rank/Linear Attention)是一种通过对自注意力机制进行低秩近似来减少计算复杂性的方法。Linformer 是一种具体的实现,它通过低秩近似大大降低了自注意力机制的计算和内存需求。
在标准的自注意力机制中,每个查询(query)与所有键(key)计算注意力权重,计算复杂度为 O ( N 2 ) O(N^2) O(N2),其中 N N N 是序列长度。低秩/线性注意力通过对注意力矩阵进行低秩近似,将计算复杂度降低为 O ( N ) O(N) O(N)。
Linformer 是一种通过低秩近似优化自注意力机制的模型。它的主要思想是将高维的键(key)和值(value)投影到一个低维空间,从而减少计算量。
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
K ′ = E K , V ′ = E V K' = EK, \quad V' = EV K′=EK,V′=EV
其中,(E) 是一个投影矩阵,将原始的高维键和值矩阵 (K) 和 (V) 投影到一个低维空间。
LinformerAttention ( Q , K ′ , V ′ ) = softmax ( Q K ′ T d k ) V ′ \text{LinformerAttention}(Q, K', V') = \text{softmax}\left(\frac{QK'^T}{\sqrt{d_k}}\right)V' LinformerAttention(Q,K′,V′)=softmax(dk QK′T)V′
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinformerAttention(nn.Module):
def __init__(self, d_model, seq_len, k_dim):
super(LinformerAttention, self).__init__()
self.seq_len = seq_len
self.k_dim = k_dim
self.query_proj = nn.Linear(d_model, d_model)
self.key_proj = nn.Linear(d_model, d_model)
self.value_proj = nn.Linear(d_model, d_model)
self.E = nn.Parameter(torch.randn(seq_len, k_dim))
self.proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v):
B, N, C = q.shape
Q = self.query_proj(q)
K = self.key_proj(k)
V = self.value_proj(v)
# Project keys and values to low-dimensional space
K = torch.matmul(self.E.T, K)
V = torch.matmul(self.E.T, V)
# Calculate attention
attn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)
attn_output = torch.matmul(attn_weights, V)
return self.proj(attn_output)
# Example usage
batch_size = 2
seq_len = 16
d_model = 64
k_dim = 8
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)
linformer_attention = LinformerAttention(d_model, seq_len, k_dim)
output = linformer_attention(q, k, v)
print(output.shape) # (batch_size, seq_len, d_model)
LinformerAttention类: 定义了Linformer低秩注意力机制。
方法中,定义了投影矩阵 E
和用于查询、键、值的线性变换 query_proj
, key_proj
, value_proj
方法中,首先对查询、键和值进行线性变换,然后使用投影矩阵 E
将键和值投影到低维空间,最后计算注意力权重和输出。Example usage: 创建输入张量 q
, k
, v
和 LinformerAttention
实例,并进行前向计算。结果输出张量的形状为 (batch_size, seq_len, d_model)
