赞
踩
Mamba是ICLR2024上出圈的建模时序数据的新架构,基于S4(Selective Structured State Space)架构。
但是Mamba官方源码https://github.com/state-spaces/mamba给出的环境配置方法,实际操作时经常出现相关bug,主要集中在causal-conv1d和mamba_ssm这两个包的安装上。
下面给出在3090上成功配置mamba源码环境的步骤:
conda create -n mamba python=3.9
conda activate mamba
conda install cudatoolkit==11.7 -c nvidia
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install -c "nvidia/label/cuda-11.7.0" cuda-nvcc
conda install packaging
上述安装时如果出现找不到包的情况,可能是网络问题,也可能需要添加镜像源
pip install causal-conv1d==1.0.0
如果上述安装失败,则进行离线安装
wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.0.0/causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
pip install causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
pip install mamba_ssm==1.0.1
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。