赞
踩
最近Mamba有关的论文引起了众多人的关注,笔者知道很多人开始复现官方代码,但是由于官方代码虚拟环境创建在ubuntu系统下,因此在windows系统下复现代码遇到各种各样的问题。笔者在经过一晚上的尝试,总算在win11系统anaconda中成功配置了环境。
Building wheel for causal-conv1d (setup.py) ... error
ERROR: Could not build wheels for causal-conv1d, which is required to install pyproject.toml-based projects
主要原因是CUDA版本不兼容。
本次复现主要参考这篇博客方法三,并根据自己电脑情况修改了一些步骤。
- conda create -n your_env_name python=3.10.13
- conda activate your_env_name
- conda install cudatoolkit==11.8 -c nvidia
- pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
- conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
- conda install packaging
接下来安装 'triton’包:tritan包安装教程,有大神编译了Windows下二进制文件,下载到本地后,在anacoda终端中,切换到tritan所在文件夹,输入
pip install triton-2.0.0-cp310-cp310-win_amd64.whl
然后使用源码编译安装causal-conv1d,注意在此之前,请检查torch的cuda版本和你自己的cuda版本是否一致。并输入
where nvcc
检查CUDA版本(最好在此电脑-属性-系统-高级系统设置-环境变量)中检查新安装的CUDA11.8是否添加(在这里卡了很久,也不知道是什么原因,where nvcc指令只能找到自己电脑的旧版本CUDA,笔者干脆手动安装了CUDA11.8和CUDNN安装教程),安装完成后,重启电脑,重新进入anaconda环境。接下来进行causal-conv1d安装,笔者采用源码编译的方式。首先请在causal-conv1d安装链接下载好对应版本的安装包(笔者下载的是1.0.0版本)。下载到本地后,解压,anaconda激活环境后进入该文件夹。输入
pip install .
在这里可能会出现
有时候缓存文件可能会导致安装出错。你可以尝试清理 pip 或 conda 的缓存
pip cache purge
然后再输入pip install .就可以啦。
接下来是mamba源码编译,请在mamba官方代码中setup.py文件修改配置
- FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
- SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
selective_scan_cuda
包括进去,导入模块还是会出错。请在mamba_ssm/ops/selective_scan_interface.py该文件中注释掉: import selective_scan_cuda
将
- def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
- return_last_state=False):
- """if return_last_state is True, returns (out, last_state)
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
- not considered in the backward pass.
- """
- return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
-
-
- def mamba_inner_fn(
- xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
- out_proj_weight, out_proj_bias,
- A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
- C_proj_bias=None, delta_softplus=True
- ):
- return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
- out_proj_weight, out_proj_bias,
- A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
改为
- def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
- return_last_state=False):
- """if return_last_state is True, returns (out, last_state)
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
- not considered in the backward pass.
- """
- return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
-
- def mamba_inner_fn(
- xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
- out_proj_weight, out_proj_bias,
- A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
- C_proj_bias=None, delta_softplus=True
- ):
- return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
- out_proj_weight, out_proj_bias,
- A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
-
最后就可以正常做mamba框架的相关实验啦~
最后放一下Mamba模型的简要介绍:
Mamba是一种新颖的序列建模架构,被誉为Transformer的潜在竞争对手。相比于传统的模型,Mamba引入了选择性状态空间的概念,以更高效和有效地捕获相关信息。以下是Mamba的关键特点:
Mamba的核心组成包括固定主干和输入相关转换。在训练期间,它类似于Transformer,同时处理整个序列;在推理中,它更符合传统的循环模型,提供有效的序列处理。
此外,Mamba还使用了SRAM(Static Random-Access Memory)来优化内存需求,使其成为处理长序列的有前途的模型。
笔者研一菜鸡一枚~在安装的过程中conda下载完cuda11.8后又手动下载一遍(真的没搞懂),希望评论区有大神知道其中原理的欢迎交流~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。