当前位置:   article > 正文

【模型架构】学习最火热的Mamba、Vision Mamba、MambaOut模型_mambavision

mambavision

一、Mamba

论文链接:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

代码链接:https://github.com/state-spaces/mamba

作者:Albert Gu,Tri Dao

发表单位:卡内基梅隆大学、普林斯顿大学

会议/期刊:暂无

Mamba的提出起源于RNN和Transformer本身存在的问题。

RNN的训练过程中当前时间步依赖于前一时间步的计算,因此不能并行计算,效率非常低,而结构并不复杂,所以推理速度还可以(线性计算);Transformer训练过程是矩阵运算,其训练是可以并行计算的,效率比较高,但是推理过程是一个词一个词去进行矩阵运算(即已经生成了一些token,当生成下一个token时,仍然需要重新计算整个序列的注意力),效率比较低。

那么,能不能提出一个训练和推理过程效率都很高的模型呢?这就有了Mamba。

Mamba是SSM(Structured State Space for Sequence Modeling,序列的结构化状态空间,因为有4个S,所以也称为S4)的改进,所以首先要介绍一下到底什么是SSM?

1.1 SSM的介绍

状态空间模型(State Space Model, SSM)是一种用于描述动态系统的数学模型,特别适用于时间序列分析和控制系统设计。它将系统的状态表示为一个状态向量,并通过状态方程和观测方程描述系统的动态行为和观测过程。

因此,SSM是可以用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么的模型,这就符合了作为深度学习模型基础架构的条件。

SSM的计算示意图

具体来说,可以用下面的公式描述上述过程:

状态变量:描述系统当前状态的变量。状态变量通常是一个向量,包含系统当前时刻的所有信息。

状态方程:描述系统状态如何随时间变化,t+1时刻的状态变化通常形式为:

\mathbf{x}_{t+1}=\mathbf{A}\mathbf{x}_t+\mathbf{B}\mathbf{u}_t+\mathbf{w}_t

其中,xt 是时刻 t 的状态向量,A 是状态转移矩阵,B 是控制输入矩阵,ut​ 是控制输入,wt​ 是过程噪声。

观测方程:描述如何从状态变量获得观测值。通常形式为:

\mathbf{y}_t=\mathbf{C}\mathbf{x}_t+\mathbf{D}\mathbf{u}_t+\mathbf{v}_t

其中,yt 是时刻 t 的观测向量,C 是观测矩阵,D 是直接传输矩阵,vt 是观测噪声。

同样,如果简化噪声,状态方程可以表示系统的状态如何随着时间的推移和输入的变化而变化。

\mathbf h'(t)=\mathbf A\mathbf h(t)+\mathbf B\mathbf x(t)

  • h(t) 是状态向量,表示系统在时间 t 的状态。

  • A 是状态转移矩阵,描述了系统的动态特性。

  • B 是控制输入矩阵,描述了输入 x(t) 如何影响状态。

  • x(t) 是控制输入向量,表示在时间 t 的外部输入。

  • h′(t) 是状态向量对时间 t 的导数,表示状态的变化率。

\mathbf{y}(t)=\mathbf{Ch}(t)+\mathbf{Dx}(t)

  • y(t) 是输出向量,表示在时间 t 的系统输出。

  • C 是输出矩阵,描述了状态向量 h(t) 如何映射到输出 y(t)。

  • D 是直接传输矩阵,描述了输入 x(t) 直接对输出 y(t) 的影响。

这个方程表示如何从状态向量和输入向量计算系统的输出。

状态方程和输出方程

可以注意到,输入和输出此时都是连续的,但是实际应用到深度学习模型当中,需要进行离散化,比如NLP中的单词token输入,CV中的像素块输入。

这里涉及到一部分数学推理,不过有大学高数知识就可以解决。

离散化的常用方法是通过离散化时间步长 Δt 将连续系统转换为等间隔的离散系统。

零阶保持技术(Zero-order hold technique)

首先,决定系统的离散时间间隔,即每次采样的时间间隔 Δt。

使用零阶保持器(Zero-Order Hold, ZOH,假设在每个采样间隔内输入信号保持不变)法进行离散化:

\mathbf{h}[k+1]=\mathbf{A}_d\mathbf{h}[k]+\mathbf{B}_d\mathbf{x}[k]

现在要求解Ad和Bd,让我们回顾一下简单的线性常微分方程:

\mathbf{h}'(t)=\mathbf{Ah}(t)

这是一个线性齐次微分方程。我们期望找到一个解 h(t),它表示系统状态随时间 t 的变化。

对于一阶标量线性常微分方程:

h'(t)=\lambda h(t)

解的形式是:

h(t)=h(0)e^{\lambda t}

求解过程如下所示:

大学高数的知识

这个解表示,状态 h(t)在时间 t 处的值是初始状态 h(0) 乘以指数增长(或衰减)因子e^{\lambda t}

对于向量和矩阵的情形,我们希望找到一个类似的形式。那么,设A是一个矩阵。

\mathbf{h}(t)=e^{\mathbf{A}t}\mathbf{h}(0)

因此得到Ad的解:

\mathbf{A}_d=e^{\mathbf{A}\Delta t}

为了推导 Bd 需要考虑在一个离散时间步长内,输入 x(t) 对状态 h(t) 的影响。

\mathbf{h}^{\prime}(t)=\mathbf{A}\mathbf{h}(t)+\mathbf{B}\mathbf{x}(t)

求解得到:

\mathbf{h}(t)=e^{\mathbf{A}t}\mathbf{h}(0)+\int_0^te^{\mathbf{A}(t-\tau)}\mathbf{B}\mathbf{x}(\tau)d\tau

右边的第一项很好理解,第二项B那边是设置了一个特解,状态转移矩阵 eA(t−τ) 描述了系统从时间 τ 到时间 t 的自由响应。积分 ∫0t​ 累积了所有过去时刻的输入对当前状态的影响。

\mathbf{h}[k+1]=\mathbf{A}_d\mathbf{h}[k]+\mathbf{B}_d\mathbf{x}[k]

先考虑系统状态从 t=kΔt 到 t=(k+1)Δt 的变化。可以表示为:

\mathbf{h}((k+1)\Delta t)=e^{\mathbf{A}\Delta t}\mathbf{h}(k\Delta t)+\int_{k\Delta t}^{(k+1)\Delta t}e^{\mathbf{A}((k+1)\Delta t-\tau)}\mathbf{B}\mathbf{x}(\tau)d\tau

在零阶保持假设下,输入 x(t) 在时间间隔 [kΔt,(k+1)Δt] 内保持不变,即

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/878346
推荐阅读
相关标签