赞
踩
最近Mamba
有关的论文引起了众多人的关注,虽然Mamba
论文自身被ICLR 2024拒稿,但是其衍生的模型层出不穷,诸如 Vim 和 Umamba 等。笔者在配置相关环境(版本安装要求:PyTorch 1.12+;CUDA 11.6+)时,发现按照他们给的安装方法12安装时会遇到非常多的bug,主要集中在causal-conv1d
和mamba-ssm
上,原因都是版本兼容问题,特此记录。
P.S. 经过和网友的深入讨论,本文内容已经有大幅扩充,大家如果有遇到新的问题及解决方法,希望大家可以告诉我,经我们共同验证有效的bug解决方法将会及时更新进本文!
(安装问题 / 资源自取 / 论文合作想法请+vx:931744281
)
直接 pip 安装或者下载工程文件再setup,出现了以下报错但不限于:
Building wheel for causal-conv1d (setup.py) ... error
error: command '/usr/bin/gcc' failed with exit code 1
RuntimeError: Error compiling objects for extension
ERROR: Could not build wheels for causal-conv1d, which is required to install pyproject.toml-based projects
Connection timed out> [end of output]
ModuleNotFoundError: No module named 'packaging'
FileNotFoundError: [Errno 2] No such file or directory: '/usr/local/cuda/bin/nvcc'
error: subprocess-exited-with-error
大部分原因是CUDA版本不匹配,有部分是网络原因。
使用网友配置好的Docker环境,参考:解决causal_conv1d和mamba_ssm无法安装 -> 直接使用Mamba基础环境docker镜像。
DockHub仓库地址:https://hub.docker.com/repository/docker/kom4cr0/cuda11.7-pytorch1.13-mamba1.1.1/general
代码:docker pull kom4cr0/cuda11.7-pytorch1.13-mamba1.1.1:1.1.1
直接下载工程文件,再setup。具体可参考:运行Mamba项目时无法直接用pip install安装causal_conv1d和mamba_ssm 和复现U-Mamba(笔者依然未安装成功,但是原作者以及GitHub issue 里有部分人可以安装成功)
参考步骤为:
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.1 # current latest version tag
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .
cd ..
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v1.1.1 # current latest version tag
pip install . # 方式一,下载whl安装,两种方式选择一个即可
MAMBA_FORCE_BUILD=TRUE pip install . # 方式二,强制在本地编译安装,Win 下无法识别此命令
受博文 “flash-attention踩坑:使用conda管理CUDA”启发,合理调整安装顺序,先安装CUDA,并且安装cuda-nvcc,正确的安装步骤如下:
conda create -n your_env_name python=3.10.13
conda activate your_env_name
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal-conv1d==1.1.1 # 版本号根据实际情况选择,或者不指定直接安装最新
pip install mamba-ssm==1.1.3.post1 # 版本号根据实际情况选择,1.1 和 1.2 实测有函数不兼容,不设定默认装最新版本
20240313更新
中步骤的修改代码,注意由于跳过了核心部分的CUDA加速,虽然可以跑通,但是速度很慢;20240329更新
,利用 Win 中的 WSL,或者Linux虚拟机。如果方法三中倒数第二步无法安装,则需要从项目源码编译。
Windows 下安装mamba-ssm
在方法三倒数第三步之后会不一样,即需要先安装 'triton’包,之后从causal-conv1d 以及mamba源码编译,并且修改源码
pip install triton-2.0.0-cp310-cp310-win_amd64.whl
。(评论还提供了triton 2.1.0 版本的下载链接)解决方案(Linux)
方法二。部分同学由于环境问题,需要跳过 causal_conv1d_cuda
,可参考 20240424更新
。本人编译好了 Windows 下的 causal_conv1d-1.1.1-cp310-cp310-win_amd64.whl,可直接下载安装。setup.py
修改配置FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
pip install .
命令之前设置 set MAMBA_FORCE_BUILD=TRUE
以及 set MAMBA_SKIP_CUDA_BUILD=TRUE
。Linux 的命令 MAMBA_FORCE_BUILD=TRUE pip install .
在 Win 下会报错。selective_scan_cuda
包括进去,导入模块还是会出错。ops/selective_scan_interface.py
文件下,注释掉import selective_scan_cuda
将
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
改为
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
当然,本人编译好的已经修改源码绕过 无需绕过 CDUA 加速,请移步开头导航里的系列博客!selective_scan_cuda
的Windows 下的whl 也有:mamba-ssm-1.1.3 或 mamba_ssm-1.2.0.post1,可直接下载安装或联系本人vx自取。
Win 下Mamaba的安装除了利用docker、修改源码编译之外,也有人通过WSL成功跑通最新mamba模型,参考:
原生Windows通过WSL成功跑通最新mamba模型
不少小伙伴在装完 cuda-nvcc
以后,安装 causal-conv1d
时还是会显示CUDA版本不对的错误,这是由于环境中还可能有CUDA_HOME
(Linux)或 CUDA_PATH
(Windows)变量指定到错误的位置,此时需要检查:
nvcc -V
python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.CUDA_HOME)"
确保其输出的是正确的版本或位置。尤其是要保证第二句命令输出的位置是正确的。
在 Linux 下,如果第二句命令输出位置是base环境的,使用 which nvcc
获取虚拟环境正确的路径,然后在 .bashrc
里面设置成这个位置 export CUDA_HOME='....'
,source ~/.bashrc
激活配置,然后再继续安装过程。
在 Win 下,则使用 where nvcc
虚拟环境正确的路径(路径到bin,不包括 nvcc.exe),把系统环境变量里的 CUDA_PATH
修改为该路径,然后继续安装过程。
pytorch选择cuda的顺序可参考博文:pytorch选择cuda的顺序【关于cudatoolkit和/usr/local/cuda】。
在Linux下卡住不动是因为它在下载对应的 *.whl
文件,需要科学上网,可以等它下载失败输出正确的网址,然后手动下载再pip install 这个 whl 文件。可以直接下载whl安装
在我的配置下面:
causal_conv1d 下载链接为:https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.1.1/causal_conv1d-1.1.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
mamba_ssm 下载链接为:https://github.com/state-spaces/mamba/releases/download/v1.1.3.post1/mamba_ssm-1.1.3.post1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
causal_conv1d_cuda
或selective_scan_cuda
还是报错不少小伙伴在成功安装 causal-conv1d
之后还是会出现 import causal_conv1d_cuda
提示没有名称为 causal_conv1d_cuda
的包,或者在成功安装 mamba-ssm
之后出现 import selective_scan_cuda
提示没有名称为 selective_scan_cuda
的包,这还是CUDA环境不兼容导致的。这两个函数对应着Python程序编译动态库(Linux 下为.so 文件,Windows下为.pyd文件),不在安装好后的源码中,而在 xxxx/envs/xxxx/lib/python3.xx/site-packages/
下面,分别对应 causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so
(以本人环境为例)和 selective_scan_cuda.cpython-310-x86_64-linux-gnu.so
(以本人环境为例)。
此时建议用源码方式在本地强制编译安装(CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .
或者 MAMBA_FORCE_BUILD=TRUE pip install .
),此时有的小伙伴会成功,有的小伙伴还是会报错,但是报错会给出具体信息,譬如 ImportError xxxx selective_scan_cuda.cpython-xxx-linux-gnu.so undefined symbol
(可以用编译好的文件直接替换,selective-scan-cuda-linux-gnu.so)或者 ImportError xxxx causal_conv1d_cuda.cpython-xxx-linux-gnu.so undefined symbol
(可以用编译好的文件直接替换,causal-conv1d-cuda.cpython-310-x86-64-linux-gnu.so),由于大家环境不一样,根据各自相应报错情况再针对性解决。【文件均可联系本人vx自取】
出现 .so undefined symbol
一般是因为 CUDA 版本不匹配造成的,参考本博客 20240418更新-关于CDUA版本
。譬如在虚拟环境中 which nvcc
调用的是虚拟环境的 cuda,但是 python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.CUDA_HOME)"
输出的位置确是base 环境的 usr/local/cuda
。
此外,可以按照Win下的方法,修改源文件绕过对 causal_conv1d_cuda
和 selective_scan_cuda
的调用。
causal_conv1d_cuda
和 selective_scan_cuda
的调用causal_conv1d_cuda
:在 causal_conv1d_interface.py
文件中,注释掉 import causal_conv1d_cuda
,且将
def causal_conv1d_fn( x, weight, bias=None, seq_idx=None, initial_states=None, return_final_states=False, final_states_out=None, activation=None, ): """ x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) seq_idx: (batch, seqlen) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1), to be written to activation: either None or "silu" or "swish" out: (batch, dim, seqlen) """ return CausalConv1dFn.apply( x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation, )
改为:
def causal_conv1d_fn( x, weight, bias=None, seq_idx=None, initial_states=None, return_final_states=False, final_states_out=None, activation=None, ): """ x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) seq_idx: (batch, seqlen) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1), to be written to activation: either None or "silu" or "swish" out: (batch, dim, seqlen) """ return causal_conv1d_ref( x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation, )
版本不同可能会有差异,但是都改这个函数。
causal_conv1d_cuda
for Vim ( Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model ):
def causal_conv1d_fn(x, weight, bias=None, activation=None):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
return CausalConv1dFn.apply(x, weight, bias, activation)
改为:
def causal_conv1d_fn(x, weight, bias=None, activation=None):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
return causal_conv1d_ref(x, weight, bias, activation)
selective_scan_cuda
:见 20240313更新
。
selective_scan_cuda
for Vim ( Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model ):由于它对源码做了修改,所以如果想在Window下跑通 Vision Mamba 为了避开这个函数也需要相应修改,如下所示。
# Copyright (c) 2023, Tri Dao, Albert Gu. import torch import torch.nn.functional as F from torch.cuda.amp import custom_bwd, custom_fwd from einops import rearrange, repeat try: from causal_conv1d import causal_conv1d_fn import causal_conv1d_cuda except ImportError: causal_conv1d_fn = None causal_conv1d_cuda = None # import selective_scan_cuda class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: delta = delta.contiguous() if D is not None: D = D.contiguous() if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() if B.dim() == 3: B = rearrange(B, "b dstate l -> b 1 dstate l") ctx.squeeze_B = True if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) else: ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors z = None out = None else: u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, False # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return (du, ddelta, dA, dB, dC, dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """ u: r(B D L) delta: r(B D L) A: c(D N) or r(D N) B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) return out if not return_last_state else (out, last_state) class MambaInnerFnNoOutProj(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) # return rearrange(out_z, "b d l -> b l d") return out_z @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz, ctx.delta_softplus, True # option to recompute out_z ) dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) class MambaInnerFn(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( x, conv1d_weight, conv1d_bias, None, None, None, True ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( x, conv1d_weight, conv1d_bias, None, None, None, True ) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) dout = rearrange(dout, "b l e -> e (b l)") dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, True # option to recompute out_z ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) class BiMambaInnerFn(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) assert not A_b.is_complex(), "A should not be complex!!" out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, ) out_z = out_z_f + out_z_b.flip([-1]) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) dout = rearrange(dout, "b l e -> e (b l)") dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz, ctx.delta_softplus, True # option to recompute out_z ) # flip one dz_b = torch.empty_like(dz) dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b, ctx.delta_softplus, True # option to recompute out_z ) dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1]) ddelta = ddelta + ddelta_f_b.flip([-1]) dB = dB + dB_f_b.flip([-1]) dC = dC + dC_f_b.flip([-1]) dD = dD + dD_b ddelta_bias = ddelta_bias + ddelta_bias_b dz = dz + dz_b.flip([-1]) out_z = out_z_f + out_z_b.flip([-1]) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dA_b, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def bimamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return bimamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def mamba_inner_fn_no_out_proj( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): return mamba_inner_ref_fn_no_out_proj(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) def mamba_inner_ref_fn_no_out_proj( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) # return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) return y def bimamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True) y = y + y_b.flip([-1]) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
在用 pip install mamba-ssm
安装完 mamba-ssm
发现原来正常运行的代码出现以下报错:
File "/home/xxx/.conda/envs/mamba/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor
Invoked with: tensor(
[-4.9056e-40, -4.9057e-40, -4.9074e-40, -4.9078e-40]], device='cuda:0',
requires_grad=True), Parameter containing:
tensor([ 0.0322, -0.1139, 0.0770, ..., -0.0320, -0.1266, -0.1096],
device='cuda:0', requires_grad=True), None, None, None, True
经过检查发现是 mamba-ssm
版本的问题,报错的版本号为 1.2.0.post1
,即 pip install mamba-ssm
安装的是最新版本,与之前的函数存在部分不兼容,而之前正常运行版本为 1.1.3.post1
。
有小伙伴在跑 Vision Mamba
时遭遇以下报错(Linux):
TypeError: Mamba.__init__() got an unexpected keyword argument 'bimamba_type'
因为 Vision Mamba
修改了 Mamba
的源代码,从 Mamba
官方途径安装的包中是没有这个函数的,所以需要先卸载原版Mamba
,再从 Vision Mamba
代码里的Mamba 源码手动安装,而不是从 Mamba
官方途径安装。不过实测也可以直接进行文件替换,用 Vision Mamba
的selective_scan_interface.py 替换 selective_scan_interface.py,替换 causal_conv1d_interface.py
和 mamba_simple.py
。
最近有小伙伴在安装时出现以下报错:ImportError: cannot import name 'packaging' from 'pkg_resources
,原因是 setuptools
版本太高,一般是70.0.0,需要降级,直接 pip install setuptools==68.2.2
即可。
selective_scan_cuda
有小伙伴在配置 Vision Mamba
时遇到以下错误:
NameError: name 'selective_scan_cuda' is not defined. Did you mean: 'selective_scan_fn'
,出现该问题的原因是压根没按照原文所说安装 causal_conv1d
和 mamba_ssm
,直接复制的源码。不过按照原文指示的安装方法,大概会报错,在此提供一个更简单的方法:
解决方案(Linux)
三正常安装原版 causal_conv1d
和 mamba_ssm
;Vision Mamba
工程下的 causal_conv1d 和 mamba_ssm 替换环境中已经装好的对应位置的 causal_conv1d
和 mamba_ssm
。ModuleNotFoundError: No module named 'causal_conv1d_cuda'
这其实是一个老问题,参看本文 20240424更新
第一个问题,原因是 causal_conv1d 没有装成功,正常来说装成功之后在位置 xxxx\envs\xxxx\Lib\site-packages\
下面有一个 causal_conv1d_cuda.cp310-win_amd64.pyd
(以本人环境为例)文件,该文件下载链接为causal-conv1d-cuda.cp310-win-amd64.pyd。可以按照本文前述方法在Windows下面从源码编译,或者在配置好前面环境后直接下载本人编译好的whl安装。【文件均可联系本人vx自取】
先按照前文所述的 解决方案(Win)
配置好Mamba,再根据Vision Mamba
的源码相应地修改安装好的mamba包的源码。参考前文 20240424更新
的causal_conv1d_cuda for Vim
以及 selective_scan_cuda for Vim
。
移步Window 下 Vim 环境安装踩坑问题汇总及解决方法
KeyError: 'HOME'
具体来说出现以下报错
Traceback (most recent call last): ..... File "xxx\models\vimamba.py", line 115, in forward hidden_states, residual = fused_add_norm_fn( File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 478, in rms_norm_fn return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) File "D:\Anaconda\envs\xxx\lib\site-packages\torch\autograd\function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 411, in forward y, mean, rstd, residual_out = _layer_norm_fwd( File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 155, in _layer_norm_fwd _layer_norm_fwd_1pass_kernel[(M,)]( File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\jit.py", line 106, in launcher return self.run(*args, grid=grid, **kwargs) File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 73, in run timings = {config: self._bench(*args, config=config, **kwargs) File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 73, in <dictcomp> timings = {config: self._bench(*args, config=config, **kwargs) File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 63, in _bench return do_bench(kernel_call) File "D:\Anaconda\envs\xxx\lib\site-packages\triton\testing.py", line 136, in do_bench fn() File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 62, in kernel_call self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) File "<string>", line 41, in _layer_norm_fwd_1pass_kernel File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1230, in compile so_cache_manager = CacheManager(so_cache_key) File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1102, in __init__ self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1093, in default_cache_dir return os.path.join(os.environ["HOME"], ".triton", "cache") File "D:\Anaconda\envs\xxx\lib\os.py", line 680, in __getitem__ raise KeyError(key) from None KeyError: 'HOME'
在Win下还需要修改 mamba 安装路径下 D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py
文件,具体来说,是把原来 layernorm.py 里面的
def layer_norm_fn( x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False, ): return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
改为
def layer_norm_fn(
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return layer_norm_ref(x, weight, bias, residual, eps, prenorm, residual_in_fp32)
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
return rms_norm_ref(x, weight, bias, residual, eps, prenorm, residual_in_fp32)
移步Window 下 Vim 环境安装踩坑问题汇总及解决方法
从Vim源码通过setup.py 分别安装 causal_conv1d
和 mamba_ssm
之前
Vim/mamba-1p1p1/mamba_ssm/ops/triton/layernorm.py
Vim/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py
Vim/causal-conv1d/causal_conv1d/causal_conv1d_interface.py
也可以通过setup.py 分别安装 causal_conv1d
和 mamba_ssm
之后
xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\ops\triton\layernorm.py
xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\ops\selective_scan_interface.py
xxx\Anaconda\envs\xxx\Lib\site-packages\causal_conv1d\causal_conv1d_interface.py
如果是通过前文所述的 解决方案(Win)
配置好Mamba,再跑 Vim,除了修改这三处源码之外则还需要用Vim源码 中的 mamba_simple.py
文件 替换 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\modules\mamba_simple.py
。
selective_scan_cuda
,安装步骤参考开头导航的最新博客,Vim 依旧可以先装好Mamba再覆盖虚拟环境的源码。Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。