赞
踩
最近Mamba系列(Mamba、VMamba、Vision Mamba)比较火,在同样具备高效长距离建模能力的情况下,Transformer具有平方级计算复杂度,而Mamba架构则是线性级计算复杂度,并且推理速度更快。
秉承着公众号科研的思路扩展视野的思路,笔者觉得需要学习一下相关内容,于是挑选了目前较新的VMamba论文,准备开始学习。由于缺乏之前的基础知识储备,Preliminaries里面的状态空间模型及其离散化过程直接给我干蒙,想着不能出师未捷身先死,于是决定搜索相关资料,把这个过程弄明白,不过由于本人水平有限,如果内容存在错误,希望大家能给出指导进行纠正。
状态空间模型(State Space Model,简称SSM)是一种数学模型,用于描述和分析动态系统的行为。这种模型在多个领域都有应用,包括控制理论、信号处理、经济学和机器学习等。在深度学习领域,状态空间模型被用来处理序列数据,如时间序列分析、自然语言处理(NLP)和视频理解等。通过将序列数据映射到状态空间,可以更好地捕捉数据中的长期依赖关系。
状态空间模型的核心思想是将系统的当前状态(state)
x
(
t
)
∈
R
n
x(t) \in \mathbb{R}^n
x(t)∈Rn与输入(input)
u
(
t
)
∈
R
p
u(t) \in \mathbb{R}^p
u(t)∈Rp和输出(output)
y
(
t
)
∈
R
q
y(t) \in \mathbb{R}^q
y(t)∈Rq之间的关系用一组方程来表示:
x
˙
(
t
)
=
A
(
t
)
x
(
t
)
+
B
(
t
)
u
(
t
)
y
(
t
)
=
C
(
t
)
x
(
t
)
+
D
(
t
)
u
(
t
)
(1)
当式(1)中的所有矩阵均随着时间
t
t
t而变化时,此时所表示的线性时变系统,而当所有矩阵都不随时间
t
t
t变化时,此时表示的是线性非时变系统,在Mamba系列中,实际上是线性非时变系统 经Shom指出,在Mamba之前的SSM才是线性非时变系统,后续在Mamba中,相关矩阵不再是固定不变的,从而变成线性时变系统,这里的推导过程主要还是基于线性非时变系统:
x
˙
(
t
)
=
A
x
(
t
)
+
B
u
(
t
)
y
(
t
)
=
C
x
(
t
)
+
D
u
(
t
)
(2)
离散化(Discretization)是将连续的数学对象或过程转换为离散形式的过程。在不同的领域中,离散化有着不同的应用和含义,但核心思想是一致的:将连续的变量或函数映射到有限的、离散的集合中。这个过程在数学、工程、计算机科学和许多其他领域中都非常常见。
SSM作为一个连续时间系统,其难以直接集成到现代深度学习算法中:
这里再贴上状态方程公式
x
˙
(
t
)
=
A
x
(
t
)
+
B
u
(
t
)
(3)
\dot{x}(t)=A x(t)+B u(t) \tag{3}
x˙(t)=Ax(t)+Bu(t)(3)
为了进行离散化,我们首先要对状态方程(3)进行积分。
在离散化连续状态方程的过程中,积分是一个关键步骤,因为它涉及到状态变量随时间的累积效应,我们需要考虑在每个离散时间步长内状态变量是如何累积变化的。
在离散时间系统中,我们不能直接处理导数,因为离散时间点上没有导数的概念。相反,我们需要考虑在每个时间步长内状态变量的累积变化。这可以通过对连续时间积分进行离散化来实现,即将连续时间的积分转换为离散时间的求和。
在实际的数值模拟中,我们通常使用数值积分方法(如梯形法则、矩形法则、辛普森法则等)来近似连续时间积分。这些方法允许我们在离散时间点上近似连续时间的累积效应,从而得到离散时间状态方程。这个转换过程涉及到将连续时间的导数项替换为离散时间的差分项,这通常涉及到指数函数和采样间隔 T T T 的计算。
在式(3)中,假设我们直接对
x
˙
(
t
)
\dot{x}(t)
x˙(t)进行积分的话,结果如下:
x
(
t
)
=
x
(
0
)
+
∫
0
t
(
A
x
(
τ
)
+
B
u
(
τ
)
)
d
τ
(4)
x(t)=x(0)+\int_0^t(A x(\tau)+B u(\tau)) d \tau \tag{4}
x(t)=x(0)+∫0t(Ax(τ)+Bu(τ))dτ(4)
此时,积分项中会包含
x
(
τ
)
x(\tau)
x(τ)项本身,由于我们是离散系统,我们是无法获取在一个连续的时刻(
0
→
t
0\rightarrow t
0→t)内所有的
x
(
τ
)
x(\tau)
x(τ)值的,因此无法完成该积分结果的计算。
对于离散系统来说,我们希望将公式(4)这个积分表达式转变为以下形式:
x
(
k
+
1
)
=
x
(
k
)
+
∑
i
=
0
k
(
A
x
(
i
)
+
B
u
(
i
)
)
Δ
t
(5)
x(k+1)=x(k)+\sum_{i=0}^k(A x(i)+B u(i)) \Delta t \tag{5}
x(k+1)=x(k)+i=0∑k(Ax(i)+Bu(i))Δt(5)
这个形式要求我们对公式(3)进行一些改造,目标是消除
x
˙
(
t
)
\dot{x}(t)
x˙(t)表达式中的
x
(
t
)
x(t)
x(t)本身。
为了消除 x ˙ ( t ) \dot{x}(t) x˙(t)表达式中的 x ( t ) x(t) x(t)本身,我们通常会构造一个新的函数 α ( t ) x ( t ) \alpha(t)x(t) α(t)x(t),通过对这个新函数进行求导,来简化相应的导数项。
我们对 α ( t ) x ( t ) \alpha(t)x(t) α(t)x(t)进行求导
d
d
t
[
α
(
t
)
x
(
t
)
]
=
α
(
t
)
x
˙
(
t
)
+
x
(
t
)
d
α
(
t
)
d
t
(6)
\frac{d}{d t}[\alpha(t) x(t)]=\alpha(t) \dot{x}(t)+x(t) \frac{d \alpha(t)}{d t} \tag{6}
dtd[α(t)x(t)]=α(t)x˙(t)+x(t)dtdα(t)(6)
我们将公式(3)代入到公式(6)中,替换
x
˙
(
t
)
\dot{x}(t)
x˙(t):
d
d
t
[
α
(
t
)
x
(
t
)
]
=
α
(
t
)
(
A
x
(
t
)
+
B
u
(
t
)
)
+
x
(
t
)
d
α
(
t
)
d
t
(7)
\frac{d}{d t}[\alpha(t) x(t)]=\alpha(t) (A x(t)+B u(t))+x(t) \frac{d \alpha(t)}{d t} \tag{7}
dtd[α(t)x(t)]=α(t)(Ax(t)+Bu(t))+x(t)dtdα(t)(7)
我们进一步对公式(7)进行改写,合并
x
(
t
)
x(t)
x(t)的相关系数:
d
d
t
[
α
(
t
)
x
(
t
)
]
=
(
A
α
(
t
)
+
d
α
(
t
)
d
t
)
x
(
t
)
+
B
α
(
t
)
u
(
t
)
(8)
\frac{d}{d t}[\alpha(t) x(t)]=(A\alpha(t) + \frac{d \alpha(t)}{d t})x(t)+B \alpha(t) u(t) \tag{8}
dtd[α(t)x(t)]=(Aα(t)+dtdα(t))x(t)+Bα(t)u(t)(8)
由于我们的目的是消除导数项中的
x
(
t
)
x(t)
x(t),因此,我们令
x
(
t
)
x(t)
x(t)的系数项为0即可:
A
α
(
t
)
+
d
α
(
t
)
d
t
=
0
(9)
A\alpha(t) + \frac{d \alpha(t)}{d t} = 0 \tag{9}
Aα(t)+dtdα(t)=0(9)
此时,我们可以得到
α
(
t
)
\alpha(t)
α(t)的表达式:
α
(
t
)
=
e
−
A
t
(10)
\alpha(t)=e^{-At} \tag{10}
α(t)=e−At(10)
将
α
(
t
)
\alpha(t)
α(t)的表达式代入公式(8)可以得到:
d
d
t
[
e
−
A
t
x
(
t
)
]
=
B
e
−
A
t
u
(
t
)
(11)
\frac{d}{d t}[e^{-At} x(t)]=B e^{-At} u(t) \tag{11}
dtd[e−Atx(t)]=Be−Atu(t)(11)
这时我们已经完成了在导数项中消除
x
(
t
)
x(t)
x(t)的目标,对
e
−
A
t
x
(
t
)
e^{-At}x(t)
e−Atx(t)进行积分:
e
−
A
t
x
(
t
)
=
x
(
0
)
+
∫
0
t
e
−
A
τ
B
u
(
τ
)
d
τ
(12)
e^{-At}x(t)=x(0)+\int_0^t e^{-A\tau} B u(\tau) d \tau \tag{12}
e−Atx(t)=x(0)+∫0te−AτBu(τ)dτ(12)
对公式(12)进行整理:
x ( t ) = e A t x ( 0 ) + ∫ 0 t e A ( t − τ ) B u ( τ ) d τ (13) x(t)=e^{At}x(0)+\int_0^t e^{A(t-\tau)} B u(\tau) d \tau \tag{13} x(t)=eAtx(0)+∫0teA(t−τ)Bu(τ)dτ(13)
在离散系统中,我们需要将公式(13)转化为离散形式,大致步骤如下:
参数定义:采样时刻 t k t_k tk和 t k + 1 t_{k+1} tk+1,其中 k k k是采样索引, T T T是采样间隔,即 T = t k + 1 − t k T=t_{k+1}-t_k T=tk+1−tk
积分区间离散化:在连续时间积分中,我们通常有一个积分区间,例如从 t t t 到 t + △ t t+\triangle{t} t+△t。在离散时间系统中,我们需要将这个区间划分为 k k k 个等长的子区间,每个子区间的长度为 T T T。
在某个子区间内,公式(13)的形式变为:
x
(
t
k
+
1
)
=
e
A
(
t
k
+
1
−
t
k
)
x
(
t
k
)
+
∫
t
k
t
k
+
1
e
A
(
t
k
+
1
−
τ
)
B
u
(
τ
)
d
τ
(14)
x(t_{k+1})=e^{A(t_{k+1}-t_k)}x(t_{k})+\int_{t_{k}}^{t_{k+1}} e^{A(t_{k+1}-\tau)} B u(\tau) d \tau \tag{14}
x(tk+1)=eA(tk+1−tk)x(tk)+∫tktk+1eA(tk+1−τ)Bu(τ)dτ(14)
近似积分:对于每个子区间来说,考虑使用数值积分方法来近似积分,这里考虑对
u
(
t
)
u(t)
u(t)应用零阶保持法,即假设
u
(
t
)
u(t)
u(t)在采样时刻
t
k
t_k
tk和
t
k
+
1
t_{k+1}
tk+1之间是恒定的,此时,我们可以将
u
(
t
)
u(t)
u(t)当做常数项从积分项中取出:
∫
t
k
t
k
+
1
e
A
(
t
−
τ
)
B
u
(
τ
)
d
τ
=
∫
t
k
t
k
+
1
e
A
(
t
k
+
1
−
τ
)
d
τ
B
u
(
t
k
)
(15)
\int_{t_{k}}^{t_{k+1}} e^{A(t-\tau)} B u(\tau) d \tau = \int_{t_{k}}^{t_{k+1}} e^{A(t_{k+1}-\tau)} d \tau B u(t_k) \tag{15}
∫tktk+1eA(t−τ)Bu(τ)dτ=∫tktk+1eA(tk+1−τ)dτBu(tk)(15)
离散时间状态方程构建:将公式(15)的积分结果代入到公式(14)中,同时使用
T
=
t
k
+
1
−
t
k
T=t_{k+1}-t_k
T=tk+1−tk进行化简,我们可以得到:
x
(
t
k
+
1
)
=
e
A
T
x
(
t
k
)
+
∫
t
k
t
k
+
1
e
A
(
t
k
+
1
−
τ
)
d
τ
B
u
(
t
k
)
(16)
x(t_{k+1})=e^{AT}x(t_{k})+\int_{t_{k}}^{t_{k+1}} e^{A(t_{k+1}-\tau)} d \tau Bu\left(t_k\right) \tag{16}
x(tk+1)=eATx(tk)+∫tktk+1eA(tk+1−τ)dτBu(tk)(16)
引入新变量
λ
=
t
k
+
1
−
τ
\lambda=t_{k+1}-\tau
λ=tk+1−τ,对原积分进行简化得到:
x
(
t
k
+
1
)
=
e
A
T
x
(
t
k
)
+
B
u
(
t
k
)
∫
0
T
e
A
τ
d
τ
(17)
x(t_{k+1})=e^{AT}x(t_{k})+Bu\left(t_k\right)\int_{0}^{T} e^{A\tau} d \tau \tag{17}
x(tk+1)=eATx(tk)+Bu(tk)∫0TeAτdτ(17)
这里涉及到矩阵作为指数的积分,这个部分我是查阅一些资料得到的结果:
∫
0
T
e
A
τ
d
τ
=
A
−
1
(
e
A
T
−
I
)
(18)
\int_{0}^{T} e^{A\tau} d \tau=A^{-1}(e^{AT}- I) \tag{18}
∫0TeAτdτ=A−1(eAT−I)(18)
最终我们得到了离散时间状态方程:
x
(
t
k
+
1
)
=
e
A
T
x
(
t
k
)
+
(
e
A
T
−
I
)
A
−
1
B
u
(
t
k
)
(19)
x(t_{k+1})=e^{AT}x(t_{k})+(e^{AT}- I)A^{-1}B u\left(t_k\right) \tag{19}
x(tk+1)=eATx(tk)+(eAT−I)A−1Bu(tk)(19)
对比公式(19)和VMamba论文中的离散化结果:
两者形式基本一致,至此,我们完成了SSM的离散化过程的完整推导。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。