Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
论文: https://arxiv.org/abs/2312.00752
Mamba 是一种新的状态空间模型架构,在信息密集型数据(如语言建模)上显示出良好的性能,而以前的次二次模型不及 Transformers。 它基于结构化状态空间模型的进展路线, 本着 FlashAttention 的精神,采用高效的硬件感知设计和实现。
pip install causal-conv1d>=1.1.0,<1.2.0:在 Mamba 块内使用的简单因果 Conv1d 层的有效实现。
pip install mamba-ssm:核心 Mamba 包。
它也可以从源代码和此存储库构建。
pip install .
如果抱怨 PyTorch 版本,请尝试传递给 .
pip
--no-build-isolation
pip
其他要求:
我们公开了与 Mamba 模型的多个接口级别。
Mamba 基于选择性 SSM 层,这是本文的重点(第 3 节;算法 2)。
资料来源:ops/selective_scan_interface.py。
此存储库的主模块是包装选择性 SSM 的 Mamba 架构块。
资料来源:modules/mamba_simple.py。
用法:
import torch from mamba_ssm import Mamba batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape
最后,我们提供了一个完整的语言模型示例:一个深度序列模型骨干(带有重复的 Mamba 块)+ 语言模型头。
资料来源:models/mixer_seq_simple.py。
这是如何将 Mamba 集成到端到端神经网络中的示例。 此示例用于下面的生成脚本中。
预训练模型上传到 Hugging Face: , , , , , , , 在 Pile 上的 300B 代币上训练,以及(在 SlimPajama 数据集上训练的 600B 代币上)。
mamba-130m
mamba-370m
mamba-790m
mamba-1.4b
mamba-2.8b
mamba-2.8b-slimpj
模型将由下面的生成脚本自动下载。
这些模型是在 Pile 上训练的,并遵循 GPT-3 描述的标准模型尺寸,并遵循许多开源模型:
参数 | 层 | 型号昏暗。 |
---|---|---|
130 米 | 24 | 768 |
370米 | 48 | 1024 |
790米 | 48 | 1536 |
1.4字节 | 48 | 2048 |
2.8乙 | 64 | 2560 |
(Mamba 的层数是具有类似大小的 Transformer 的两倍,因为 Transformer 的每个“层”(MHA 块 + MLP 块)需要两个 Mamba 块。
注意:这些是仅针对 300B 代币训练的基础模型,没有任何形式的下游修改(指令调整等)。 预计性能将与在类似数据上训练的其他架构相当或更好,但不能与更大或微调的模型相匹配。
要对模型进行零样本评估(对应于论文的表3), 我们使用 lm-evaluation-harness 库。
lm-evaluation-harness
git submodule update --init --recursive
big-refactor
lm-evaluation-harness
pip install -e 3rdparty/lm-evaluation-harness
promptsource
pip install git+https://github.com/bigscience-workshop/promptsource.git
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
要重现博客文章中报告的模型的结果,请执行以下操作:
mamba-2.8b-slimpj
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64 python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64
请注意,由于评估过程中的噪声,每个任务的结果可能与报告值相差 0.1-0.3。
脚本基准测试/benchmark_generation_mamba_simple.py
其他可配置选项包括 top-p(原子核采样)概率和 softmax 温度。
要使用不同的采样策略测试生成延迟(例如批处理大小 = 1):
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
要使用随机提示(例如大批量)测试生成吞吐量:
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
我们的模型使用 PyTorch AMP 进行训练,以实现混合精度。AMP 将模型参数保留在 float32 中,并在必要时强制转换为半精度。 另一方面,其他框架(如 DeepSpeed)将参数存储在 float16 中,并在必要时向上转换(例如,用于优化器累积)。
我们观察到,可能需要对主要模型参数进行更高的精度,因为 SSM 对其循环动力学很敏感。如果你遇到不稳定, 作为第一步,请尝试在 fp32 中存储参数的框架(例如 AMP)。
模型的某些部分具有从先前的 S4 模型工作继承的初始化。 例如,$\Delta$ 参数通过初始化其线性投影的偏差来获得目标范围。 但是,某些框架可能具有初始化后钩子(例如,将模块中的所有偏差项设置为零)。 如果是这种情况,你可能需要添加自定义逻辑(例如,此行在我们的训练器中关闭重新初始化,但在任何其他框架中都是无操作的) 这是特定于培训框架的。
nn.Linear
如果你使用此代码库,或者发现我们的工作有价值,请引用 Mamba:
@article{mamba, title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, author={Gu, Albert and Dao, Tri}, journal={arXiv preprint arXiv:2312.00752}, year={2023} }