mamba - Mamba 是一种新的状态空间模型架构,在信息密集型数据(例如语言建模)上显示出良好的性能,而之前的二次模型在 Transformers 方面存在不足。它基于结构化状态空间模型的进展,并本着FlashAttention的精神进行高效的硬件感知设计和实现。

Created at: 2023-12-01 09:17:39
Language: Python
License: Apache-2.0

曼巴

曼巴

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
论文: https://arxiv.org/abs/2312.00752

大约

Mamba 是一种新的状态空间模型架构,在信息密集型数据(如语言建模)上显示出良好的性能,而以前的次二次模型不及 Transformers。 它基于结构化状态空间模型的进展路线, 本着 FlashAttention 的精神,采用高效的硬件感知设计和实现。

安装

  • pip install causal-conv1d>=1.1.0,<1.2.0
    :在 Mamba 块内使用的简单因果 Conv1d 层的有效实现。
  • pip install mamba-ssm
    :核心 Mamba 包。

它也可以从源代码和此存储库构建。

pip install .

如果抱怨 PyTorch 版本,请尝试传递给 .

pip
--no-build-isolation
pip

其他要求:

  • Linux操作系统
  • 英伟达 GPU
  • PyTorch 1.12+
  • CUDA 11.6+

用法

我们公开了与 Mamba 模型的多个接口级别。

选择性 SSM

Mamba 基于选择性 SSM 层,这是本文的重点(第 3 节;算法 2)。

资料来源:ops/selective_scan_interface.py

曼巴街区

此存储库的主模块是包装选择性 SSM 的 Mamba 架构块。

资料来源:modules/mamba_simple.py

用法:

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

Mamba 语言模型

最后,我们提供了一个完整的语言模型示例:一个深度序列模型骨干(带有重复的 Mamba 块)+ 语言模型头。

资料来源:models/mixer_seq_simple.py

这是如何将 Mamba 集成到端到端神经网络中的示例。 此示例用于下面的生成脚本中。

预训练模型

预训练模型上传到 Hugging Face: , , , , , , , 在 Pile 上的 300B 代币上训练,以及(在 SlimPajama 数据集上训练的 600B 代币上)。

mamba-130m
mamba-370m
mamba-790m
mamba-1.4b
mamba-2.8b
mamba-2.8b-slimpj

模型将由下面的生成脚本自动下载。

这些模型是在 Pile 上训练的,并遵循 GPT-3 描述的标准模型尺寸,并遵循许多开源模型:

参数 型号昏暗。
130 米 24 768
370米 48 1024
790米 48 1536
1.4字节 48 2048
2.8乙 64 2560

(Mamba 的层数是具有类似大小的 Transformer 的两倍,因为 Transformer 的每个“层”(MHA 块 + MLP 块)需要两个 Mamba 块。

注意:这些是仅针对 300B 代币训练的基础模型,没有任何形式的下游修改(指令调整等)。 预计性能将与在类似数据上训练的其他架构相当或更好,但不能与更大或微调的模型相匹配。

评估

要对模型进行零样本评估(对应于论文的表3), 我们使用 lm-evaluation-harness 库。

  1. 按 拉取存储库。我们使用分支。
    lm-evaluation-harness
    git submodule update --init --recursive
    big-refactor
  2. 安装:。 在 Python 3.10 上,你可能需要手动安装最新版本的 : 。
    lm-evaluation-harness
    pip install -e 3rdparty/lm-evaluation-harness
    promptsource
    pip install git+https://github.com/bigscience-workshop/promptsource.git
  3. 使用以下方式运行评估(更多文档在 lm-evaluation-harness 存储库):
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64

要重现博客文章中报告的模型的结果,请执行以下操作:

mamba-2.8b-slimpj

python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64

请注意,由于评估过程中的噪声,每个任务的结果可能与报告值相差 0.1-0.3。

推理

脚本基准测试/benchmark_generation_mamba_simple.py

  1. 从 Hugging Face Hub 自动加载模型,
  2. 生成用户指定提示的完成,
  3. 对这一代的推理速度进行基准测试。

其他可配置选项包括 top-p(原子核采样)概率和 softmax 温度。

例子

要使用不同的采样策略测试生成延迟(例如批处理大小 = 1):

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2

要使用随机提示(例如大批量)测试生成吞吐量:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128

故障 排除

精度

我们的模型使用 PyTorch AMP 进行训练,以实现混合精度。AMP 将模型参数保留在 float32 中,并在必要时强制转换为半精度。 另一方面,其他框架(如 DeepSpeed)将参数存储在 float16 中,并在必要时向上转换(例如,用于优化器累积)。

我们观察到,可能需要对主要模型参数进行更高的精度,因为 SSM 对其循环动力学很敏感。如果你遇到不稳定, 作为第一步,请尝试在 fp32 中存储参数的框架(例如 AMP)。

初始化

模型的某些部分具有从先前的 S4 模型工作继承的初始化。 例如,$\Delta$ 参数通过初始化其线性投影的偏差来获得目标范围。 但是,某些框架可能具有初始化后钩子(例如,将模块中的所有偏差项设置为零)。 如果是这种情况,你可能需要添加自定义逻辑(例如,此在我们的训练器中关闭重新初始化,但在任何其他框架中都是无操作的) 这是特定于培训框架的。

nn.Linear

引文

如果你使用此代码库,或者发现我们的工作有价值,请引用 Mamba:

@article{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}