当前位置:   article > 正文

Mamba2 coming back-Transformers are SSMs

mamba2

Question 1: What are the conceptual connections between state space models and attention? Can we combine them?
从概念的角度来看,发现 SSM 如此迷人的原因之一是给人的感觉就像_基础_一样。其中一个例子就是它们与许多主要的序列模型范式有着密切的联系。正在结构化 SSM 方面的工作中所阐述的那样,它们似乎抓住了连续、卷积和循环序列模型的本质——所有这些都包含在一个简单而优雅的模型中。(底下会介绍)
当然,除了这些之外,还有另一个主要的序列模型范式:无处不在的注意力机制的变体。SSM 总是让人感觉与注意力有些脱节,我们尝试了一段时间来更好地理解它们之间的关系
_Question 2: _Can we speed up the training of Mamba models by recasting them as matrix multiplications?
从计算的角度来看,尽管 Mamba 为提高速度付出了很多努力(特别是其硬件感知选择性扫描实现),但它的硬件效率仍然远低于注意力机制。缺少的一点是,现代加速器(如 GPU 和 TPU)_高度_专门用于矩阵乘法。虽然这对于推理来说不是问题,因为推理的瓶颈在于不同的考虑因素,但这在训练期间可能是一个大问题。

The Model

The SSD Model

Mamba-2 论文的重点是结构化状态空间对偶性(SSD),它指的是几件事:

  1. SSD 模型是指可以纳入深度神经网络的特定独立层,如注意力机制或 SSM
  2. SSD 框架是推理此模型(以及更多理论联系)的通用框架
  3. SSD 算法是一种比以前的 SSM 更有效地计算 SSD 层的算法

线性时不变系统-线性(SSMs)模型

解释线性时不变系统(Linear Time-Invariant System,简称LTI系统):是控制理论和信号处理中的一个基本概念。它描述了一个系统,其输出与输入之间的关系是线性的,并且不随时间改变。
h t = A h t − 1 + B x t y t = C ⊤ h t

ht=Aht1+Bxtyt=Cht
ht=Aht1+Bxtyt=Cht (0) h t = A t h t − 1 + B t x t y t = C t ⊤ h t
ht=Atht1+Btxtyt=Ctht
htyt=Atht1+Btxt=Ctht
(1)
A selective state space model 是允许(A,B,C)随着时间变化而变化的,通常将其看作带有形状的张量,其中 A ∈ R ( T , N , N ) , B ∈ R ( T , N ) , a n d C ∈ R ( T , N ) A\in\mathbb{R}^{(\mathrm{T,N,N})},B\in\mathbb{R}^{(\mathrm{T,N})},\mathrm{and}C\in\mathbb{R}^{(\mathrm{T,N})} AR(T,N,N),BR(T,N),andCR(T,N)。结构化 SSM 需要_A_具有可有效计算的结构,例如最常用的对角线结构,在这种情况下_A_有形状(T ,N )其中只有对角线元素矩阵被存储。

RNN模型

h t = f ( W x h x t + W h h h t − 1 ) y t = g ( W h y h t ) ,

ht=f(Wxhxt+Whhht1)yt=g(Whyht),
ht=f(Wxhxt+Whhht1)yt=g(Whyht),
思考:SSM模型作为状态空间模型与RNN结构想类似为什么,可以避免梯度消失。(ssm线性)5077f061af13c80f081099cea71ebdb.jpg

CNN模型

当SSM的动力学特性一样,在时间上是常数时,该模型被称为线性时不变( LTI ) 。在这种情况下,它们等价于卷积。
29fb0caf2d2c508f39ea948331d637f.jpg


SSD:标量结构化SSM

原始的Mamba(或者更准确地说是其核心的“S6”层)正是具有对角结构的选择性SSM。
Mamba-2 的 SSD 层只做了一个小修改:它限制了对角线A甚至进一步变成标量乘以恒等结构;换句话说A必须全部为相同值。在这种情况下A可以用形状 (T)来表示还可以识别只是一个标量(所以有时会表示它)。
方程(1)仅针对单维输入进行定义,仅针对单维输入进行定义 x ∈ R T x\in\mathbb{R}^\mathrm{T} xRT。 如果 X ∈ R ( T , P ) X\in\mathbb{R}^{(\mathrm{T,P})} XR(T,P)有P个不同的通道,作者使用相同的动态(即相同的 SSM) 独立地针对每个通道。这可以解释为SSM 模型的_单个头。_在这里,作者认为 X {\boldsymbol{X}} X这是一个形状张量, ( T , P ) (\mathrm{T},\mathrm{P}) (T,P)其中 T \mathrm{T} T是序列(时间)维度, P \mathrm{P} P是“头部维度”。
多个头可以完全独立地构建;在这篇文章的其余部分中,作者假设使用的是一个头。请注意,这些头部与多头注意力模型中的头部的工作原理完全相似,在Mamba-2中,作者也选择了与现代变形金刚相似的尺寸,例如 P = 64 \mathrm{P}=64 P=64 P = 128 \mathrm{P}=128 P=128。(为了扩展到更大的模型宽度,保持这个固定,并增加独立头的数量。)
Y ( T , P ) = S S M ( A ( T , … ) , B ( T , N ) , C ( T , N ) ) ( X ( T , P ) ) Y^{(\mathrm{T},\mathrm{P})}=\mathrm{SSM}(A^{(\mathrm{T},\ldots)},B^{(\mathrm{T},\mathrm{N})},C^{(\mathrm{T},\mathrm{N})})(X^{(\mathrm{T},\mathrm{P})}) Y(T,P)=SSM(A(T,),B(T,N),C(T,N))(X(T,P)) (2)
一些变化包括:

  1. 结构_A_,这会影响其参数形状:
    • ⋯ = ( N , N )
      =(N,N)
      =(N,N)
      对于一般(非结构化)SSM
    • ⋯ = ( N )
      =(N)
      =(N)
      对于对角线 SSM(或其他结构,例如对角线加低秩)
    • ⋯ = ( )
      =()
      =()
      对于标量 SSM(即 SSD)
  2. The state dimension N(i.e. d_state)
  3. The head dimension P(i.e. d_head)

SSD层的双重注意样形式是
M = L ∘ C B ⊤ ∈ R ( T , T ) M=L\circ CB^\top\in\mathbb{R}^{(\mathrm{T},\mathrm{T})} M=LCBR(T,T)

State Space Models are Structured Matrices

状态空间模型与称为半可分离矩阵的结构化矩阵族之间的等价性

Matrix Transformations

其思想是许多序列模型,即序列变换 X ∈ R ( T , P ) ↦ Y ∈ R ( T , P ) X\in\mathbb{R}^{(\mathrm{T},\mathrm{P})}\mapsto Y\in\mathbb{R}^{(\mathrm{T},\mathrm{P})} XR(T,P)YR(T,P),可以写成单一矩阵乘法的形式 Y = M ( X ) ⋅ X Y=M(X)\cdot X Y=M(X)X,其中是一个本身可以依赖的矩阵 X X X。我们称之为矩阵序列变换,或简称矩阵变换。在文献中,序列变换也被称为“序列混合器”或“标记混合器”,而矩阵序列变换则被称为“矩阵混合器”。其中有许多这些例子,它们可以通过矩阵的结构来区分。事实上的例子是自我注意本身, M = s o f t m a x ( Q K ⊤ ) M=\mathrm{softmax}(QK^\top) M=softmax(QK)其中是注意力矩阵。其他的例子包括MLP-Mixer,FNet,和Monarch Mixer。
为什么我们要关心这些类型的模型呢?
将序列模型作为矩阵转换来编写,为理解模型的结构和特征提供了一个强大的工具
虽然一般的非线性RNNs,如LSTMs不能写成矩阵形式,但状态空间模型可以!事实上,仅仅通过展开SSM递归的定义,就很容易看到这一点。其结果是,SSM (2) 可以写成一个矩阵变换
Y = S S M ( A , B , C ) ( X ) = M X Y=\mathsf{SSM}(A,B,C)(X)=MX Y=SSM(A,B,C)(X)=MX
h t = A t … A 1 B 0 x 0 + A t … A 2 B 1 x 1 + ⋯ + A t A t − 1 B t − 2 x t − 2 + A t B t − 1 x t − 1 + B t x t = ∑ s = 0 t A t : s × B s x s .

ht=AtA1B0x0+AtA2B1x1++AtAt1Bt2xt2+AtBt1xt1+Btxt=s=0tAt:s×Bsxs.
ht=AtA1B0x0+AtA2B1x1++AtAt1Bt2xt2+AtBt1xt1+Btxt=s=0tAt:s×Bsxs.
73842958ee2d0a39e575c45ff0b51db.jpg
y t = ∑ s = 0 t C t ⊤ A t : s × B s x s y = S S M ( A , B , C ) ( x ) = M x M j i : = C j ⊤ A j ⋯ A i + 1 B i
yt=s=0tCtAt:s×Bsxsy=SSM(A,B,C)(x)=MxMji:=CjAjAi+1Bi
yt=s=0tCtAt:s×Bsxsy=SSM(A,B,C)(x)=MxMji:=CjAjAi+1Bi

M i j = 0 \textrm{}M_{ij}=0 Mij=0,即是一个下三角矩阵,如下所示
[ C 0 ⊤ B 0 C 1 ⊤ A 1 B 0 C 1 ⊤ B 1 C 2 ⊤ A 2 A 1 B 0 C 2 ⊤ A 2 B 1 C 2 ⊤ B 2 ⋮ ⋮ ⋱ ⋱ C T ⊤ A T − 1 … A 1 B 0 C T ⊤ A T − 1 … A 2 B 1 … C T ⊤ A T − 1 B T − 2 C T ⊤ B T − 1 ]
[C0B0C1A1B0C1B1C2A2A1B0C2A2B1C2B2\varvdots\varvdotsCTAT1A1B0CTAT1A2B1CTAT1BT2CTBT1]
C0B0C1A1B0C2A2A1B0CTAT1A1B0C1B1C2A2B1CTAT1A2B1C2B2CTAT1BT2CTBT1

Semiseparable Matrices

这种类型的矩阵实际上有一个名字:它被称为一个(三角形的)半可分割的矩阵,并且已经在工程和计算线性代数的其他领域进行了研究。它有好多性质。例如,半可分割矩阵的另一种特征是它们的结构秩性质,即包含在下三角部分中的每个子矩阵都是低秩的。
image.png
所有包含半可分矩阵对角线的子矩阵都是低秩的
好处:通过矩阵乘法计算SSMS 所有计算状态空间模型的算法都可以看作是半可分矩阵上的结构化矩阵乘法算法。
[ C j ⊤ A j : i ′ × B i ′ … C j ⊤ A j : i − 1 × B i − 1 ⋮ ⋮ C j ′ − 1 ⊤ A j ′ − 1 : i ′ × B i ′ … C j ′ − 1 ⊤ A j ′ − 1 : i − 1 × B i − 1 ] = [ C j ⊤ A j : j × ⋮ C j ′ − 1 ⊤ A j ′ − 1 : j × ] A j : i − 1 × [ A i − 1 : i ′ × B i ′ ⋯ A i − 1 : i − 1 × B i − 1 ] .

[CjAj:i×BiCjAj:i1×Bi1Cj1Aj1:i×BiCj1Aj1:i1×Bi1]
=
[CjAj:j×Cj1Aj1:j×]
A_{j:i-1}^\times
[Ai1:i×BiAi1:i1×Bi1]
. CjAj:i×BiCj1Aj1:i×BiCjAj:i1×Bi1Cj1Aj1:i1×Bi1 = CjAj:j×Cj1Aj1:j× Aj:i1×[Ai1:i×BiAi1:i1×Bi1].

The Quadratic (Attention) Mode

给定上面具有相同形状的相同张量,定义一个不同的对象。
首先,将定义以下矩阵
L = [ 1 a 1 1 a 2 a 1 a 2 1 ⋮ ⋮ ⋱ ⋱ a T − 1 … a 1 a T − 1 … a 2 … a T − 1 1 ] . L=

[1a11a2a1a21aT1a1aT1a2aT11]
. L= 1a1a2a1aT1a11a2aT1a21aT11 .
然后,定义以下矩阵
M = L ∘ C B ⊤ ∈ R ( T , T ) M=L\circ CB^\top\in\mathbb{R}^{(\mathrm{T},\mathrm{T})} M=LCBR(T,T) (3)
最后,编码一个序列转换 x ∈ R T → y ∈ R T x\in\mathbb{R}^\mathrm{T}\to y\in\mathbb{R}^\mathrm{T} xRTyRT将一维输入映射到一维输出,如式所示 (1) -通过基本的矩阵乘法 y = M x . y=Mx. y=Mx.
这有什么特别之处?会注意到它看起来和注意力计算非常相似。事实上,如果全部 a t = 1 a_{t}=1 at=1,然后就是下三角形的_causal mask_和 (3) 是等价于_causal linear attention _
Y = ( L ∘ Q K ⊤ ) V Y=(L\circ QK^\top)V Y=(LQK)V
代码

 def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        (b, l, d_in) = u.shape
        n = A.shape[1]
        
        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
def segsum_unstable(x):
  """Naive segment sum calculation."""
  T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum

def segsum(x):
  """More stable segment sum calculation."""
  T = x.size(-1)
  x = repeat(x, "... d -> ... d e", e=T)
  mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
  x = x.masked_fill(~mask, 0)
  x_segsum = torch.cumsum(x, dim=-2)
  mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
  x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
  return x_segsum

  def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
    """
    Arguments:
        X: (batch, length, n_heads, d_head)
        A: (batch, length, n_heads)
        B: (batch, length, n_heads, d_state)
        C: (batch, length, n_heads, d_state)
    Return:
        Y: (batch, length, n_heads, d_head)
    """
    assert X.dtype == A.dtype == B.dtype == C.dtype
    assert X.shape[1] % block_len == 0

  # Rearrange into blocks/chunks
  X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]

  A = rearrange(A, "b c l h -> b h c l")
  A_cumsum = torch.cumsum(A, dim=-1)

  # 1. Compute the output for each intra-chunk (diagonal blocks)
  L = torch.exp(segsum(A))
  Y_diag  = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)

  # 2. Compute the state for each intra-chunk
  # (right term of low-rank factorization of off-diagonal blocks; B terms)
  decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
  states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)

  # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  # (middle term of factorization of off-diag blocks; A terms)
  if initial_states is None:
    initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]

# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)

# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
return Y, final_state

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

参考文献
Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/970835
推荐阅读
相关标签
  

闽ICP备14008679号