赞
踩
code:https://github.com/state-spaces/mamba
mamba ├── benchmarks │ └── benchmark_generation_mamba_simple.py // 示例模型的推理脚本 ├── csrc │ └── selective_scan // 选择性扫描的c++实现 ├── evals │ └── lm_harness_eval.py ├── mamba_ssm │ ├── models │ │ ├── config_mamba.py │ │ └── mixer_seq_simple.py // 使用mamba构建的一个完整的语言模型示例 │ ├── modules │ │ └── mamba_simple.py // mamba block的实现 │ ├── ops │ │ ├── triton │ │ │ ├── layernorm.py │ │ │ ├── selective_state_update.py │ │ └── selective_scan_interface.py // 选择性SSM层的实现 │ ├── utils │ │ ├── generation.py │ │ └── hf.py └── test └── ops ├── triton │ ├── test_selective_state_update.py └──test_selective_scan.py
代码中很多地方使用到了
if x.stride(-1) != 1:
x = x.contiguous()
这段代码的意思是:如果 x 在最后一个维度上的步长不是1(即元素在内存中不是紧密排列的),那么调用 .contiguous() 来重新排列 x,保证它在内存中是连续存储的。目的是为了确保后续操作的效率和正确性。
选择性SSM通过引入输入依赖的参数,如动态调整的 delta 、A、B、C 参数,来实现对序列数据的选择性处理。选择性扫描根据这些参数,执行具体的计算步骤,从而实现SSM的选择性功能。
是通过c++ ( mamba/csrc/selective_scan.cpp ) 实现的选择性扫描。
SelectiveScanFn 类的包装函数,简化使用。
是选择性扫描的参考实现,使用纯PyTorch操作。便于于理解操作的逻辑。
函数参数:
实现细节:
初始化和调整 delta: 对 delta 进行偏置调整(如果提供了 delta_bias)并应用 Softplus 激活(如果 delta_softplus 为 True),以确保 delta 为正值。
处理 B 和 C 的可变性:
根据 B 和 C 的维度,判断它们是否为变量(即是否依赖于输入)。如果是,将它们调整为适合状态空间模型的形状。
使用线性层由输入得到B,C的调整(由于该函数仅为参考实现所以没有使用线性层),再使用delta进行离散化
选择性扫描的主体逻辑:根据 delta 和参数 A、B、C 执行状态更新和输出生成。包括使用线性层由输入动态调整delta,B,C(由于该函数仅为参考所以没有使用线性层),再使用delta进行离散化,再进行状态更新和输出生成。x为状态,y为输出
生成输出序列: 根据更新后的状态和参数 C 生成输出序列。如果提供了 D 参数,还会考虑 u 对输出的直接影响。如果提供了额外信息 z,则在输出阶段对其进行处理。
Mamba模型的关键操作,不仅包含了选择性扫描操作,还整合了其他处理步骤,如一维因果卷积、线性变换等,使用了核融合,以实现Mamba模型的完整计算流程。
MambaInnerFn 类的包装函数,简化使用。
mamba_inner的参考实现,使用纯PyTorch操作。便于于理解操作的逻辑。
输入参数:
实现细节:
选择性的实现:
初始化方法(init):
forward 方法:
训练(卷积模式):
推理(递归模式):
step方法:
step方法的设计假定每次调用只处理一个时间步的输入(hidden_states的形状被断言为(batch_size, 1, feature_dim),表示一次只处理一个序列元素)。这意味着要生成整个序列,外部需要一个循环来反复调用step方法,每次生成一个序列元素。
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
Mamba模块的输出序列的形状应该与输入序列的形状相同。意味着Mamba模块可以相对灵活地被插入到需要处理序列数据的神经网络架构中,例如可以替换Transformer的自注意力层,也可以作为额外的处理层被插入到现有的序列模型中,比如RNN、GRU或LSTM之后,用来进一步提取序列中的特征或增强模型对长期依赖性的捕捉能力。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。