当前位置:   article > 正文

Mamba系列日积月累(一):状态空间模型SSM的离散化过程推导_ssm模型原理与公式推倒

ssm模型原理与公式推倒


本文首发于: Mamba系列日积月累(一):状态空间模型SSM的离散化过程推导

最近Mamba系列(MambaVMambaVision Mamba)比较火,在同样具备高效长距离建模能力的情况下,Transformer具有平方级计算复杂度,而Mamba架构则是线性级计算复杂度,并且推理速度更快。

秉承着公众号科研的思路扩展视野的思路,笔者觉得需要学习一下相关内容,于是挑选了目前较新的VMamba论文,准备开始学习。由于缺乏之前的基础知识储备,Preliminaries里面的状态空间模型及其离散化过程直接给我干蒙,想着不能出师未捷身先死,于是决定搜索相关资料,把这个过程弄明白,不过由于本人水平有限,如果内容存在错误,希望大家能给出指导进行纠正。

1. 背景基础知识

1.1 什么是状态空间模型(State Space Model,SSM)?

状态空间模型(State Space Model,简称SSM)是一种数学模型,用于描述和分析动态系统的行为。这种模型在多个领域都有应用,包括控制理论、信号处理、经济学和机器学习等。在深度学习领域,状态空间模型被用来处理序列数据,如时间序列分析、自然语言处理(NLP)和视频理解等。通过将序列数据映射到状态空间,可以更好地捕捉数据中的长期依赖关系。

状态空间模型的核心思想是将系统的当前状态(state) x ( t ) ∈ R n x(t) \in \mathbb{R}^n x(t)Rn与输入(input) u ( t ) ∈ R p u(t) \in \mathbb{R}^p u(t)Rp和输出(output) y ( t ) ∈ R q y(t) \in \mathbb{R}^q y(t)Rq之间的关系用一组方程来表示:
x ˙ ( t ) = A ( t ) x ( t ) + B ( t ) u ( t ) y ( t ) = C ( t ) x ( t ) + D ( t ) u ( t ) (1)

x˙(t)=A(t)x(t)+B(t)u(t)y(t)=C(t)x(t)+D(t)u(t)
\tag{1} x˙(t)=A(t)x(t)+B(t)u(t)y(t)=C(t)x(t)+D(t)u(t)(1)

  1. 状态方程(State Equation):描述系统状态随时间的演变。状态方程通常包含当前状态和输入,以及可能的系统参数。数学上,状态方程可以表示为: x ˙ ( t ) = A ( t ) x ( t ) + B ( t ) u ( t ) \dot{x}(t)=A(t) x(t)+B(t) u(t) x˙(t)=A(t)x(t)+B(t)u(t), 其中, x ( t ) x(t) x(t)是在时间步 t t t 的系统状态, x ˙ ( t ) \dot{x}(t) x˙(t)是状态向量 x ( t ) x(t) x(t)关于时间 t t t的导数, u ( t ) u(t) u(t) 是在时间步 t t t的输入, A ( t ) A(t) A(t)是状态转移矩阵, dim ⁡ [ A ( ⋅ ) ] = n × n \operatorname{dim}[A(\cdot)]=n \times n dim[A()]=n×n B B B 是输入矩阵, dim ⁡ [ B ( ⋅ ) ] = n × p \operatorname{dim}[B(\cdot)]=n \times p dim[B()]=n×p
  2. 观测方程(Observation Equation):描述系统输出与状态之间的关系。观测方程允许我们从系统的输出中观察到系统的状态。数学上,观测方程可以表示为: y ( t ) = C ( t ) x ( t ) + D ( t ) u ( t ) y(t)=C(t) x(t)+D(t) u(t) y(t)=C(t)x(t)+D(t)u(t) 其中, y ( t ) y(t) y(t) 是在时间步 t t t 的系统输出, C ( t ) C(t) C(t)是观测矩阵, dim ⁡ [ C ( ⋅ ) ] = q × n \operatorname{dim}[C(\cdot)]=q \times n dim[C()]=q×n D ( t ) D(t) D(t) 是前馈矩阵, dim ⁡ [ D ( ⋅ ) ] = q × p \operatorname{dim}[D(\cdot)]=q \times p dim[D()]=q×p

当式(1)中的所有矩阵均随着时间 t t t而变化时,此时所表示的线性时变系统,而当所有矩阵都不随时间 t t t​变化时,此时表示的是线性非时变系统,在Mamba系列中,实际上是线性非时变系统Shom指出,在Mamba之前的SSM才是线性非时变系统,后续在Mamba中,相关矩阵不再是固定不变的,从而变成线性时变系统,这里的推导过程主要还是基于线性非时变系统:
x ˙ ( t ) = A x ( t ) + B u ( t ) y ( t ) = C x ( t ) + D u ( t ) (2)

x˙(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t)
\tag{2} x˙(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t)(2)

1.2 什么是离散化(Discretization)?

离散化(Discretization)是将连续的数学对象或过程转换为离散形式的过程。在不同的领域中,离散化有着不同的应用和含义,但核心思想是一致的:将连续的变量或函数映射到有限的、离散的集合中。这个过程在数学、工程、计算机科学和许多其他领域中都非常常见。

1.3 为什么需要离散化?

SSM作为一个连续时间系统,其难以直接集成到现代深度学习算法中:

  • 计算效率:现代深度学习框架和硬件通常是基于离散时间操作而设计的,对SSM进行离散化后,才能将其转化为可以在这些框架和硬件上高效运行的模型。
  • 训练算法:大多数深度学习训练算法,如梯度下降和反向传播,都是为离散时间模型设计的。离散化使得这些算法可以直接应用于状态空间模型,简化了训练过程。
  • 实际应用:在许多实际应用中,数据是离散的,如文本数据(单词序列)、时间序列数据(股票价格、传感器读数)等。离散时间模型更自然地与这些数据格式相匹配。
  • 模型复杂度:离散化过程可以通过选择合适的时间步长 T T T 来控制模型的复杂度。较小的时间步长可以提供更精细的控制,但计算成本更高;较大的时间步长可以减少计算量,但可能牺牲一些精度。

2. SSM离散化过程推导

这里再贴上状态方程公式
x ˙ ( t ) = A x ( t ) + B u ( t ) (3) \dot{x}(t)=A x(t)+B u(t) \tag{3} x˙(t)=Ax(t)+Bu(t)(3)
为了进行离散化,我们首先要对状态方程(3)进行积分。

2.1 为什么在离散化过程中要先进行积分?

在离散化连续状态方程的过程中,积分是一个关键步骤,因为它涉及到状态变量随时间的累积效应,我们需要考虑在每个离散时间步长内状态变量是如何累积变化的。

在离散时间系统中,我们不能直接处理导数,因为离散时间点上没有导数的概念。相反,我们需要考虑在每个时间步长内状态变量的累积变化。这可以通过对连续时间积分进行离散化来实现,即将连续时间的积分转换为离散时间的求和。

在实际的数值模拟中,我们通常使用数值积分方法(如梯形法则、矩形法则、辛普森法则等)来近似连续时间积分。这些方法允许我们在离散时间点上近似连续时间的累积效应,从而得到离散时间状态方程。这个转换过程涉及到将连续时间的导数项替换为离散时间的差分项,这通常涉及到指数函数和采样间隔 T T T​ 的计算。

2.2 为什么不直接对 x ˙ ( t ) \dot{x}(t) x˙(t)进行积分?

在式(3)中,假设我们直接对 x ˙ ( t ) \dot{x}(t) x˙(t)进行积分的话,结果如下:
x ( t ) = x ( 0 ) + ∫ 0 t ( A x ( τ ) + B u ( τ ) ) d τ (4) x(t)=x(0)+\int_0^t(A x(\tau)+B u(\tau)) d \tau \tag{4} x(t)=x(0)+0t(Ax(τ)+Bu(τ))dτ(4)
此时,积分项中会包含 x ( τ ) x(\tau) x(τ)项本身,由于我们是离散系统,我们是无法获取在一个连续的时刻( 0 → t 0\rightarrow t 0t)内所有的 x ( τ ) x(\tau) x(τ)值的,因此无法完成该积分结果的计算。

对于离散系统来说,我们希望将公式(4)这个积分表达式转变为以下形式:
x ( k + 1 ) = x ( k ) + ∑ i = 0 k ( A x ( i ) + B u ( i ) ) Δ t (5) x(k+1)=x(k)+\sum_{i=0}^k(A x(i)+B u(i)) \Delta t \tag{5} x(k+1)=x(k)+i=0k(Ax(i)+Bu(i))Δt(5)
这个形式要求我们对公式(3)进行一些改造,目标是消除 x ˙ ( t ) \dot{x}(t) x˙(t)表达式中的 x ( t ) x(t) x(t)本身。

2.3 状态方程的改造以及 α ( t ) \alpha(t) α(t)的设计

为了消除 x ˙ ( t ) \dot{x}(t) x˙(t)表达式中的 x ( t ) x(t) x(t)本身,我们通常会构造一个新的函数 α ( t ) x ( t ) \alpha(t)x(t) α(t)x(t),通过对这个新函数进行求导,来简化相应的导数项。

我们对 α ( t ) x ( t ) \alpha(t)x(t) α(t)x(t)​进行求导

d d t [ α ( t ) x ( t ) ] = α ( t ) x ˙ ( t ) + x ( t ) d α ( t ) d t (6) \frac{d}{d t}[\alpha(t) x(t)]=\alpha(t) \dot{x}(t)+x(t) \frac{d \alpha(t)}{d t} \tag{6} dtd[α(t)x(t)]=α(t)x˙(t)+x(t)dtdα(t)(6)
我们将公式(3)代入到公式(6)中,替换 x ˙ ( t ) \dot{x}(t) x˙(t)

d d t [ α ( t ) x ( t ) ] = α ( t ) ( A x ( t ) + B u ( t ) ) + x ( t ) d α ( t ) d t (7) \frac{d}{d t}[\alpha(t) x(t)]=\alpha(t) (A x(t)+B u(t))+x(t) \frac{d \alpha(t)}{d t} \tag{7} dtd[α(t)x(t)]=α(t)(Ax(t)+Bu(t))+x(t)dtdα(t)(7)
我们进一步对公式(7)进行改写,合并 x ( t ) x(t) x(t)的相关系数:

d d t [ α ( t ) x ( t ) ] = ( A α ( t ) + d α ( t ) d t ) x ( t ) + B α ( t ) u ( t ) (8) \frac{d}{d t}[\alpha(t) x(t)]=(A\alpha(t) + \frac{d \alpha(t)}{d t})x(t)+B \alpha(t) u(t) \tag{8} dtd[α(t)x(t)]=(Aα(t)+dtdα(t))x(t)+Bα(t)u(t)(8)
由于我们的目的是消除导数项中的 x ( t ) x(t) x(t),因此,我们令 x ( t ) x(t) x(t)的系数项为0即可:
A α ( t ) + d α ( t ) d t = 0 (9) A\alpha(t) + \frac{d \alpha(t)}{d t} = 0 \tag{9} Aα(t)+dtdα(t)=0(9)
此时,我们可以得到 α ( t ) \alpha(t) α(t)的表达式:
α ( t ) = e − A t (10) \alpha(t)=e^{-At} \tag{10} α(t)=eAt(10)
α ( t ) \alpha(t) α(t)的表达式代入公式(8)可以得到:
d d t [ e − A t x ( t ) ] = B e − A t u ( t ) (11) \frac{d}{d t}[e^{-At} x(t)]=B e^{-At} u(t) \tag{11} dtd[eAtx(t)]=BeAtu(t)(11)
这时我们已经完成了在导数项中消除 x ( t ) x(t) x(t)的目标,对 e − A t x ( t ) e^{-At}x(t) eAtx(t)进行积分:
e − A t x ( t ) = x ( 0 ) + ∫ 0 t e − A τ B u ( τ ) d τ (12) e^{-At}x(t)=x(0)+\int_0^t e^{-A\tau} B u(\tau) d \tau \tag{12} eAtx(t)=x(0)+0teAτBu(τ)dτ(12)
对公式(12)进行整理:

x ( t ) = e A t x ( 0 ) + ∫ 0 t e A ( t − τ ) B u ( τ ) d τ (13) x(t)=e^{At}x(0)+\int_0^t e^{A(t-\tau)} B u(\tau) d \tau \tag{13} x(t)=eAtx(0)+0teA(tτ)Bu(τ)dτ(13)

2.3 状态方程的离散化

在离散系统中,我们需要将公式(13)转化为离散形式,大致步骤如下:

  • 参数定义:采样时刻 t k t_k tk t k + 1 t_{k+1} tk+1,其中 k k k是采样索引, T T T是采样间隔,即 T = t k + 1 − t k T=t_{k+1}-t_k T=tk+1tk

  • 积分区间离散化:在连续时间积分中,我们通常有一个积分区间,例如从 t t t t + △ t t+\triangle{t} t+t。在离散时间系统中,我们需要将这个区间划分为 k k k 个等长的子区间,每个子区间的长度为 T T T​​。

    在某个子区间内,公式(13)的形式变为:
    x ( t k + 1 ) = e A ( t k + 1 − t k ) x ( t k ) + ∫ t k t k + 1 e A ( t k + 1 − τ ) B u ( τ ) d τ (14) x(t_{k+1})=e^{A(t_{k+1}-t_k)}x(t_{k})+\int_{t_{k}}^{t_{k+1}} e^{A(t_{k+1}-\tau)} B u(\tau) d \tau \tag{14} x(tk+1)=eA(tk+1tk)x(tk)+tktk+1eA(tk+1τ)Bu(τ)dτ(14)

  • 近似积分:对于每个子区间来说,考虑使用数值积分方法来近似积分,这里考虑对 u ( t ) u(t) u(t)应用零阶保持法,即假设 u ( t ) u(t) u(t)在采样时刻 t k t_k tk t k + 1 t_{k+1} tk+1之间是恒定的,此时,我们可以将 u ( t ) u(t) u(t)当做常数项从积分项中取出:
    ∫ t k t k + 1 e A ( t − τ ) B u ( τ ) d τ = ∫ t k t k + 1 e A ( t k + 1 − τ ) d τ B u ( t k ) (15) \int_{t_{k}}^{t_{k+1}} e^{A(t-\tau)} B u(\tau) d \tau = \int_{t_{k}}^{t_{k+1}} e^{A(t_{k+1}-\tau)} d \tau B u(t_k) \tag{15} tktk+1eA(tτ)Bu(τ)dτ=tktk+1eA(tk+1τ)dτBu(tk)(15)

  • 离散时间状态方程构建:将公式(15)的积分结果代入到公式(14)中,同时使用 T = t k + 1 − t k T=t_{k+1}-t_k T=tk+1tk​进行化简,我们可以得到:
    x ( t k + 1 ) = e A T x ( t k ) + ∫ t k t k + 1 e A ( t k + 1 − τ ) d τ B u ( t k ) (16) x(t_{k+1})=e^{AT}x(t_{k})+\int_{t_{k}}^{t_{k+1}} e^{A(t_{k+1}-\tau)} d \tau Bu\left(t_k\right) \tag{16} x(tk+1)=eATx(tk)+tktk+1eA(tk+1τ)dτBu(tk)(16)
    引入新变量 λ = t k + 1 − τ \lambda=t_{k+1}-\tau λ=tk+1τ,对原积分进行简化得到:
    x ( t k + 1 ) = e A T x ( t k ) + B u ( t k ) ∫ 0 T e A τ d τ (17) x(t_{k+1})=e^{AT}x(t_{k})+Bu\left(t_k\right)\int_{0}^{T} e^{A\tau} d \tau \tag{17} x(tk+1)=eATx(tk)+Bu(tk)0TeAτdτ(17)
    这里涉及到矩阵作为指数的积分,这个部分我是查阅一些资料得到的结果:
    ∫ 0 T e A τ d τ = A − 1 ( e A T − I ) (18) \int_{0}^{T} e^{A\tau} d \tau=A^{-1}(e^{AT}- I) \tag{18} 0TeAτdτ=A1(eATI)(18)
    最终我们得到了离散时间状态方程:
    x ( t k + 1 ) = e A T x ( t k ) + ( e A T − I ) A − 1 B u ( t k ) (19) x(t_{k+1})=e^{AT}x(t_{k})+(e^{AT}- I)A^{-1}B u\left(t_k\right) \tag{19} x(tk+1)=eATx(tk)+(eATI)A1Bu(tk)(19)

3. SSM离散化结果

对比公式(19)和VMamba论文中的离散化结果:

image-20240129012440256

两者形式基本一致,至此,我们完成了SSM的离散化过程的完整推导。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/569362
推荐阅读
相关标签
  

闽ICP备14008679号