Mamba 环境安装踩坑问题汇总及解决方法(Windows已解决)

error: could not build wheels for causal-conv1d, which is required to instal



最近Mamba有关的论文引起了众多人的关注,虽然Mamba论文自身被ICLR 2024拒稿,但是其衍生的模型层出不穷,诸如 VimUmamba 等。笔者在配置相关环境(版本安装要求:PyTorch 1.12+CUDA 11.6+)时,发现按照他们给的安装方法12安装时会遇到非常多的bug,主要集中在causal-conv1dmamba-ssm上,原因都是版本兼容问题,特此记录。
P.S. 经过和网友的深入讨论,本文内容已经有大幅扩充,大家如果有遇到新的问题及解决方法,希望大家可以告诉我,经我们共同验证有效的bug解决方法将会及时更新进本文!
直接 pip 安装或者下载工程文件再setup,出现了以下报错但不限于:

  1. Building wheel for causal-conv1d (setup.py) ... error
  2. error: command '/usr/bin/gcc' failed with exit code 1
  3. RuntimeError: Error compiling objects for extension
  4. ERROR: Could not build wheels for causal-conv1d, which is required to install pyproject.toml-based projects
  5. Connection timed out> [end of output]
  6. ModuleNotFoundError: No module named 'packaging'
  7. FileNotFoundError: [Errno 2] No such file or directory: '/usr/local/cuda/bin/nvcc'
  8. error: subprocess-exited-with-error




  1. 使用网友配置好的Docker环境,参考:解决causal_conv1d和mamba_ssm无法安装 -> 直接使用Mamba基础环境docker镜像
    代码:docker pull kom4cr0/cuda11.7-pytorch1.13-mamba1.1.1:1.1.1

  2. 直接下载工程文件,再setup。具体可参考:运行Mamba项目时无法直接用pip install安装causal_conv1d和mamba_ssm复现U-Mamba笔者依然未安装成功,但是原作者以及GitHub issue 里有部分人可以安装成功

    git clone https://github.com/Dao-AILab/causal-conv1d.git
    cd causal-conv1d
    git checkout v1.1.1 # current latest version tag
    cd ..
    git clone https://github.com/state-spaces/mamba.git
    cd mamba
    git checkout v1.1.1 # current latest version tag
    pip install . # 方式一,下载whl安装,两种方式选择一个即可
    MAMBA_FORCE_BUILD=TRUE pip install . # 方式二,强制在本地编译安装,Win 下无法识别此命令
  3. 受博文 “flash-attention踩坑:使用conda管理CUDA”启发,合理调整安装顺序,先安装CUDA,并且安装cuda-nvcc,正确的安装步骤如下:

    conda create -n your_env_name python=3.10.13
    conda activate your_env_name
    conda install cudatoolkit==11.8 -c nvidia
    pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
    conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
    conda install packaging
    pip install causal-conv1d==1.1.1  # 版本号根据实际情况选择,或者不指定直接安装最新
    pip install mamba-ssm==1.1.3.post1  # 版本号根据实际情况选择,1.1 和 1.2 实测有函数不兼容,不设定默认装最新版本
  1. 采用 Docker环境;
  2. 按照 20240313更新 中步骤的修改代码,注意由于跳过了核心部分的CUDA加速,虽然可以跑通,但是速度很慢;
  3. 采用 20240329更新,利用 Win 中的 WSL,或者Linux虚拟机。
  4. 本文所有Windows下的安装步骤已经有了更好的解决方法,无需绕过 CDUA 加速,请移步开头导航里的系列博客!Window 下Mamba 环境安装踩坑问题汇总及解决方法 (无需绕过selective_scan_cuda)Window 下 Vim 环境安装踩坑问题汇总及解决方法Windows 下 VMamba 安装教程(无需更改base环境中的cuda版本且可加速)


  1. 如果方法三中倒数第二步无法安装,则需要从项目源码编译。

  2. Windows 下安装mamba-ssm在方法三倒数第三步之后会不一样,即需要先安装 'triton’包,之后从causal-conv1d 以及mamba源码编译,并且修改源码

    def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
        """if return_last_state is True, returns (out, last_state)
        last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
        not considered in the backward pass.
        return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
    def mamba_inner_fn(
        xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
        out_proj_weight, out_proj_bias,
        A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
        C_proj_bias=None, delta_softplus=True
        return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                  out_proj_weight, out_proj_bias,
                                  A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
        """if return_last_state is True, returns (out, last_state)
        last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
        not considered in the backward pass.
        return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
    def mamba_inner_fn(
        xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
        out_proj_weight, out_proj_bias,
        A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
        C_proj_bias=None, delta_softplus=True
        return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                  out_proj_weight, out_proj_bias,
                                  A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    当然,本人编译好的已经修改源码绕过 selective_scan_cuda 的Windows 下的whl 也有:mamba-ssm-1.1.3mamba_ssm-1.2.0.post1,可直接下载安装或联系本人vx自取。 无需绕过 CDUA 加速,请移步开头导航里的系列博客!


Win 下Mamaba的安装除了利用docker、修改源码编译之外,也有人通过WSL成功跑通最新mamba模型,参考:


1. 关于CDUA版本

不少小伙伴在装完 cuda-nvcc 以后,安装 causal-conv1d 时还是会显示CUDA版本不对的错误,这是由于环境中还可能有CUDA_HOME (Linux)或 CUDA_PATH (Windows)变量指定到错误的位置,此时需要检查:

nvcc -V
python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.CUDA_HOME)"
在 Linux 下,如果第二句命令输出位置是base环境的,使用 which nvcc 获取虚拟环境正确的路径,然后在 .bashrc 里面设置成这个位置 export CUDA_HOME='....'source ~/.bashrc 激活配置,然后再继续安装过程。

在 Win 下,则使用 where nvcc 虚拟环境正确的路径(路径到bin,不包括 nvcc.exe),把系统环境变量里的 CUDA_PATH 修改为该路径,然后继续安装过程。


2. 关于 setup 之后卡住不动

在Linux下卡住不动是因为它在下载对应的 *.whl 文件,需要科学上网,可以等它下载失败输出正确的网址,然后手动下载再pip install 这个 whl 文件。可以直接下载whl安装
causal_conv1d 下载链接为:https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.1.1/causal_conv1d-1.1.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
mamba_ssm 下载链接为:https://github.com/state-spaces/mamba/releases/download/v1.1.3.post1/mamba_ssm-1.1.3.post1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl


1. 成功安装causal_conv1d_cudaselective_scan_cuda还是报错

不少小伙伴在成功安装 causal-conv1d 之后还是会出现 import causal_conv1d_cuda 提示没有名称为 causal_conv1d_cuda 的包,或者在成功安装 mamba-ssm 之后出现 import selective_scan_cuda 提示没有名称为 selective_scan_cuda 的包,这还是CUDA环境不兼容导致的。这两个函数对应着Python程序编译动态库(Linux 下为.so 文件,Windows下为.pyd文件),不在安装好后的源码中,而在 xxxx/envs/xxxx/lib/python3.xx/site-packages/ 下面,分别对应 causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so(以本人环境为例)和 selective_scan_cuda.cpython-310-x86_64-linux-gnu.so(以本人环境为例)。

此时建议用源码方式在本地强制编译安装(CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install . 或者 MAMBA_FORCE_BUILD=TRUE pip install .),此时有的小伙伴会成功,有的小伙伴还是会报错,但是报错会给出具体信息,譬如 ImportError xxxx selective_scan_cuda.cpython-xxx-linux-gnu.so undefined symbol (可以用编译好的文件直接替换,selective-scan-cuda-linux-gnu.so)或者 ImportError xxxx causal_conv1d_cuda.cpython-xxx-linux-gnu.so undefined symbol (可以用编译好的文件直接替换,causal-conv1d-cuda.cpython-310-x86-64-linux-gnu.so),由于大家环境不一样,根据各自相应报错情况再针对性解决。【文件均可联系本人vx自取】

出现 .so undefined symbol 一般是因为 CUDA 版本不匹配造成的,参考本博客 20240418更新-关于CDUA版本。譬如在虚拟环境中 which nvcc 调用的是虚拟环境的 cuda,但是 python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.CUDA_HOME)" 输出的位置确是base 环境的 usr/local/cuda

此外,可以按照Win下的方法,修改源文件绕过对 causal_conv1d_cudaselective_scan_cuda 的调用。

2. Win 下 绕过对 causal_conv1d_cudaselective_scan_cuda 的调用 无需绕过 CDUA 加速,请移步开头导航里的系列博客!

  • causal_conv1d_cuda:在 causal_conv1d_interface.py 文件中,注释掉 import causal_conv1d_cuda,且将

    def causal_conv1d_fn(
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        seq_idx: (batch, seqlen)
        initial_states: (batch, dim, width - 1)
        final_states_out: (batch, dim, width - 1), to be written to
        activation: either None or "silu" or "swish"
        out: (batch, dim, seqlen)
        return CausalConv1dFn.apply(
    def causal_conv1d_fn(
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        seq_idx: (batch, seqlen)
        initial_states: (batch, dim, width - 1)
        final_states_out: (batch, dim, width - 1), to be written to
        activation: either None or "silu" or "swish"
        out: (batch, dim, seqlen)
        return causal_conv1d_ref(
  • causal_conv1d_cuda for Vim ( Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model ):

    def causal_conv1d_fn(x, weight, bias=None, activation=None):
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        activation: either None or "silu" or "swish"
        out: (batch, dim, seqlen)
        return CausalConv1dFn.apply(x, weight, bias, activation)
    def causal_conv1d_fn(x, weight, bias=None, activation=None):
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        activation: either None or "silu" or "swish"
        out: (batch, dim, seqlen)
        return causal_conv1d_ref(x, weight, bias, activation)
  • selective_scan_cuda:见 20240313更新

  • selective_scan_cuda for Vim ( Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model ):由于它对源码做了修改,所以如果想在Window下跑通 Vision Mamba 为了避开这个函数也需要相应修改,如下所示。

    # Copyright (c) 2023, Tri Dao, Albert Gu.
    import torch
    import torch.nn.functional as F
    from torch.cuda.amp import custom_bwd, custom_fwd
    from einops import rearrange, repeat
        from causal_conv1d import causal_conv1d_fn
        import causal_conv1d_cuda
    except ImportError:
        causal_conv1d_fn = None
        causal_conv1d_cuda = None
    # import selective_scan_cuda
    class SelectiveScanFn(torch.autograd.Function):
        def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
            if u.stride(-1) != 1:
                u = u.contiguous()
            if delta.stride(-1) != 1:
                delta = delta.contiguous()
            if D is not None:
                D = D.contiguous()
            if B.stride(-1) != 1:
                B = B.contiguous()
            if C.stride(-1) != 1:
                C = C.contiguous()
            if z is not None and z.stride(-1) != 1:
                z = z.contiguous()
            if B.dim() == 3:
                B = rearrange(B, "b dstate l -> b 1 dstate l")
                ctx.squeeze_B = True
            if C.dim() == 3:
                C = rearrange(C, "b dstate l -> b 1 dstate l")
                ctx.squeeze_C = True
            out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
            ctx.delta_softplus = delta_softplus
            ctx.has_z = z is not None
            last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
            if not ctx.has_z:
                ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
                return out if not return_last_state else (out, last_state)
                ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
                out_z = rest[0]
                return out_z if not return_last_state else (out_z, last_state)
        def backward(ctx, dout, *args):
            if not ctx.has_z:
                u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
                z = None
                out = None
                u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            # Here we just pass in None and dz will be allocated in the C++ code.
            du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
                u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
                False  # option to recompute out_z, not used here
            dz = rest[0] if ctx.has_z else None
            dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
            dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
            return (du, ddelta, dA, dB, dC,
                    dD if D is not None else None,
                    ddelta_bias if delta_bias is not None else None,
    def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
        """if return_last_state is True, returns (out, last_state)
        last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
        not considered in the backward pass.
        return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
    def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
        u: r(B D L)
        delta: r(B D L)
        A: c(D N) or r(D N)
        B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
        C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
        D: r(D)
        z: r(B D L)
        delta_bias: r(D), fp32
        out: r(B D L)
        last_state (optional): r(B D dstate) or c(B D dstate)
        dtype_in = u.dtype
        u = u.float()
        delta = delta.float()
        if delta_bias is not None:
            delta = delta + delta_bias[..., None].float()
        if delta_softplus:
            delta = F.softplus(delta)
        batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
        is_variable_B = B.dim() >= 3
        is_variable_C = C.dim() >= 3
        if A.is_complex():
            if is_variable_B:
                B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
            if is_variable_C:
                C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
            B = B.float()
            C = C.float()
        x = A.new_zeros((batch, dim, dstate))
        ys = []
        deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
        if not is_variable_B:
            deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
            if B.dim() == 3:
                deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
                B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
                deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
        if is_variable_C and C.dim() == 4:
            C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
        last_state = None
        for i in range(u.shape[2]):
            x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
            if not is_variable_C:
                y = torch.einsum('bdn,dn->bd', x, C)
                if C.dim() == 3:
                    y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
                    y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
            if i == u.shape[2] - 1:
                last_state = x
            if y.is_complex():
                y = y.real * 2
        y = torch.stack(ys, dim=2)  # (batch dim L)
        out = y if D is None else y + u * rearrange(D, "d -> d 1")
        if z is not None:
            out = out * F.silu(z)
        out = out.to(dtype=dtype_in)
        return out if not return_last_state else (out, last_state)
    class MambaInnerFnNoOutProj(torch.autograd.Function):
        def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
                 xz: (batch, dim, seqlen)
            assert checkpoint_lvl in [0, 1]
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            if torch.is_autocast_enabled():
                x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            if xz.stride(-1) != 1:
                xz = xz.contiguous()
            conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
            x, z = xz.chunk(2, dim=1)
            conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
            # We're being very careful here about the layout, to avoid extra transposes.
            # We want delta to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
            ctx.is_variable_B = B is None
            ctx.is_variable_C = C is None
            ctx.B_proj_bias_is_None = B_proj_bias is None
            ctx.C_proj_bias_is_None = C_proj_bias is None
            if B is None:  # variable B
                B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
                if B_proj_bias is not None:
                    B = B + B_proj_bias.to(dtype=B.dtype)
                if not A.is_complex():
                    # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
                if B.stride(-1) != 1:
                    B = B.contiguous()
            if C is None:  # variable C
                C = x_dbl[:, -d_state:]  # (bl dstate)
                if C_proj_bias is not None:
                    C = C + C_proj_bias.to(dtype=C.dtype)
                if not A.is_complex():
                    # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
                if C.stride(-1) != 1:
                    C = C.contiguous()
            if D is not None:
                D = D.contiguous()
            out, scan_intermediates, out_z = selective_scan_cuda.fwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
            ctx.delta_softplus = delta_softplus
            ctx.checkpoint_lvl = checkpoint_lvl
            if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
                conv1d_out, delta = None, None
            ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                                  delta_proj_weight, conv1d_out, delta,
                                  A, B, C, D, delta_bias, scan_intermediates, out)
            # return rearrange(out_z, "b d l -> b l d")
            return out_z
        def backward(ctx, dout):
            # dout: (batch, seqlen, dim)
            (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight,
             conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            x, z = xz.chunk(2, dim=1)
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            if ctx.checkpoint_lvl == 1:
                conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
                delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                                  "d (b l) -> b d l", l=L)
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
            dx, dz = dxz.chunk(2, dim=1)
            # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
            dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
                True  # option to recompute out_z
            dD = dD if D is not None else None
            dx_dbl = torch.empty_like(x_dbl)
            dB_proj_bias = None
            if ctx.is_variable_B:
                if not A.is_complex():
                    dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
                    dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
                dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
                dB = None
            dC_proj_bias = None
            if ctx.is_variable_C:
                if not A.is_complex():
                    dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
                    dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
                dx_dbl[:, -d_state:] = dC  # (bl d)
                dC = None
            ddelta = rearrange(ddelta, "b d l -> d (b l)")
            ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
            dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
            dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
            dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
            dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
            dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
            # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
            # backward of conv1d with the backward of chunk).
            dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
                x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
            dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
            dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
            return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                    dA, dB, dC, dD,
                    ddelta_bias if delta_bias is not None else None,
                    dB_proj_bias, dC_proj_bias, None)
    class MambaInnerFn(torch.autograd.Function):
        def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                    out_proj_weight, out_proj_bias,
                    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
                 xz: (batch, dim, seqlen)
            assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
            assert checkpoint_lvl in [0, 1]
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            if torch.is_autocast_enabled():
                x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                                 if out_proj_bias is not None else None)
            if xz.stride(-1) != 1:
                xz = xz.contiguous()
            conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
            x, z = xz.chunk(2, dim=1)
            conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            # We're being very careful here about the layout, to avoid extra transposes.
            # We want delta to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
            ctx.is_variable_B = B is None
            ctx.is_variable_C = C is None
            ctx.B_proj_bias_is_None = B_proj_bias is None
            ctx.C_proj_bias_is_None = C_proj_bias is None
            if B is None:  # variable B
                B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
                if B_proj_bias is not None:
                    B = B + B_proj_bias.to(dtype=B.dtype)
                if not A.is_complex():
                    # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
                if B.stride(-1) != 1:
                    B = B.contiguous()
            if C is None:  # variable C
                C = x_dbl[:, -d_state:]  # (bl dstate)
                if C_proj_bias is not None:
                    C = C + C_proj_bias.to(dtype=C.dtype)
                if not A.is_complex():
                    # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
                if C.stride(-1) != 1:
                    C = C.contiguous()
            if D is not None:
                D = D.contiguous()
            out, scan_intermediates, out_z = selective_scan_cuda.fwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
            ctx.delta_softplus = delta_softplus
            ctx.out_proj_bias_is_None = out_proj_bias is None
            ctx.checkpoint_lvl = checkpoint_lvl
            if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
                conv1d_out, delta = None, None
            ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                                  delta_proj_weight, out_proj_weight, conv1d_out, delta,
                                  A, B, C, D, delta_bias, scan_intermediates, out)
            return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
        def backward(ctx, dout):
            # dout: (batch, seqlen, dim)
            assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
            (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
             conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            x, z = xz.chunk(2, dim=1)
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            if ctx.checkpoint_lvl == 1:
                conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                    x, conv1d_weight, conv1d_bias, None, None, None, True
                delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                                  "d (b l) -> b d l", l=L)
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
            dx, dz = dxz.chunk(2, dim=1)
            dout = rearrange(dout, "b l e -> e (b l)")
            dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
            dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
                True  # option to recompute out_z
            dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
            dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
            dD = dD if D is not None else None
            dx_dbl = torch.empty_like(x_dbl)
            dB_proj_bias = None
            if ctx.is_variable_B:
                if not A.is_complex():
                    dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
                    dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
                dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
                dB = None
            dC_proj_bias = None
            if ctx.is_variable_C:
                if not A.is_complex():
                    dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
                    dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
                dx_dbl[:, -d_state:] = dC  # (bl d)
                dC = None
            ddelta = rearrange(ddelta, "b d l -> d (b l)")
            ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
            dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
            dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
            dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
            dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
            dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
            # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
            # backward of conv1d with the backward of chunk).
            dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
                x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
            dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
            dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
            return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                    dout_proj_weight, dout_proj_bias,
                    dA, dB, dC, dD,
                    ddelta_bias if delta_bias is not None else None,
                    dB_proj_bias, dC_proj_bias, None)
    class BiMambaInnerFn(torch.autograd.Function):
        def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                    out_proj_weight, out_proj_bias,
                    A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
                 xz: (batch, dim, seqlen)
            assert checkpoint_lvl in [0, 1]
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            if torch.is_autocast_enabled():
                x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                                 if out_proj_bias is not None else None)
            if xz.stride(-1) != 1:
                xz = xz.contiguous()
            conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
            x, z = xz.chunk(2, dim=1)
            conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
            # We're being very careful here about the layout, to avoid extra transposes.
            # We want delta to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
            ctx.is_variable_B = B is None
            ctx.is_variable_C = C is None
            ctx.B_proj_bias_is_None = B_proj_bias is None
            ctx.C_proj_bias_is_None = C_proj_bias is None
            if B is None:  # variable B
                B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
                if B_proj_bias is not None:
                    B = B + B_proj_bias.to(dtype=B.dtype)
                if not A.is_complex():
                    # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
                if B.stride(-1) != 1:
                    B = B.contiguous()
            if C is None:  # variable C
                C = x_dbl[:, -d_state:]  # (bl dstate)
                if C_proj_bias is not None:
                    C = C + C_proj_bias.to(dtype=C.dtype)
                if not A.is_complex():
                    # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
                if C.stride(-1) != 1:
                    C = C.contiguous()
            if D is not None:
                D = D.contiguous()
            out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
            assert not A_b.is_complex(), "A should not be complex!!"
            out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
                conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias,
            out_z = out_z_f + out_z_b.flip([-1])
            ctx.delta_softplus = delta_softplus
            ctx.out_proj_bias_is_None = out_proj_bias is None
            ctx.checkpoint_lvl = checkpoint_lvl
            if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
                conv1d_out, delta = None, None
            ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                                  delta_proj_weight, out_proj_weight, conv1d_out, delta,
                                  A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
            return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
        def backward(ctx, dout):
            # dout: (batch, seqlen, dim)
            (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
             conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f,
             out_b) = ctx.saved_tensors
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            x, z = xz.chunk(2, dim=1)
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            if ctx.checkpoint_lvl == 1:
                conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
                delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                                  "d (b l) -> b d l", l=L)
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
            dx, dz = dxz.chunk(2, dim=1)
            dout = rearrange(dout, "b l e -> e (b l)")
            dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
            dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
                True  # option to recompute out_z
            # flip one
            dz_b = torch.empty_like(dz)
            dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
                conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias,
                dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
                True  # option to recompute out_z
            dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
            ddelta = ddelta + ddelta_f_b.flip([-1])
            dB = dB + dB_f_b.flip([-1])
            dC = dC + dC_f_b.flip([-1])
            dD = dD + dD_b
            ddelta_bias = ddelta_bias + ddelta_bias_b
            dz = dz + dz_b.flip([-1])
            out_z = out_z_f + out_z_b.flip([-1])
            dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
            dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
            dD = dD if D is not None else None
            dx_dbl = torch.empty_like(x_dbl)
            dB_proj_bias = None
            if ctx.is_variable_B:
                if not A.is_complex():
                    dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
                    dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
                dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
                dB = None
            dC_proj_bias = None
            if ctx.is_variable_C:
                if not A.is_complex():
                    dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
                    dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
                dx_dbl[:, -d_state:] = dC  # (bl d)
                dC = None
            ddelta = rearrange(ddelta, "b d l -> d (b l)")
            ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
            dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
            dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
            dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
            dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
            dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
            # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
            # backward of conv1d with the backward of chunk).
            dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
                x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
            dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
            dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
            return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                    dout_proj_weight, dout_proj_bias,
                    dA, dA_b, dB, dC, dD,
                    ddelta_bias if delta_bias is not None else None,
                    dB_proj_bias, dC_proj_bias, None)
    def mamba_inner_fn(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
        return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                               out_proj_weight, out_proj_bias,
                               A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    def bimamba_inner_fn(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
        return bimamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                 out_proj_weight, out_proj_bias,
                                 A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    def mamba_inner_fn_no_out_proj(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
        return mamba_inner_ref_fn_no_out_proj(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    def mamba_inner_ref(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
        delta = rearrange(delta, "d (b l) -> b d l", l=L)
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        if C is None:  # variable B
            C = x_dbl[:, -d_state:]  # (bl d)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
        return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
    def mamba_inner_ref_fn_no_out_proj(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
        delta = rearrange(delta, "d (b l) -> b d l", l=L)
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        if C is None:  # variable B
            C = x_dbl[:, -d_state:]  # (bl d)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
        # return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
        return y
    def bimamba_inner_ref(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
        delta = rearrange(delta, "d (b l) -> b d l", l=L)
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        if C is None:  # variable B
            C = x_dbl[:, -d_state:]  # (bl d)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
        y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]),
                                delta_bias, delta_softplus=True)
        y = y + y_b.flip([-1])
        return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
20240506 更新

在用 pip install mamba-ssm 安装完 mamba-ssm 发现原来正常运行的代码出现以下报错:

File "/home/xxx/.conda/envs/mamba/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward
    conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor

Invoked with: tensor(
        [-4.9056e-40, -4.9057e-40, -4.9074e-40, -4.9078e-40]], device='cuda:0',
       requires_grad=True), Parameter containing:
tensor([ 0.0322, -0.1139,  0.0770,  ..., -0.0320, -0.1266, -0.1096],
       device='cuda:0', requires_grad=True), None, None, None, True
经过检查发现是 mamba-ssm 版本的问题,报错的版本号为 1.2.0.post1,即 pip install mamba-ssm 安装的是最新版本,与之前的函数存在部分不兼容,而之前正常运行版本为 1.1.3.post1

20240523 更新

有小伙伴在跑 Vision Mamba 时遭遇以下报错(Linux):

TypeError: Mamba.__init__() got an unexpected keyword argument 'bimamba_type'
  • 1

因为 Vision Mamba 修改了 Mamba 的源代码,从 Mamba 官方途径安装的包中是没有这个函数的,所以需要先卸载原版Mamba ,再从 Vision Mamba 代码里的Mamba 源码手动安装,而不是从 Mamba 官方途径安装。不过实测也可以直接进行文件替换,用 Vision Mambaselective_scan_interface.py 替换 selective_scan_interface.py,替换 causal_conv1d_interface.pymamba_simple.py

20240531 更新

最近有小伙伴在安装时出现以下报错:ImportError: cannot import name 'packaging' from 'pkg_resources,原因是 setuptools 版本太高,一般是70.0.0,需要降级,直接 pip install setuptools==68.2.2 即可。

20240604 更新

1. Linux 下 找不到 selective_scan_cuda

有小伙伴在配置 Vision Mamba 时遇到以下错误:
NameError: name 'selective_scan_cuda' is not defined. Did you mean: 'selective_scan_fn',出现该问题的原因是压根没按照原文所说安装 causal_conv1dmamba_ssm,直接复制的源码。不过按照原文指示的安装方法,大概会报错,在此提供一个更简单的方法:

  1. 按照本文 解决方案(Linux)三正常安装原版 causal_conv1dmamba_ssm
  2. 直接用 Vision Mamba 工程下的 causal_conv1dmamba_ssm 替换环境中已经装好的对应位置的 causal_conv1dmamba_ssm

2. Win 下 ModuleNotFoundError: No module named 'causal_conv1d_cuda'

这其实是一个老问题,参看本文 20240424更新 第一个问题,原因是 causal_conv1d 没有装成功,正常来说装成功之后在位置 xxxx\envs\xxxx\Lib\site-packages\ 下面有一个 causal_conv1d_cuda.cp310-win_amd64.pyd (以本人环境为例)文件,该文件下载链接为causal-conv1d-cuda.cp310-win-amd64.pyd。可以按照本文前述方法在Windows下面从源码编译,或者在配置好前面环境后直接下载本人编译好的whl安装。【文件均可联系本人vx自取】

20240607 更新

1. 关于在Windows系统下Vim的环境配置问题

先按照前文所述的 解决方案(Win) 配置好Mamba,再根据Vision Mamba 的源码相应地修改安装好的mamba包的源码。参考前文 20240424更新causal_conv1d_cuda for Vim 以及 selective_scan_cuda for Vim
移步Window 下 Vim 环境安装踩坑问题汇总及解决方法

2. Win下面跑Vim出现 KeyError: 'HOME'


Traceback (most recent call last):
  File "xxx\models\vimamba.py", line 115, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "D:\Anaconda\envs\xxx\lib\site-packages\torch\autograd\function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 155, in _layer_norm_fwd
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 73, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 73, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 63, in _bench
    return do_bench(kernel_call)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\testing.py", line 136, in do_bench
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 62, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 41, in _layer_norm_fwd_1pass_kernel
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1230, in compile
    so_cache_manager = CacheManager(so_cache_key)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1102, in __init__
    self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1093, in default_cache_dir
    return os.path.join(os.environ["HOME"], ".triton", "cache")
  File "D:\Anaconda\envs\xxx\lib\os.py", line 680, in __getitem__
    raise KeyError(key) from None
KeyError: 'HOME'
在Win下还需要修改 mamba 安装路径下 D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py 文件,具体来说,是把原来 layernorm.py 里面的

def layer_norm_fn(
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)

def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)

def layer_norm_fn(
    return layer_norm_ref(x, weight, bias, residual, eps, prenorm, residual_in_fp32)

def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
    return rms_norm_ref(x, weight, bias, residual, eps, prenorm, residual_in_fp32)
3. Win 下配置Vim环境总结

移步Window 下 Vim 环境安装踩坑问题汇总及解决方法
Vim源码通过setup.py 分别安装 causal_conv1dmamba_ssm 之前

  1. 修改 Vim/mamba-1p1p1/mamba_ssm/ops/triton/layernorm.py
  2. 修改 Vim/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py
  3. 修改 Vim/causal-conv1d/causal_conv1d/causal_conv1d_interface.py

也可以通过setup.py 分别安装 causal_conv1dmamba_ssm 之后

  1. 修改 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\ops\triton\layernorm.py
  2. 修改 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\ops\selective_scan_interface.py
  3. 修改 xxx\Anaconda\envs\xxx\Lib\site-packages\causal_conv1d\causal_conv1d_interface.py

如果是通过前文所述的 解决方案(Win) 配置好Mamba,再跑 Vim,除了修改这三处源码之外则还需要用Vim源码 中的 mamba_simple.py 文件 替换 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\modules\mamba_simple.py

20240715 更新

  1. 经过实验,Windows 下 Mamba 和 Vmamba 均可正常编译,无需绕过 selective_scan_cuda,安装步骤参考开头导航的最新博客,Vim 依旧可以先装好Mamba再覆盖虚拟环境的源码。
  2. 经网友提醒,所有下载积分增至5,不鼓励从csdn下载;正确的安装步骤、可能的报错解决方案及参考的资料在系列博客中均已公开,符合开源精神,倡导亲自动手实践。
  3. 实在不行,安装问题(已仔细阅读开头导航的系列博客且没有找到解决方法) / 资源自取 / idea合作 请联系主页 vx,私信评论数量过多且有限制,随缘看。

