当前位置:   article > 正文

(简单易学)mamba2核心ssd算法逻辑整理(基于mamba2-minimal实现)

mamba2-minimal

本文最后也有部分代码,更详细的请参考官方链接https://github.com/tommyip/mamba2-minimal/blob/main/mamba2.py这段代码实现的是 Mamba-2 模型的核心算法 Structured State Space Duality (SSD),用于高效地处理序列数据。让我们来逐步拆解这段代码,理解其背后的原理和运作机制。

1. SSD 算法概述

  • SSD 是一种基于状态空间模型的序列处理方法,其核心思想是将序列分解成若干个块 (chunk),并在块内和块间进行高效的信息传递。

  • SSD 利用了矩阵的低秩分解和指数衰减特性,将原本复杂的序列建模问题转化为一系列高效的矩阵乘法运算,从而显著降低了计算复杂度。

2. 代码解析

segsum 函数解析:

  • 功能: 计算一个特殊的累积和,用于生成一个 1-半可分矩阵(1-semiseparable matrix),该矩阵等效于一个标量状态空间模型 (SSM)。

  • 输入:

    • x: 输入张量。

    • device: 计算设备。

  • 步骤:

    1. 复制扩展: 将输入 x 沿着最后一个维度复制 T 次,生成一个新的张量。

    2. 生成掩码: 创建两个下三角掩码矩阵,分别用于控制累积和的范围。

    3. 计算累积和: 使用 torch.cumsum 函数计算累积和,并利用掩码矩阵控制计算范围。

    4. 填充负无穷: 将不在计算范围内的元素填充为负无穷,这是为了在后续计算指数时将其置零。

SSD 函数解析:

  • 输入:

    • x: 输入序列,形状为 (batch, seqlen, n_heads, d_head)。

    • A: 控制状态转移的矩阵,形状为 (batch, seqlen, n_heads)。

    • B: 将输入映射到状态空间的矩阵,形状为 (batch, seqlen, n_heads, d_state)。

    • C: 将状态空间映射到输出的矩阵,形状为 (batch, seqlen, n_heads, d_state)。

    • chunk_size: 将序列分割成的块的大小。

    • initial_states: 初始状态,可选。

    • device: 计算设备,例如 CPU 或 GPU。

  • 步骤:

    1. 数据重排 (Rearrange into chunks):

      • 将输入序列、A、B、C 矩阵按照 chunk_size 分割成若干个块。

      • 代码中使用了 rearrange 函数进行高效的数据重排操作。

    2. 块内计算 (Intra-chunk computation):

      • 计算每个块内的输出 Y_diag:

        • L = torch.exp(segsum(A, device=device)) 计算状态转移矩阵的累积效应。

        • Y_diag 通过 C, B, L 和 x 的矩阵乘法得到。

      • 计算每个块内的最终状态 states:

        • decay_states 计算状态的衰减情况。

        • states 通过 B, decay_states 和 x 的矩阵乘法得到。

    3. 块间循环 (Inter-chunk recurrence):

      • 使用 initial_states 初始化状态,或使用默认的零向量。

      • decay_chunk 计算块间的衰减情况。

      • new_states 通过 decay_chunk 和 states 的矩阵乘法得到,表示更新后的状态。

      • final_state 保存最后一个块的最终状态。

    4. 状态到输出的转换 (State-to-output conversion):

      • state_decay_out 计算状态衰减对输出的影响。

      • Y_off 通过 C, states 和 state_decay_out 的矩阵乘法得到,表示块间信息传递对输出的贡献。

    5. 输出合并 (Output combination):

      • 将块内输出 Y_diag 和块间输出 Y_off 相加,得到最终的输出 Y。

  • 输出:

    • Y: 处理后的输出序列,形状为 (batch, seqlen, n_heads, d_head)。

    • final_state: 最后一个块的最终状态。

3. 代码亮点:

  • 高效的矩阵运算: 代码大量使用了 torch.einsum 函数,这是一种高效的矩阵乘法运算方法,可以充分利用硬件资源,加速计算。

  • 并行计算: 代码中提到,步骤 1、2 和 4 可以并行计算,这为进一步提升性能提供了空间。

总结:

  • segsum 函数是一个辅助函数,用于生成 SSD 算法中所需的特殊矩阵。

  • ssd 函数实现了 SSD 算法,通过将序列分解成块并利用矩阵的低秩分解和指数衰减特性,实现了高效的序列建模。

  1. def segsum(x: Tensor, device: Device = None) -> Tensor:
  2. """Stable segment sum calculation.
  3. `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
  4. Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
  5. """
  6. T = x.size(-1)
  7. x = repeat(x, "... d -> ... d e", e=T)
  8. mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
  9. x = x.masked_fill(~mask, 0)
  10. x_segsum = torch.cumsum(x, dim=-2)
  11. mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
  12. x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
  13. return x_segsum
  14. def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
  15. """Structed State Space Duality (SSD) - the core of Mamba-2
  16. This is almost the exact same minimal SSD code from the blog post.
  17. Arguments
  18. x: (batch, seqlen, n_heads, d_head)
  19. A: (batch, seqlen, n_heads)
  20. B: (batch, seqlen, n_heads, d_state)
  21. C: (batch, seqlen, n_heads, d_state)
  22. Return
  23. y: (batch, seqlen, n_heads, d_head)
  24. Source
  25. 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
  26. 2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
  27. """
  28. assert x.shape[1] % chunk_size == 0
  29. # Rearrange into chunks
  30. # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
  31. # This is not implemented and left as an exercise for the reader 声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
    推荐阅读
    相关标签