当前位置:   article > 正文

windows系统下anaconda中配置Mamba官方代码环境_windows配置mamba-ssm

windows配置mamba-ssm

#项目场景#

最近Mamba有关的论文引起了众多人的关注,笔者知道很多人开始复现官方代码,但是由于官方代码虚拟环境创建在ubuntu系统下,因此在windows系统下复现代码遇到各种各样的问题。笔者在经过一晚上的尝试,总算在win11系统anaconda中成功配置了环境。

问题描述 

  1. Building wheel for causal-conv1d (setup.py) ... error
  2. ERROR: Could not build wheels for causal-conv1d, which is required to install pyproject.toml-based projects
  3. ERROR: Could not build wheels for mamba-ssm, which is required to install pyproject.toml-based projects

主要原因是CUDA版本不兼容。

解决方案 

本次复现主要参考这篇博客方法三,并根据自己电脑情况修改了一些步骤。

  1. conda create -n your_env_name python=3.10.13
  2. conda activate your_env_name
  3. conda install cudatoolkit==11.8 -c nvidia
  4. pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  5. conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
  6. conda install packaging

接下来安装 'triton’包:tritan包安装教程,有大神编译了Windows下二进制文件,下载到本地后,在anacoda终端中,切换到tritan所在文件夹,输入

pip install triton-2.0.0-cp310-cp310-win_amd64.whl

然后使用源码编译安装causal-conv1d,注意在此之前,请检查torch的cuda版本和你自己的cuda版本是否一致。并输入

where nvcc

检查CUDA版本(最好在此电脑-属性-系统-高级系统设置-环境变量)中检查新安装的CUDA11.8是否添加(在这里卡了很久,也不知道是什么原因,where nvcc指令只能找到自己电脑的旧版本CUDA,笔者干脆手动安装了CUDA11.8和CUDNN安装教程),安装完成后,重启电脑,重新进入anaconda环境。接下来进行causal-conv1d安装,笔者采用源码编译的方式。首先请在causal-conv1d安装链接下载好对应版本的安装包(笔者下载的是1.0.0版本)。下载到本地后,解压,anaconda激活环境后进入该文件夹。输入

pip install .

在这里可能会出现

有时候缓存文件可能会导致安装出错。你可以尝试清理 pip 或 conda 的缓存

pip cache purge

然后再输入pip install .就可以啦。

接下来是mamba源码编译,请在mamba官方代码中setup.py文件修改配置

  1. FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
  2. SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
  • 此时,可以编译完成,但是无法将 selective_scan_cuda 包括进去,导入模块还是会出错。请在mamba_ssm/ops/selective_scan_interface.py该文件中注释掉:
    import selective_scan_cuda
    

  1. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  2. return_last_state=False):
  3. """if return_last_state is True, returns (out, last_state)
  4. last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
  5. not considered in the backward pass.
  6. """
  7. return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  8. def mamba_inner_fn(
  9. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  10. out_proj_weight, out_proj_bias,
  11. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  12. C_proj_bias=None, delta_softplus=True
  13. ):
  14. return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  15. out_proj_weight, out_proj_bias,
  16. A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

 改为

  1. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  2. return_last_state=False):
  3. """if return_last_state is True, returns (out, last_state)
  4. last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
  5. not considered in the backward pass.
  6. """
  7. return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  8. def mamba_inner_fn(
  9. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  10. out_proj_weight, out_proj_bias,
  11. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  12. C_proj_bias=None, delta_softplus=True
  13. ):
  14. return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  15. out_proj_weight, out_proj_bias,
  16. A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

最后就可以正常做mamba框架的相关实验啦~

最后放一下Mamba模型的简要介绍:

Mamba模型:选择性状态空间的新架构

Mamba是一种新颖的序列建模架构,被誉为Transformer的潜在竞争对手。相比于传统的模型,Mamba引入了选择性状态空间的概念,以更高效和有效地捕获相关信息。以下是Mamba的关键特点:

  1. 线性时间复杂度:与Transformer不同,Mamba在序列长度方面以线性时间运行,适用于处理非常长的序列任务。
  2. 灵活性和效率:Mamba结合了传统状态空间模型和循环神经网络的优点,具有高效计算和灵活性。
  3. 硬件感知算法:Mamba使用一种硬件感知算法,通过扫描操作而不是卷积,在GPU上高效地执行计算。

Mamba的核心组成包括固定主干和输入相关转换。在训练期间,它类似于Transformer,同时处理整个序列;在推理中,它更符合传统的循环模型,提供有效的序列处理。

此外,Mamba还使用了SRAM(Static Random-Access Memory)来优化内存需求,使其成为处理长序列的有前途的模型。

笔者研一菜鸡一枚~在安装的过程中conda下载完cuda11.8后又手动下载一遍(真的没搞懂),希望评论区有大神知道其中原理的欢迎交流~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/565386
推荐阅读
相关标签
  

闽ICP备14008679号