赞
踩
manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现
之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示
参数及简写 | Mamba论文简写 |
---|---|
batch_size b | B |
序列长度 l | L |
隐藏维度 d / d_model | |
潜在状态维度 n / d_state | N |
扩展因子 expand | E |
d_in / d_inner | D |
数据依赖步长 Δ \Delta Δ / delta | |
delta秩 dt_rank |
根据forward简单梳理MambaBlock的结构
中间变量 | 来源 | shape |
---|---|---|
输入x | (b, l, d_model) | |
x_and_res | x经过输入映射后 | (b, l, 2* d_in) |
x | 切分后作为ssm分支输入 | (b, l, d_in) |
res | 切分后作为门控分支输入 | (b, l, d_in) |
y | 经过卷积,激活,ssm,门控后的输出 | (b, l, d_in) |
output | y经过输出映射后得到 | (b, l, d_model) |
def forward(self, x): (b, l, d) = x.shape x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1) x = rearrange(x, 'b l d_in -> b d_in l') x = self.conv1d(x)[:, :, :l] x = rearrange(x, 'b d_in l -> b l d_in') x = F.silu(x) y = self.ssm(x) y = y * F.silu(res) output = self.out_proj(y) return output
初始化主要初始了几个部分
组件定义
操作及简写 | 维度变换 |
---|---|
输入映射 in_proj | (b, l, d_model) -> (b, l, 2*d_in) |
序列变换 conv1d | 只取前l (b, d_in, l) -> (b, d_in, l) |
非线性激活 silu | |
输出映射 out_proj | (b, l, d_in) -> (b, l, d) |
ssm初始化
操作及简写 | 作用 |
---|---|
参数生成映射 x_proj | 生成数据依赖的参数B, C, Δ \Delta Δ |
delta映射 dt_proj | 将 Δ \Delta Δ从dt_rank映射到d_in |
矩阵A初始化 | 简单初始化 |
矩阵D初始化 | 简单初始化 |
def __init__(self, args: ModelArgs): super().__init__() self.args = args self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias) self.conv1d = nn.Conv1d( in_channels=args.d_inner, out_channels=args.d_inner, bias=args.conv_bias, kernel_size=args.d_conv, groups=args.d_inner, padding=args.d_conv - 1, ) # ssm模型的初始化部分 # x_proj takes in `x` and outputs the input-specific Δ, B, C self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False) # dt_proj projects Δ from dt_rank to d_in self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(args.d_inner)) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。
SSM参数 | shape | 来源 |
---|---|---|
状态矩阵A | (d_in, n) | 在初始化中定义,非数据依赖 |
输入矩阵B | (b, l, n) | 由x_db1切分而来,因此数据依赖 |
输出矩阵C | (b, l, n) | 由x_db1切分而来,因此数据依赖 |
直接传递矩阵D | (d_in) | 在初始化中定义,非数据依赖 |
数据依赖步长 Δ \Delta Δ | (b, l, d_in) | 由x_db1切分而来,因此数据依赖 |
其中一部分变量初始化于class MambaBlock的初始化部分
中间变量及简写 | 来源 |
---|---|
数据生成变量 x_db1 | x经过参数映射x_proj生成 |
最终delta Δ \Delta Δ | 切分而来的 Δ \Delta Δ经过映射和softplus |
def ssm(self, x):
(d_in, n) = self.A_log.shape
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
return y
SSM参数 | shape |
---|---|
状态矩阵A | (d_in, n) |
输入矩阵B | (b, l, n) |
输出矩阵C | (b, l, n) |
直接传递矩阵D | (d_in) |
我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。
在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化
前向欧拉离散化
x
k
=
(
I
+
Δ
k
A
)
x
k
−
1
+
Δ
k
B
⋅
u
k
x
(
t
+
Δ
)
=
(
I
+
Δ
A
)
x
(
t
)
+
Δ
B
⋅
u
(
t
)
零阶保持离散化
x
k
=
e
Δ
k
A
x
k
−
1
+
(
Δ
k
A
)
−
1
(
e
Δ
k
A
−
I
)
⋅
Δ
k
B
⋅
u
k
x
(
t
+
Δ
)
=
e
Δ
A
x
(
t
)
+
(
Δ
A
)
−
1
(
e
Δ
A
−
I
)
⋅
Δ
B
⋅
u
(
t
)
这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢
def selective_scan(self, u, delta, A, B, C, D): (b, l, d_in) = u.shape n = A.shape[1] deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) x = torch.zeros((b, d_in, n), device=deltaA.device) ys = [] for i in range(l): x = deltaA[:, i] * x + deltaB_u[:, i] y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') ys.append(y) y = torch.stack(ys, dim=1) # shape (b, l, d_in) y = y + u * D return y
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。