当前位置:   article > 正文

(Windows傻瓜教程)Mamba安装以及问题汇总(Causal-Conv1d & Mamba-ssm)_windows mamba

windows mamba

目录

Mamba优势

1. 本人环境配置

1.1 傻瓜式安装Cuda

1.2 上流环境变量法

2. 安装Causal-Conv1d 

2.1 直接安装(亲测不是很行)

2.2 傻瓜式安装 

3. 安装mamba-ssm

4. 参考


安装完之后发现自己真是大聪明(bushi

最近MambaVision也开始火了,老师那颗热爱科研的心也是有一次被激起了,我那个悬着的心也终于在配置Mamba的时候Die了.......


Mamba优势

  1. Selection mechanism(选择机制)引入了Gate,类似RNN的门控机制
  2. Linearly in sequence length(线性计算),降低train和inference计算量
  3. Hardware-aware Algorithm 降低硬件开销
  4. Architecture was composed with H3 and Gated MLP

1. 本人环境配置

  1. python==3.10.14
  2. causal-conv1d==1.1.1
  3. triton==2.0.0
  4. mamba-ssm==1.2.0.post1
  5. torch==2.1.1+cu118
  6. torchaudio==2.1.1+cu118
  7. torchvision==0.16.1+cu118

由于之前在网站上面搜索的时候都是基于cu118+python3.10的安装包(triton的whl包也是),因此以cuda11.8进行安装,主打一个听劝^.^

1.1 傻瓜式安装Cuda

直接去Nvidia官网进行下载cu11.8cu11.8icon-default.png?t=N7T8https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows

最好是进行下载local文件进行本地安装,exe可能需要科学,安装完了之后可以直接进环境变量path看看是否有cu11.8(nvidia-ssm中的cuda version只是安装过版本问题,只需要查看nvcc -V查看是否为

  1. nvcc: NVIDIA (R) Cuda compiler driver
  2. Copyright (c) 2005-2022 NVIDIA Corporation
  3. Built on Wed_Sep_21_10:41:10_Pacific_Daylight_Time_2022
  4. Cuda compilation tools, release 11.8, V11.8.89
  5. Build cuda_11.8.r11.8/compiler.31833905_0

1.2 上流环境变量法

创建虚拟环境并且安装cudatoolkit==11.8,这样在虚拟变量中会额外有一个cuda版本(比较便捷),后面的cuda-nvcc一定要加,不然当前虚拟环境可能找不到对应的cuda版本

  1. conda create -n your_env_name python=3.10.13
  2. conda activate your_env_name
  3. conda install cudatoolkit==11.8 -c nvidia
  4. pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  5. conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

2. 安装Causal-Conv1d 

安装Causal-Conv1d在checkout的时候一定要找对应cuda的版本

首先安装一下 packaging

conda install packaging

2.1 直接安装(亲测不是很行)

pip install causal-conv1d==1.1.1

2.2 傻瓜式安装 

首先安装一下triton(里面包含了cmake)对后续编译提供基础

triton-2.0.0-cp10icon-default.png?t=N7T8https://hf-mirror.com/r4ziel/xformers_pre_built/blob/main/triton-2.0.0-cp310-cp310-win_amd64.whl

后续就是开始在github上面下载安装本地文件

  1. git clone https://github.com/Dao-AILab/causal-conv1d.git
  2. cd causal-conv1d
  3. git checkout v1.1.1

******为什么要checkout v1.1.1 因为这是支持cu118的最高版本******

然后就是进行install causal-conv1d的过程了

要在git(没有的话在git_download中进行下载)进行安装,在windows terminal中识别不到命令行

CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .

 按照这个步骤安装causal-conv1d应该不是什么大问题^.^

3. 安装mamba-ssm

不说别的歪门邪道了(直接傻瓜式安装)离线安装

  1. git clone https://github.com/state-spaces/mamba.git
  2. cd mamba
  3. git checkout v1.1.1

 先下载mamba的文件夹,然后进行文件安装

  • 在mamba源码的 setup.py 修改配置
  1. FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
  2. SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
  •  在mamba_ssm/ops/selective_scan_interface.py 进行修改

将以下代码

  1. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  2. return_last_state=False):
  3. """if return_last_state is True, returns (out, last_state)
  4. last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
  5. not considered in the backward pass.
  6. """
  7. return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  8. def mamba_inner_fn(
  9. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  10. out_proj_weight, out_proj_bias,
  11. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  12. C_proj_bias=None, delta_softplus=True
  13. ):
  14. return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  15. out_proj_weight, out_proj_bias,
  16. A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

 改为

  1. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  2. return_last_state=False):
  3. """if return_last_state is True, returns (out, last_state)
  4. last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
  5. not considered in the backward pass.
  6. """
  7. return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  8. def mamba_inner_fn(
  9. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  10. out_proj_weight, out_proj_bias,
  11. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  12. C_proj_bias=None, delta_softplus=True
  13. ):
  14. return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  15. out_proj_weight, out_proj_bias,
  16. A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

 然后直接mamba文件下进行安装即可

pip install .

 至此mamba大法修炼成功


4. 参考

https://blog.csdn.net/yyywxk/article/details/136071016

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/677674
推荐阅读
相关标签
  

闽ICP备14008679号