赞
踩
mamba-minimal/model.py at master · johnma2006/mamba-minimal · GitHub
- from dataclasses import dataclass
- from einops import rearrange, repeat, einsum
-
-
- @dataclass
- class ModelArgs:
- d_model: int
- n_layer: int
- vocab_size: int
- d_state: int = 16
- expand: int = 2
- dt_rank: Union[int, str] = 'auto'
- d_conv: int = 4
- pad_vocab_size_multiple: int = 8
- conv_bias: bool = True
- bias: bool = False
-
- def __post_init__(self):
- self.d_inner = int(self.expand * self.d_model)
-
- if self.dt_rank == 'auto':
- self.dt_rank = math.ceil(self.d_model / 16)
-
- if self.vocab_size % self.pad_vocab_size_multiple != 0:
- self.vocab_size += (self.pad_vocab_size_multiple
- - self.vocab_size % self.pad_vocab_size_multiple)
-
-
- class Mamba(nn.Module):
- def __init__(self, args: ModelArgs):
- """Full Mamba model."""
- super().__init__()
- self.args = args
-
- self.embedding = nn.Embedding(args.vocab_size, args.d_model)
- self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
- self.norm_f = RMSNorm(args.d_model)
-
- self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
- self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
- # See "Weight Tying" paper
-
-
- def forward(self, input_ids):
- """
- Args:
- input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
-
- Returns:
- logits: shape (b, l, vocab_size)
- Official Implementation:
- class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
- """
- x = self.embedding(input_ids)
-
- for layer in self.layers:
- x = layer(x)
-
- x = self.norm_f(x)
- logits = self.lm_head(x)
-
- return logits
- from dataclasses import dataclass
- from einops import rearrange, repeat, einsum
- import torch
- import torch.nn as nn
- import math
- from typing import Union
-
- @dataclass
- class ModelArgs:
- d_model: int # 模型的隐藏层维度
- n_layer: int # 模型的层数
- vocab_size: int # 词汇表大小
- d_state: int = 16 # 状态维度,默认值为16
- expand: int = 2 # 扩展因子,默认值为2
- dt_rank: Union[int, str] = 'auto' # 时间嵌入的秩,默认值为'auto'
- d_conv: int = 4 # 卷积层的维度,默认值为4
- pad_vocab_size_multiple: int = 8 # 词汇表大小的填充倍数,默认值为8
- conv_bias: bool = True # 是否在卷积层中使用偏置,默认值为True
- bias: bool = False # 是否在全连接层中使用偏置,默认值为False
-
- def __post_init__(self):
- self.d_inner = int(self.expand * self.d_model) # 内部维度,扩展因子乘以隐藏层维度
-
- if self.dt_rank == 'auto':
- self.dt_rank = math.ceil(self.d_model / 16) # 当dt_rank为'auto'时,计算实际的秩
-
- if self.vocab_size % self.pad_vocab_size_multiple != 0:
- # 调整词汇表大小,使其成为pad_vocab_size_multiple的倍数
- self.vocab_size += (self.pad_vocab_size_multiple
- - self.vocab_size % self.pad_vocab_size_multiple)
-
- class Mamba(nn.Module):
- def __init__(self, args: ModelArgs):
- """完整的Mamba模型。"""
- super().__init__()
- self.args = args
-
- self.embedding = nn.Embedding(args.vocab_size, args.d_model) # 嵌入层
- self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)]) # 残差块层列表
- self.norm_f = RMSNorm(args.d_model) # 归一化层
-
- self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) # 线性层,用于输出词汇表大小的logits
- self.lm_head.weight = self.embedding.weight # 将输出投影权重与嵌入层权重共享(权重绑定)
-
- def forward(self, input_ids):
- """
- Args:
- input_ids (long tensor): shape (b, l) (参见论文中的定义,b为批次大小,l为序列长度)
-
- Returns:
- logits: shape (b, l, vocab_size)
- 官方实现:
- class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
- """
- x = self.embedding(input_ids) # 将输入ID映射到嵌入向量
-
- for layer in self.layers:
- x = layer(x) # 依次通过每个残差块层
-
- x = self.norm_f(x) # 通过归一化层
- logits = self.lm_head(x) # 计算logits
-
- return logits # 返回logits
ModelArgs
数据类d_model
, n_layer
, vocab_size
等参数定义了模型的基本配置。d_state
, expand
, dt_rank
, d_conv
, pad_vocab_size_multiple
, conv_bias
, bias
是一些默认参数,用于进一步配置模型。__post_init__
方法在初始化后计算内部维度 (d_inner
) 和调整词汇表大小 (vocab_size
) 使其成为指定倍数。Mamba
类__init__
方法初始化 Mamba 模型,创建嵌入层 (embedding
)、一系列残差块层 (layers
)、归一化层 (norm_f
) 和输出层 (lm_head
)。forward
方法是模型的前向传播过程:
input_ids
被映射到嵌入向量。logits
。- class ResidualBlock(nn.Module):
- def __init__(self, args: ModelArgs):
- """残差块,包含选择性状态空间模型(SSM)和卷积层。"""
- super().__init__()
- self.args = args
- self.norm1 = RMSNorm(args.d_model) # 第一个归一化层
- self.norm2 = RMSNorm(args.d_model) # 第二个归一化层
- self.ssm = SelectiveSSM(args) # 选择性状态空间模型
- self.conv = nn.Conv1d(args.d_model, args.d_model, kernel_size=args.d_conv, bias=args.conv_bias) # 一维卷积层
-
- def forward(self, x):
- """残差块的前向传播。"""
- residual = x # 保存输入的残差
- x = self.norm1(x) # 通过第一个归一化层
- x = self.ssm(x) # 通过选择性状态空间模型
- x = self.norm2(x) # 通过第二个归一化层
- x = rearrange(x, 'b l d -> b d l') # 调整维度以适应卷积操作
- x = self.conv(x) # 通过卷积层
- x = rearrange(x, 'b d l -> b l d') # 调整维度回原始形状
- return x + residual # 将残差加回输出
-
- class SelectiveSSM(nn.Module):
- def __init__(self, args: ModelArgs):
- """选择性状态空间模型(SSM)。"""
- super().__init__()
- self.args = args
- self.hidden = nn.Parameter(torch.randn(args.d_model, args.d_state)) # 隐藏状态参数
- self.linear = nn.Linear(args.d_model, args.d_state, bias=args.bias) # 线性层将输入映射到状态维度
- self.output = nn.Linear(args.d_state, args.d_model, bias=args.bias) # 线性层将状态维度映射回输出维度
-
- def forward(self, x):
- """选择性状态空间模型的前向传播。"""
- state = torch.tanh(self.linear(x) + self.hidden) # 计算新的状态
- output = self.output(state) # 将状态映射回输出维度
- return output # 返回输出
ResidualBlock
类__init__
方法初始化残差块,包含两个归一化层 (norm1
和 norm2
)、一个选择性状态空间模型 (ssm
) 和一个一维卷积层 (conv
)。forward
方法是残差块的前向传播过程:
residual
)。rearrange
函数调整维度,以适应卷积操作。SelectiveSSM
类__init__
方法初始化选择性状态空间模型,包含隐藏状态参数 (hidden
)、一个将输入映射到状态维度的线性层 (linear
) 和一个将状态维度映射回输出维度的线性层 (output
)。forward
方法是选择性状态空间模型的前向传播过程:
tanh
激活函数。通过对上述代码的解释,我们可以看到 Mamba 模型如何结合选择性状态空间模型与卷积层来实现高效的序列建模。残差块和归一化层的使用有助于稳定训练并提高模型性能。这些设计选择充分体现了论文中所提到的改进和优化方法,旨在解决 Transformer 的计算效率问题,同时在多种模态下保持优异的性能。
- @staticmethod
- def from_pretrained(pretrained_model_name: str):
- """Load pretrained weights from HuggingFace into model.
-
- Args:
- pretrained_model_name: One of
- * 'state-spaces/mamba-2.8b-slimpj'
- * 'state-spaces/mamba-2.8b'
- * 'state-spaces/mamba-1.4b'
- * 'state-spaces/mamba-790m'
- * 'state-spaces/mamba-370m'
- * 'state-spaces/mamba-130m'
-
- Returns:
- model: Mamba model with weights loaded
-
- """
- from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
- from transformers.utils.hub import cached_file
-
- def load_config_hf(model_name):
- resolved_archive_file = cached_file(model_name, CONFIG_NAME,
- _raise_exceptions_for_missing_entries=False)
- return json.load(open(resolved_archive_file))
-
-
- def load_state_dict_hf(model_name, device=None, dtype=None):
- resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
- _raise_exceptions_for_missing_entries=False)
- return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
-
- config_data = load_config_hf(pretrained_model_name)
- args = ModelArgs(
- d_model=config_data['d_model'],
- n_layer=config_data['n_layer'],
- vocab_size=config_data['vocab_size']
- )
- model = Mamba(args)
-
- state_dict = load_state_dict_hf(pretrained_model_name)
- new_state_dict = {}
- for key in state_dict:
- new_key = key.replace('backbone.', '')
- new_state_dict[new_key] = state_dict[key]
- model.load_state_dict(new_state_dict)
-
- return model
这个 from_pretrained
静态方法的主要功能是从 HuggingFace 预训练模型库中加载预训练的模型配置和权重,并将其应用到 Mamba 模型中。具体步骤如下:
@staticmethod
def from_pretrained(pretrained_model_name: str):
from_pretrained
方法,接收一个预训练模型名称的字符串参数 pretrained_model_name
。"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name: One of
* 'state-spaces/mamba-2.8b-slimpj'
Returns:
model: Mamba model with weights loaded
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
transformers.utils
导入 WEIGHTS_NAME
和 CONFIG_NAME
常量。from transformers.utils.hub import cached_file
transformers.utils.hub
导入 cached_file
函数。def load_config_hf(model_name):
load_config_hf
,用于加载模型配置。resolved_archive_file = cached_file(model_name, CONFIG_NAME,
cached_file
函数获取配置文件的路径。_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
load_state_dict_hf
,用于加载模型权重。resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
cached_file
函数获取权重文件的路径。_raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
torch.load
加载权重文件,只加载权重部分,并映射到 CPU。config_data = load_config_hf(pretrained_model_name)
load_config_hf
函数加载配置数据。args = ModelArgs(
ModelArgs
对象。d_model=config_data['d_model'],
d_model
参数并赋值。n_layer=config_data['n_layer'],
n_layer
参数并赋值。vocab_size=config_data['vocab_size']
vocab_size
参数并赋值。)
ModelArgs
对象的初始化。model = Mamba(args)
args
参数创建 Mamba 模型实例。state_dict = load_state_dict_hf(pretrained_model_name)
load_state_dict_hf
函数加载模型的状态字典(权重)。new_state_dict = {}
new_state_dict
,用于存储处理后的状态字典。for key in state_dict:
state_dict
中的每个键。new_key = key.replace('backbone.', '')
'backbone.'
前缀替换为空字符串。new_state_dict[new_key] = state_dict[key]
new_state_dict
。model.load_state_dict(new_state_dict)
new_state_dict
加载模型的状态字典。return model
cached_file
函数分别获取模型的配置文件和权重文件。ModelArgs
对象。ModelArgs
对象创建 Mamba 模型实例。- class ResidualBlock(nn.Module):
- def __init__(self, args: ModelArgs):
- """Simple block wrapping Mamba block with normalization and residual connection."""
- super().__init__()
- self.args = args
- self.mixer = MambaBlock(args)
- self.norm = RMSNorm(args.d_model)
-
-
- def forward(self, x):
- """
- Args:
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
-
- Returns:
- output: shape (b, l, d)
- Official Implementation:
- Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
-
- Note: the official repo chains residual blocks that look like
- [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
- where the first Add is a no-op. This is purely for performance reasons as this
- allows them to fuse the Add->Norm.
- We instead implement our blocks as the more familiar, simpler, and numerically equivalent
- [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
-
- """
- output = self.mixer(self.norm(x)) + x
-
- return output
这里是对提供的Python代码片段中的ResidualBlock
类的逐行解释,该类是Mamba模型架构的一部分:
python
复制
- class ResidualBlock(nn.Module):
ResidualBlock
,继承自nn.Module
。nn.Module
是PyTorch中所有神经网络模块的基类,任何自定义模块都应该扩展它。python
复制
def __init__(self, args: ModelArgs):
ResidualBlock
的构造函数。它接受一个参数args
,预期是ModelArgs
的一个实例,这是一个类,它可能包含模型的配置参数。python
复制
- """Simple block wrapping Mamba block with normalization and residual connection."""
- super().__init__()
nn.Module
)的构造函数。这是必须的,以正确初始化PyTorch模块。python
复制
self.args = args
args
存储在一个实例变量中,以便可能在块中使用。python
复制
self.mixer = MambaBlock(args)
MambaBlock
的实例,传递args
。MambaBlock
可能是Mamba模型的一个关键组件,用于处理输入数据。python
复制
self.norm = RMSNorm(args.d_model)
args.d_model
指定了模型嵌入的维度,规范化层将使用这个维度。python
复制
def forward(self, x):
ResidualBlock
的前向传播。x
是传入块的输入张量。python
复制
- """
- Args:
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
-
- Returns:
- output: shape (b, l, d)
- """
b
表示批次大小,l
表示序列长度,d
表示特征维度。python
复制
output = self.mixer(self.norm(x)) + x
x
,然后将规范化后的数据输入到MambaBlock
(self.mixer
)。MambaBlock
的输出随后被添加到原始输入x
上,形成一个残差连接。这是深度学习中常用的技术,有助于梯度通过深层网络并缓解梯度消失问题。python
复制
return output
(b, l, d)
相同。[Norm -> Mamba -> Add]
。这种模式不仅更简单,而且在数值上等同于官方仓库中使用的模式,展示了基于框架能力和性能优化的实现方法的灵活性。- class MambaBlock(nn.Module):
- def __init__(self, args: ModelArgs):
- """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
- super().__init__()
- self.args = args
-
- self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
-
- self.conv1d = nn.Conv1d(
- in_channels=args.d_inner,
- out_channels=args.d_inner,
- bias=args.conv_bias,
- kernel_size=args.d_conv,
- groups=args.d_inner,
- padding=args.d_conv - 1,
- )
-
- # x_proj takes in `x` and outputs the input-specific Δ, B, C
- self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
-
- # dt_proj projects Δ from dt_rank to d_in
- self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
-
- A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
- self.A_log = nn.Parameter(torch.log(A))
- self.D = nn.Parameter(torch.ones(args.d_inner))
- self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
-
-
- def forward(self, x):
- """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
-
- Args:
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
-
- Returns:
- output: shape (b, l, d)
-
- Official Implementation:
- class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
- mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
-
- """
- (b, l, d) = x.shape
-
- x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
- (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
-
- x = rearrange(x, 'b l d_in -> b d_in l')
- x = self.conv1d(x)[:, :, :l]
- x = rearrange(x, 'b d_in l -> b l d_in')
-
- x = F.silu(x)
-
- y = self.ssm(x)
-
- y = y * F.silu(res)
-
- output = self.out_proj(y)
-
- return output
这里继续逐行解释代码,其中定义了MambaBlock
这个类,这个类是Mamba模型架构的一个组件:
python
复制
- class MambaBlock(nn.Module):
MambaBlock
的类,继承自nn.Module
。这是PyTorch中所有神经网络模块的基类。python
复制
def __init__(self, args: ModelArgs):
MambaBlock
的构造函数。它接受一个名为args
的参数,该参数应是ModelArgs
的实例,用于存储模型配置参数。python
复制
- """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
- super().__init__()
nn.Module
的构造函数,以正确初始化PyTorch模块。python
复制
self.args = args
args
保存在实例变量中。python
复制
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
nn.Linear
),用于将输入从模型维度d_model
投影到内部维度的两倍d_inner * 2
。是否使用偏置由args.bias
决定。python
复制
- self.conv1d = nn.Conv1d(
- in_channels=args.d_inner,
- out_channels=args.d_inner,
- bias=args.conv_bias,
- kernel_size=args.d_conv,
- groups=args.d_inner,
- padding=args.d_conv - 1,
- )
d_inner
,使用分组卷积,每个通道独立处理,核大小为d_conv
,填充为d_conv - 1
以保持长度不变,args.conv_bias
决定是否使用偏置。python
复制
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
d_inner
维度映射到dt_rank + d_state * 2
,不使用偏置。这个层用于生成特定输入的Δ, B, C。python
复制
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
dt_rank
映射回d_inner
,使用偏置。python
复制
- A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
- self.A_log = nn.Parameter(torch.log(A))
- self.D = nn.Parameter(torch.ones(args.d_inner))
d_state+1
的序列,重复d_inner
次,然后取对数并设置为可训练参数。D初始化为全1向量,并设置为可训练参数。python
复制
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
d_inner
映射回模型的输出维度d_model
。是否使用偏置由args.bias
决定。python
复制
def forward(self, x):
MambaBlock
的前向传播方法。x
是输入张量。python
复制
- (b, l, d) = x.shape
b
是批量大小,l
是序列长度,d
是特征维度。python
复制
x_and_res = self.in_proj(x) # shape ((b, l, 2 * d_inner)
x
通过in_proj
全连接层进行转换,输出维度变为2 * d_inner
,这允许将数据分为两部分,通常用于不同的处理目的。python
复制
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
x_and_res
沿最后一个维度分割为两个部分x
和res
,每部分都有d_inner
的维度。python
复制
x = rearrange(x, 'b l d_in -> b d_in l')
rearrange
函数改变x
的形状以适应卷积层的期望输入格式,即(batch_size, channels, length)。python
复制
x = self.conv1d(x)[:, :, :l]
x
输入到1D卷积层conv1d
,之后对输出进行切片操作以匹配原始输入的长度。python
复制
x = rearrange(x, 'b d_in l -> b l d_in')
rearrange
函数将卷积输出转换回原始的(batch, length, channels)格式。python
复制
- x = F.silu(x)
python
复制
y = self.ssm(x)
x
输入到某个状态空间模型(SSM)中进行进一步处理。这里的self.ssm
似乎没有在构造函数中定义,可能是一个遗漏或者假设外部已定义。python
复制
y = y * F.silu(res)
res
经过SiLU激活函数后的结果相乘,这可能是为了合并两种不同的信息流,增强模型的表达能力。python
复制
output = self.out_proj(y)
y
通过一个全连接层out_proj
进行转换,从d_inner
维度映射回原始的d_model
维度。python
复制
return output
MambaBlock
处理后的输出,其形状为(batch_size, sequence_length, d_model),与输入x
形状相同。通过这个详细解释,我们可以看到MambaBlock
类如何将输入数据通过一个复杂的处理流水线转化,涉及线性变换、卷积处理、激活函数,以及可能的状态空间模型处理,最终输出处理后的数据。这是深度学习中常见的一个模块化处理方式,有助于处理和学习序列数据中的复杂模式。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。