赞
踩
Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
9.2.3 混合模型:将 SSD 层与 MLP 和注意力结合
虽然 Transformers 一直是深度学习在语言建模方面取得成功的主要架构,但最近表明,诸如 Mamba 之类的状态空间模型(SSM)在小到中等规模上可以匹敌甚至超越 Transformer。我们展示了这些模型家族实际上是相当紧密相关的,并且开发了一个丰富的理论框架,将 SSM 和各种注意力机制的变体通过对一类研究良好的结构化半可分离矩阵的各种分解联系起来。我们的状态空间对偶(state space duality,SSD)框架使我们能够设计一个新架构(Mamba-2),其核心层是 Mamba 选择性 SSM 的精炼版本,速度提高了 2-8 倍,同时在语言建模方面继续与 Transformers 竞争。
模型代码和预训练检查点:https://github.com/state-spaces/mamba
Transformers,特别是仅解码模型(例如 GPT (Brown et al., 2020),Llama (Touvron, Lavril, 等, 2023)),以因果方式处理输入序列,是现代深度学习成功的主要驱动因素之一。许多方法试图近似核心注意力层,以解决其效率问题(Tay 等,2022),例如在训练期间序列长度成二次方增长,并且在自回归生成期间需要线性序列长度大小的缓存。同时,另一类替代序列模型,结构化状态空间模型(SSM),在训练期间序列长度线性增长,生成期间状态大小恒定。它们在长距离任务中表现强劲(例如 S4 (Gu, Goel, 和 Ré,2022)),最近在小到中等规模的语言建模中匹敌甚至超越了 Transformers(例如 Mamba (Gu 和 Dao,2023))。然而,SSM 的发展似乎与社区改善 Transformers 的集体努力相脱节,例如理论理解和在现代硬件上优化。因此,与 Transformer 相比,理解和实验 SSM 更为困难,从算法和系统的角度来看,训练 SSM 也更具挑战性。
我们的主要目标是建立一个丰富的理论连接体系,将结构化 SSM 与各种注意力机制联系起来。这将使我们能够将最初为 Transformers 开发的算法和系统优化转移到 SSMs 上,以构建比 Transformers 性能更好且在序列长度上更高效的基础模型。一个重要的贡献是线性注意力(Linear Attention,LA)框架(Katharopoulos 等,2020),通过展示二次核注意力(quadratic kernelized attention)的 “对偶形式” 与特定线性递归的等价性,导出了自回归注意力与线性 RNN 之间的联系。这种对偶性使得能够进行高效的可并行训练和高效的自回归推理。同样,本论文提供了多种视角,将线性复杂度的 SSM 与二次复杂度的形式联系起来,结合 SSM 和注意力的优势。
(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN
状态空间对偶。我们称之为结构化状态空间对偶(structured state space duality,SSD)的框架,通过对结构化矩阵(具有亚二次(subquadratic)参数和乘法复杂度的矩阵)的抽象(abstraction)来连接结构化 SSM 和各种注意力机制。我们开发了两种广泛的表示序列模型的框架,一种是矩阵变换,另一种是张量收缩(tensor contraction),各自揭示了对偶性的不同视角。我们的技术贡献包括:
除了其内在的理论价值,我们的框架还开辟了广泛的方向来理解和改进序列模型。
高效算法。首先也是最重要的,我们的框架揭示了计算 SSM 的新高效且易于实现的算法(第 6 节)。我们引入了一种基于半可分离矩阵块分解的新 SSD 算法,利用了线性 SSM 递归和二次对偶形式,在所有主要效率轴(如训练和推理计算、内存使用和利用现代硬件上的矩阵乘法单元的能力)上获得了最佳权衡。专门的 SSD 实现比优化的 Mamba 选择性扫描实现快 2-8 倍,同时允许更大的递归状态尺寸(是 Mamba 的 8 倍或更高,几乎没有减速)。SSD 与 softmax 注意力(FlashAttention-2(Dao 2024))的优化实现竞争,在序列长度 2K 处交叉,在序列长度 16K 时快 6 倍。
架构设计。采用新架构如 SSM 的一个主要障碍是针对 Transformers 量身定制的生态系统,例如用于大规模训练的硬件高效优化和并行技术。我们的框架允许使用已建立的注意力惯例和技术来构建 SSM 的架构设计选择词汇,并进一步改进它们(第 7 节)。例如,我们引入了多头注意力(multi-head attention,MHA)的头部在 SSM 中的类似物。我们展示了 Mamba 架构是一个多输入 SSM(multi-input SSM,MIS),类似于多值注意力(multi-value attention,MVA),并比较了具有不同头部结构的其他 Mamba 变体。
我们还使用这些想法对 Mamba 模块进行轻微修改,以便实现张量并行(例如,Megatron(Shoeybi 等,2019)的风格)。主要思想包括引入分组值注意力(grouped-value attention,GVA)头部结构,并将所有依赖于数据的投影移至模块开始时并行发生。
修改后的并行 Mamba 模块与作为内部 SSM 层的 SSD 相结合,形成了 Mamba-2 架构。我们在与 Mamba 相同的设置中研究了 Mamba-2 的 Chinchilla 扩展定律(scaling laws),发现它在困惑度和实际运行时间上优于 Mamba 和 Transformer++。我们还在 Pile 上训练了不同大小的 Mamba-2 模型家族,显示它在标准下游评估中匹敌或超越了 Mamba 和开源 Transformers。例如,具有 2.7B 参数的 Mamba-2 在 Pile 上训练了 300B token,在相同数据集上超过了 Mamba-2.8B、Pythia-2.8B 甚至 Pythia-6.9B。
系统优化。SSD 框架连接了 SSM 和 Transformers,使我们能够利用大量针对 Transformers 开发的系统优化工作(第 8 节)。
第 9 节在语言建模、训练效率和一个难度较大的多查询关联回忆任务(Arora, Eyuboglu, Zhang, 等,2024)上实证验证了 Mamba-2。最后,在第 10 节中,我们提供了扩展的相关工作并讨论了我们框架带来的潜在研究方向。
(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模结构化状态空间序列模型(Structured state space sequence models,S4)是一类最近用于深度学习的序列模型,广泛关联于 RNNs、CNNs 和经典状态空间模型。它们受到特定连续系统(1)的启发,该系统通过隐式潜在状态 h ∈ R^(T,N) 将一维序列 x ∈ R^T 映射到 y ∈ R^T。
结构化 SSMs 的一般离散形式如下:
结构化 SSMs 之所以命名为 “结构化”,是因为控制时间动态的矩阵 A 必须是结构化的,以便在深度神经网络中高效地计算这种序列到序列的转换。最初引入的结构是对角加低秩(diagonal plus low-rank,DPLR)(Gu, Goel 和 Ré 2022)和对角(Gu, Gupta 等 2022;Gupta, Gu 和 Berant 2022;J. T. Smith, Warrington 和 Linderman 2023),这些结构仍然是最受欢迎的。
在本研究中,我们使用术语 “状态空间模型”(SSM)来指代结构化 SSMs。这类 SSMs 有许多变体,与神经序列模型的几大主要范式如连续时间、递归和卷积模型有深厚的联系(Gu, Johnson, Goel 等 2021)。我们在下文中提供了简要概述,并参考了先前的工作以获得更多背景和细节(Gu 2023;Gu 和 Dao 2023)。
连续时间模型。原始的结构化 SSMs 起源于函数 x(t)∈R ↦ y(t)∈R 上的连续时间映射,而不是直接作用于序列。在连续时间视角中,在方程(1a)中,矩阵 (A,B) 不是直接学习的,而是从底层(underlying)参数 (∘A,∘B) 生成的,并且带有一个参数化的步长 Δ。通过固定公式 A=fA(Δ,∘A) 和 B=fB(Δ,∘B) 将“连续参数” (Δ,∘A,∘B) 转换为“离散参数” (A,B),其中 (f_A,f_B) 被称为离散化规则。
备注 1: 尽管我们的主要模型采用了与先前工作相同的参数化和离散化步骤(见 Gu 和 Dao(2023)了解详细信息),为了简化阐述和符号表示,我们在本文其余部分中省略了它。我们注意到,先前关于结构化 SSMs 的工作将连续参数 (∘A,∘B) 和离散参数 (A,B) 分别称为 (A,B) 和 (ˉA,ˉB);我们更改了符号表示以简化展示并直接关注控制主要 SSM 递归的离散参数。
递归模型。方程(1)和(2)采用了输入 x 线性递归的形式。因此,结构化 SSMs 可以视为递归神经网络(RNNs)的一种类型,其中线性赋予了它们额外的性质,并使其避免了传统 RNNs 的顺序计算。相反,尽管有这种简化,SSMs 仍然是全表达的序列转换(在通用近似的意义上)(Kaul 2020;Orvieto 等 2023;Shida Wang 和 Xue 2023)。
卷积模型。当 SSM 的动态在时间上是恒定的,如方程(1)所示,模型被称为线性时不变(linear time-invariant,LTI)。在这种情况下,它们等同于卷积。因此,SSMs 也可以视为 CNNs 的一种类型,但(i)卷积核通过 SSM 参数 (A,B,C) 隐式参数化,(ii)卷积核通常是全局而非局部的。相反,通过经典信号处理理论,所有足够良好的卷积都可以表示为 SSMs。
通常,以前的 LTI SSMs 会使用卷积模式进行高效的可并行训练(提前看到整个输入序列),并切换到递归模式(1)以进行高效的自回归推理(逐步看到输入)。
选择性状态空间模型。Mamba 中引入了参数 (A,B,C) 也可以随时间变化的形式(2),称为选择性 SSM。与标准 LTI 公式(1)相比,该模型可以在每个时间步选择性地关注或忽略输入。事实证明,与信息密集型数据(如语言)上的 LTI SSMs 相比,尤其是当其状态大小 N 增加以允许更多信息容量时,表现更好。然而,它只能在递归模式下计算,而不能在卷积模式下计算,并且需要仔细的硬件感知实现才能高效。即便如此,它仍然不如硬件友好的模型(如 CNNs 和 Transformers)高效,因为它没有利用现代加速器(如 GPUs 和 TPUs)专门针对的矩阵乘法单元。
虽然时不变 SSMs 与连续、递归和卷积序列模型密切相关,但它们与注意力没有直接关系。在本文中,我们展示了选择性 SSMs 与注意力之间更深层次的关系,并利用它显著提高了 SSMs 的训练速度,同时允许更大的状态尺寸 NNN。
结构化 SSMs 作为序列转换(Sequence Transformations)。
定义 2.1。我们使用术语序列转换来指代对序列的参数化映射 Y=fθ(X),其中 X,Y ∈ R^(T,P),且 θ 是参数的任意集合。T 代表序列或时间轴;下标索引到第一维,例如 Xt,Yt ∈ R^P。
序列转换(例如 SSMs 或自注意力)是深度序列模型的基石,它们被纳入神经网络架构中(例如 Transformers)。方程(1)或(2)中的 SSM 是一个序列转换,其中 P=1;可以通过简单地在此维度上进行广播将其推广到 P>1(换句话说,将输入视为 P 个独立序列,并对每个序列应用 SSM)。我们可以将 P 视为一个头维度,我们将在第 7 节详细说明。
定义 2.2。我们定义 SSM 运算符
为序列转换 X ∈ R^(T,P) ↦ Y ∈ R^(T,P),由方程(2)定义。
在 SSMs 中,N 维度是一个自由参数,称为状态大小或状态维度。我们也称之为状态扩展因子,因为它将输入/输出的大小扩展了一个因子 N,这对这些模型的计算效率有影响。
最后,我们指出,许多类型的序列转换,例如注意力,可以表示为跨序列维度的单个矩阵乘法。
定义 2.3。如果序列转换 Y=fθ(X) 可以写成 Y=M_θ·X 的形式,其中 M 是依赖于参数 θ 的矩阵,我们称其为矩阵转换。我们用矩阵 M 识别序列转换,并在上下文清晰时通常省略对 θ 的依赖。
注意力机制广泛指一种计算方法,它为序列中的每对位置分配分数,使得每个元素能够 “关注” 其他元素。迄今为止,最常见和重要的注意力变体是 softmax 自注意力,其定义为:
其中 Q,K,V ∈ R^(T,P)。成对比较机制(通过 QK^T 产生)导致了注意力的特征性二次训练成本。
(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN
虽然提出了许多注意力的变体,但它们都共享这些注意力分数的核心,并采用各种近似(Tay et al. 2022)。对本文工作最重要的变体是线性注意力(Katharopoulos et al. 2020)。大致来说,这类方法通过将 softmax 折叠到核特征映射中来省略 softmax,并利用矩阵乘法的结合性将 (QK^T)⋅V 重写为 Q⋅(K^T·V)
此外,在因果(自回归)注意力的重要情况下,当因果掩码(causal mask)被合并到左侧
中时,其中 L 是下三角全 1 矩阵,那么右侧可以展开(unrolling)为递归。最近的多项工作如 RetNet(Y. Sun et al. 2023)和 GateLoop(Katsch 2023)将这一点加强为更一般形式的 L(见第 10 节)。在本文中,我们的结构化掩码注意力的公式将强烈地概括这些思想。
一般矩阵 M∈R^(T,T) 需要 T^2 个参数来表示,并且需要 O(T^2) 时间来执行矩阵-向量乘法等基本操作。结构化矩阵是指:
也许最典型的结构化矩阵家族是稀疏和低秩矩阵。然而,还有许多其他家族,如 Toeplitz、Cauchy、Vandermonde 和 butterfly 矩阵,它们都已被用于机器学习中以实现高效模型(Dao, Gu, et al. 2019;D. Fu et al. 2024;Gu, Gupta, et al. 2022;Thomas et al. 2018)。结构化矩阵是高效表示和算法的强大抽象。在本文中,我们将展示 SSMs 等价于另一类以前未用于深度学习的结构化矩阵,并利用这一联系推导出高效的方法和算法。
虽然本文发展了 SSMs、注意力和结构化矩阵之间更丰富的联系框架,但我们提供了主要方法的简要总结,该方法在算法上实际上是非常自包含(self-contained)和简单的。
递归(线性)形式。状态空间对偶(SSD)层可以定义为选择性 SSM(2)的特例。SSM 的标准计算可以作为递归(或并行扫描)应用,其在序列长度上具有线性复杂度。与 Mamba 中使用的版本相比,SSD 有两个细微的差异:
与原始选择性 SSM 相比,这些变化可以看作是略微减少表达能力以换取显著的训练效率改进。特别是,我们的新算法将允许在现代加速器上使用矩阵乘法单元。
对偶(二次)形式。SSD 的对偶形式是与注意力紧密相关的二次计算,定义为:
其中,
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。