当前位置:   article > 正文

一份关于 Mamba 和状态空间模型的可视化指南_mamba模型

mamba模型


原文链接:https://www.maartengrootendorst.com/blog/mamba/

作为语言建模的变种 Transformer 的替代方案

Transformer 架构是大型语言模型(LLM)成功的重要组成部分。几乎所有当前使用的 LLM 都使用了 Transformer 架构,从开源模型如 Mistral 到闭源模型如 ChatGPT。

为了进一步改进 LLM,人们开发了新的架构,甚至可能超越 Transformer 架构。其中一种方法是 Mamba,一种 状态空间模型

img

状态空间模型的基本架构。

Mamba 是在论文 Mamba: Linear-Time Sequence Modeling with Selective State Spaces 中提出的。你可以在其 代码库 中找到官方实现和模型检查点。

在本文中,我将介绍状态空间模型在语言建模中的应用,并逐步探索其中的概念,以便对该领域有所了解。然后,我们将讨论 Mamba 如何挑战 Transformer 架构。

作为一份可视化指南,你将看到许多可视化图表,以便更好地理解 Mamba 和状态空间模型!

第一部分:Transformer 的问题

为了说明为什么 Mamba 是一种有趣的架构,让我们先简要回顾一下 Transformer,并探讨其中的一个缺点。

Transformer 将任何文本输入视为由 标记 组成的 序列

img

Transformer 的一个重要优点是,无论接收到什么样的输入,它都可以回顾序列中的任何早期标记来推导出其表示。

img

Transformer 的核心组件

请记住,Transformer 由两个结构组成,一组编码器块用于表示文本,一组解码器块用于生成文本。这些结构可以用于多个任务,包括翻译。

img

我们可以采用这种结构,只使用解码器来创建生成模型。这种基于 Transformer 的模型称为 生成式预训练 Transformer(GPT),它使用解码器块来完成一些输入文本。

img

让我们看看它是如何工作的!

训练的福音…

单个解码器块由两个主要组件组成,掩码自注意力和前馈神经网络。

img

自注意力是这些模型工作得如此出色的一个重要原因。它使得可以快速训练时对整个序列进行无压缩的观察。

那么它是如何工作的呢?

它创建了一个矩阵,将每个标记与之前的每个标记进行比较。矩阵中的权重由标记对彼此的相关性决定。

img

在训练过程中,这个矩阵一次性创建出来。不需要先计算“My”和“name”之间的注意力,然后再计算“name”和“is”之间的注意力。

这使得训练可以进行并行化,大大加快了训练速度!

推理的诅咒!

然而,存在一个缺陷。在生成下一个标记时,我们需要重新计算整个序列的注意力,即使我们已经生成了一些标记。

img

对于长度为 L 的序列生成标记大约需要 次计算,如果序列长度增加,这可能会很昂贵。

img

这种需要重新计算整个序列的需求是 Transformer 架构的一个主要瓶颈。

让我们看看“经典”技术——循环神经网络(RNN)是如何解决这个慢推理问题的。

RNN 是一个解决方案吗?

循环神经网络(RNN)是一种基于序列的网络。它在序列的每个时间步骤中接收两个输入,即时间步骤 *t* 的输入和前一个时间步骤 *t-1* 的隐藏状态,以生成下一个隐藏状态并预测输出。

RNN 具有一个循环机制,允许它将信息从上一步传递到下一步。我们可以将这个可视化展开,使其更加明确。

img

在生成输出时,RNN 只需要考虑前一个隐藏状态和当前输入。它避免了重新计算所有先前的隐藏状态,这是 Transformer 所要做的。

换句话说,RNN 可以快速进行推理,因为它的推理时间与序列长度呈线性关系!从理论上讲,它甚至可以具有无限的上下文长度

为了说明这一点,让我们将 RNN 应用于之前使用过的输入文本。

img

每个隐藏状态是所有先前隐藏状态的聚合,并且通常是一个压缩的视图。

然而,存在一个问题…

请注意,当生成名称“Maarten”时,最后一个隐藏状态不再包含有关单词“Hello”的信息。RNN 随着时间的推移会遗忘信息,因为它只考虑一个先前状态。

RNN 的这种顺序性质导致了另一个问题。训练无法并行进行,因为它需要按顺序逐步进行。

img

与 Transformer 相比,RNN 的问题完全相反!它的推理非常快,但无法并行化。

img

我们能否找到一种既可以像 Transformer 那样并行化训练,同时又可以实现与序列长度呈线性关系的推理的架构呢?

是的!这就是 Mamba 提供的,但在深入了解其架构之前,让我们先探索一下状态空间模型的世界。

第二部分:状态空间模型(SSM)

状态空间模型(SSM)与 Transformer 和 RNN 一样,处理信息序列,如文本和信号。在本节中,我们将介绍 SSM 的基础知识以及它们与文本数据的关系。

什么是状态空间?

状态空间包含完全描述一个系统的最少数量的变量。它是一种通过定义系统的可能状态来数学表示问题的方式。

让我们简化一下。想象一下我们在迷宫中导航。"状态空间"是所有可能位置(状态)的地图。每个点代表迷宫中的一个唯一位置,具有特定的细节,比如离出口有多远。

"状态空间表示"是对这张地图的简化描述。它显示你的当前位置(当前状态),你可以下一步去哪里(可能的未来状态),以及如何到达下一个状态(向右或向左移动)。

img

尽管状态空间模型使用方程和矩阵来跟踪这种行为,但它只是一种跟踪你所在位置、你可以去哪里以及如何到达那里的方式。

描述状态的变量,例如 X 和 Y 坐标以及距离出口的距离,在我们的例子中可以表示为“状态向量”。

img

听起来很熟悉吗?那是因为语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,当前位置(状态向量)的向量可能是这样的:

img

在神经网络中,系统的“状态”通常是其隐藏状态,在大型语言模型的上下文中,生成新标记的最重要方面之一。

什么是状态空间模型?

SSM 是用于描述这些状态表示并根据一些输入预测下一个状态的模型。

传统上,在时间 *t*,SSM:

  • 将输入序列 *x(t)*(例如,在迷宫中向左和向下移动)
  • 映射到潜在状态表示 *h(t)*(例如,距离出口和 x/y 坐标)
  • 并推导出预测的输出序列 *y(t)*(例如,再次向左移动以更快地到达出口)

然而,它不是使用离散序列(如向左移动一次),而是以连续序列作为输入,并预测输出序列。

img

SSM 假设动态系统(例如在三维空间中移动的物体)可以通过其在时间 *t* 的状态通过两个方程来预测。

img

通过解这些方程,我们假设可以揭示基于观察数据(输入序列和先前状态)预测系统状态的统计原理。

它的目标是找到这种状态表示 *h(t)*,以便我们可以从输入到输出序列。

img

这两个方程是状态空间模型的核心。

本指南将引用这两个方程。为了使它们更加直观,它们以彩色进行了标记,以便您可以快速参考它们。

状态方程描述了状态如何根据输入影响状态(通过 矩阵 A)而发生变化。

img
输出方程描述了状态如何通过矩阵C转化为输出,以及输入如何影响输出(通过矩阵D)。

img

注意*:矩阵ABCD通常也被称为参数,因为它们是可学习的。

将这两个方程可视化后,我们得到以下架构:

img

让我们逐步了解这些矩阵如何影响学习过程。

假设我们有一些输入信号x(t),这个信号首先与描述输入如何影响系统的矩阵B相乘。

img

更新的状态(类似于神经网络的隐藏状态)是一个包含环境的核心“知识”的潜在空间。我们将状态与描述所有内部状态如何连接的矩阵A相乘,因为它们代表系统的基本动态。

img

正如你可能已经注意到的,矩阵A在创建状态表示之前应用,并在状态表示更新后进行更新。

然后,我们使用矩阵C来描述如何将状态转化为输出。

img

最后,我们可以利用矩阵D提供从输入到输出的直接信号。这通常也被称为跳跃连接

img

由于矩阵D类似于跳跃连接,因此SSM通常被认为是没有跳跃连接的下图所示。

img

回到我们简化的视角,我们现在可以将重点放在矩阵ABC上,它们是SSM的核心。

img

我们可以更新原始方程(并添加一些漂亮的颜色)以表示每个矩阵的目的,就像我们之前做的那样。

img

这两个方程共同旨在从观测数据中预测系统的状态。由于输入预期是连续的,SSM的主要表示是连续时间表示

从连续信号到离散信号

如果你有一个连续信号,找到状态表示***h(t)***在分析上是具有挑战性的。此外,由于我们通常有离散输入(如文本序列),我们希望将模型离散化。

为此,我们使用零阶保持技术。它的工作原理如下。首先,每当我们接收到一个离散信号时,我们保持其值,直到我们接收到一个新的离散信号。这个过程创建了一个连续信号,SSM可以使用:

img

我们保持值的时间由一个新的可学习参数表示,称为步长 。它表示输入的分辨率。

现在,我们有了连续信号作为输入,我们可以生成连续输出,并根据输入的时间步长仅对值进行采样。

img

这些采样值就是我们的离散化输出!

从数学上讲,我们可以如下应用零阶保持:

img

它们共同使我们能够从连续SSM转换为由一个公式表示的离散SSM,而不是一个函数到函数的转换,而是一个序列到序列的转换,*x*ₖ → *y*ₖ

img

在这里,矩阵AB现在表示模型的离散化参数。

我们使用**k*而不是t***来表示离散化的时间步长,并且在引用连续SSM和离散SSM时更加清晰。

注意: 在训练过程中,我们仍然保存矩阵A的连续形式,而不是离散化版本。在训练过程中,连续表示被离散化。

现在我们有了离散表示的公式,让我们探索如何实际计算模型。

循环表示

我们的离散SSM使我们能够根据特定的时间步长来制定问题,而不是连续信号。正如我们之前在RNN中看到的那样,循环方法在这里非常有用。

如果我们考虑离散时间步长而不是连续信号,我们可以用时间步长来重新制定问题:

img

在每个时间步长,我们计算当前输入(*Bx*ₖ)如何影响先前状态(Ahₖ₋₁),然后计算预测输出(*Ch*ₖ)。

img

这个表示可能已经有点熟悉了!我们可以像之前在RNN中所做的那样处理它。

img

我们可以展开(或展开)它如下:

img

请注意,我们可以使用这个离散化版本,使用RNN的基本方法。

这种技术给我们带来了RNN的优点和缺点,即快速推理和训练缓慢。

卷积表示

我们可以使用卷积来表示SSM的另一种表示方法。还记得经典图像识别任务中我们如何应用滤波器(卷积核)来得到聚合特征吗?

img

由于我们处理的是文本而不是图像,我们需要一个一维的视角:

img

我们用于表示这个“滤波器”的卷积核是从SSM公式中得出的:

img

让我们探索一下这个卷积核在实践中的工作原理。与卷积类似,我们可以使用SSM卷积核遍历每组标记并计算输出:

img

这也说明了填充对输出的影响。我改变了填充的顺序以改善可视化效果,但我们通常在句子的末尾应用填充。

在下一步中,卷积核移动一次以执行计算的下一个步骤:

img

在最后一步中,我们可以看到卷积核的完整效果:

img

将SSM表示为卷积的一个重要优点是它可以像卷积神经网络(CNN)一样进行并行训练。然而,由于固定的卷积核大小,它们的推理速度不如RNN快且无限制。

三种表示

这三种表示,连续循环卷积,都有不同的优点和缺点:

img

有趣的是,现在我们可以根据任务选择不同的表示。在训练过程中,我们使用可以并行化的卷积表示,而在推理过程中,我们使用高效的循环表示:

img

这个模型被称为线性状态空间层(LSSL)

这些表示共享一个重要的属性,即线性时不变性(LTI)。LTI表示SSM的参数ABC对于所有时间步长都是固定的。这意味着矩阵ABC对SSM生成的每个标记都是相同的。

换句话说,无论你给SSM什么序列,ABC的值都保持不变。我们有一个静态表示,它不具备内容感知性。

在我们探索Mamba如何解决这个问题之前,让我们来看看谜题的最后一块拼图,矩阵A的重要性。

矩阵A的重要性

可以说,SSM公式中最重要的一个方面是矩阵A。正如我们之前在循环表示中看到的那样,它捕捉了关于先前状态的信息,以构建状态。

img

本质上,矩阵A生成了隐藏状态:

img

因此,如何创建矩阵A可以决定只记住几个先前标记还是捕捉到目前为止看到的每个标记。特别是在循环表示的上下文中,因为它只回顾**上一个状态

那么,我们如何以保留大内存(上下文大小)的方式创建矩阵A呢?

我们使用饥饿的河马!或者HiPPO,它代表高阶多项式投影运算符。HiPPO试图将其迄今为止看到的所有输入信号压缩成一组系数的向量。
它使用矩阵A来构建一个状态表示,能够很好地捕捉最近的标记,并衰减较旧的标记。其公式可以表示如下:

img

假设我们有一个方阵A,这给我们:

img

使用HiPPO构建矩阵A被证明比将其初始化为随机矩阵要好得多。因此,与初始标记相比,它更准确地重构了新的信号(最近的标记)。

HiPPO矩阵背后的思想是它产生一个隐藏状态,可以记住其历史。

从数学上讲,它通过跟踪Legendre多项式的系数来实现这一点,这使得它能够近似所有先前的历史。

然后,HiPPO被应用于我们之前看到的循环和卷积表示,以处理长距离依赖关系。结果是Sequences的结构化状态空间(S4),这是一类可以高效处理长序列的SSM。

它由三个部分组成:

  • 状态空间模型
  • 用于处理长距离依赖关系的HiPPO
  • 用于创建循环和卷积表示的离散化

img

这类SSM具有多种优点,取决于您选择的表示(循环vs卷积)。它还可以处理长文本序列,并通过构建HiPPO矩阵来高效存储内存。

注意:如果您想深入了解如何计算HiPPO矩阵并自己构建S4模型的技术细节,我强烈建议阅读*Annotated S4

第三部分:Mamba - 一种选择性SSM

我们终于涵盖了理解Mamba的基本知识。状态空间模型可用于建模文本序列,但仍具有一些我们想要避免的缺点。

在本节中,我们将介绍Mamba的两个主要贡献:

  1. 一种选择性扫描算法,允许模型过滤(不)相关信息
  2. 一种硬件感知算法,通过并行扫描内核融合重计算实现对(中间)结果的高效存储。

它们共同创建了选择性SSMS6模型,可以像自注意力一样用于创建Mamba块

在探讨这两个主要贡献之前,让我们首先探讨为什么它们是必要的。

它试图解决什么问题?

状态空间模型,甚至是S4(结构化状态空间模型),在语言建模和生成中的某些关键任务上表现不佳,即聚焦或忽略特定输入的能力

我们可以通过两个合成任务来说明这一点,即选择性复制归纳头

选择性复制任务中,SSM的目标是复制输入的部分并按顺序输出它们:

img

然而,(循环/卷积)SSM在这个任务中表现不佳,因为它是线性时不变的。正如我们之前看到的,矩阵ABC对于SSM生成的每个标记都是相同的。

因此,SSM无法执行内容感知推理,因为它将每个标记都视为固定的A、B和C矩阵的结果。这是一个问题,因为我们希望SSM能够推理输入(提示)。

SSM在归纳头上表现不佳的第二个任务,其目标是复制输入中的模式:

img

在上面的示例中,我们实际上是进行一次性提示,我们试图“教”模型在每个“Q:”之后提供一个“A:”的响应。然而,由于SSM是时不变的,它无法选择从其历史中回忆起哪些先前的标记。

让我们通过关注矩阵B来说明这一点。无论输入**x*是什么,矩阵B都保持完全相同,因此与x***无关:

img

同样,矩阵A和C也保持不变,无论输入是什么。这说明了我们迄今为止看到的SSM的静态性质。

img

相比之下,这些任务对于Transformer来说相对容易,因为它们根据输入序列动态地改变其注意力。它们可以有选择地“查看”或“关注”序列的不同部分。

SSM在这些任务上的表现不佳说明了时不变SSM的潜在问题,即矩阵ABC的静态性导致了内容感知性的问题。

选择性保留信息

SSM的循环表示创建了一个相当高效的小状态,因为它压缩了整个历史。然而,与Transformer模型不压缩历史(通过注意力矩阵),它的能力要低得多。

Mamba旨在拥有两全其美。一个小的状态,与Transformer的状态一样强大:

img

如上所述,它通过有选择地将数据压缩到状态中来实现。当您有一个输入句子时,通常有一些信息(如停用词)没有太多意义。

为了有选择地压缩信息,我们需要使参数依赖于输入。为此,让我们首先探索训练期间SSM中输入和输出的维度:

img

在结构化状态空间模型(S4)中,矩阵ABC与输入无关,因为它们的维度**N*D***是静态的,不会改变。

img

相反,Mamba通过将序列长度和输入的批次大小纳入考虑,使矩阵BC,甚至步长 *,* 依赖于输入:

img

这意味着对于每个输入标记,我们现在有不同的BC矩阵,这解决了内容感知性的问题!

注意:矩阵A保持不变,因为我们希望状态本身保持静态,但通过BC的影响方式是动态的。

它们一起选择性地选择在隐藏状态中保留什么和忽略什么,因为它们现在依赖于输入。

较小的步长 导致忽略特定单词,而更大的步长 则更多地关注输入单词而不是上下文:

img

扫描操作

由于这些矩阵现在是动态的,所以不能使用卷积表示来计算它们,因为它假设有一个固定的内核。我们只能使用循环表示,并且失去了卷积提供的并行化。

为了实现并行化,让我们探索如何使用循环计算输出:

img

每个状态是前一个状态(乘以A)加上当前输入(乘以B)的和。这被称为扫描操作,可以使用for循环轻松计算。

相比之下,并行化似乎是不可能的,因为只有在我们有前一个状态时才能计算每个状态。然而,Mamba通过并行扫描算法使这成为可能。

它通过关联属性假设我们进行操作的顺序不重要。因此,我们可以分部计算序列并逐步组合它们:

img

动态矩阵BC以及并行扫描算法共同创建了选择性扫描算法,以表示使用循环表示的动态和快速性质。

硬件感知算法

最近的GPU的一个缺点是它们在小而高效的SRAM和大而稍微低效的DRAM之间的传输(IO)速度有限。频繁地在SRAM和DRAM之间复制信息成为瓶颈。

img

Mamba,就像Flash Attention一样,试图通过内核融合来限制我们需要从DRAM到SRAM和从SRAM到DRAM的次数。它通过允许模型防止写入中间结果并持续执行计算直到完成来实现。

img

我们可以通过可视化Mamba的基本架构来查看DRAM和SRAM分配的具体实例:

img

在这里,以下内容被融合到一个内核中:

  • 步长 进行离散化步骤
  • 选择性扫描算法
  • C相乘

硬件感知算法的最后一部分是重计算

中间状态不会被保存,但在反向传播中计算梯度时是必需的。相反,作者在反向传播过程中重新计算这些中间状态。
我们已经介绍了其架构的所有组件,如下图所示:

img

选择性 SSM。引自:Gu, Albert, and Tri Dao. “Mamba: Linear-time sequence modeling with selective state spaces.” arXiv preprint arXiv:2312.00752 (2023)。

这种架构通常被称为选择性 SSMS6模型,因为它本质上是使用选择性扫描算法计算的 S4 模型。

Mamba 块

到目前为止,我们所探讨的选择性 SSM可以被实现为一个块,就像我们可以在解码器块中表示自注意力一样。

img

与解码器一样,我们可以堆叠多个 Mamba 块,并将它们的输出用作下一个 Mamba 块的输入:

img

它从线性投影开始,以扩展输入嵌入。然后,在应用选择性 SSM之前,应用卷积以防止独立的令牌计算。

选择性 SSM具有以下特性:

  • 通过离散化创建的循环 SSM
  • 在矩阵A上使用HiPPO初始化以捕获长距离依赖
  • 使用选择性扫描算法以有选择性地压缩信息
  • 硬件感知算法以加快计算速度

当我们查看代码实现并探索端到端示例的外观时,我们可以进一步扩展这种架构:

img

注意一些变化,比如包括归一化层和 softmax 以选择输出令牌。

将所有内容整合在一起,我们既获得了快速的推理和训练,甚至获得了无限的上下文!

img

使用这种架构,作者发现它与同等大小的 Transformer 模型的性能相匹配,甚至有时超越!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/337954
推荐阅读
相关标签
  

闽ICP备14008679号