Llama 3 与 Mamba 架构融合,推理速度提升 1.6 倍且性能更优
引言
在大型语言模型(LLM)的发展进程中,Transformer 架构长期占据主导地位。然而,随着模型规模的扩大,计算复杂度和推理延迟成为制约其部署的关键瓶颈。近期,Together AI 推出了一项突破性成果,通过将 Llama 3 蒸馏到 Mamba 架构中,实现了推理速度的显著提升,同时保持了甚至超越了原始模型的性能表现。
这一项目由提出 Mamba 架构的 Tri Dao 以及 FlashAttention 的作者共同参与,标志着 Transformer 与状态空间模型(SSM)混合架构的重要进展。Together AI 创始人兼 CEO 指出,Transformer 和 Mamba 的混合是未来大模型的一大发展方向。本文将深入探讨这一混合模型的构建过程、蒸馏策略及推理加速算法。
将 Transformer 蒸馏进 Mamba 的技术路径
初始化阶段:从 Transformer 到线性 RNN
在蒸馏正式开始之前,核心挑战在于如何将预训练的 Transformer 参数有效地迁移到 Mamba 模型中。作者观察到,Transformer 的注意力机制与 RNN 的计算流程之间存在潜在的数学相似性。
具体而言,Transformer 的自注意力机制可以被视为一种全局依赖建模,而 RNN 则通过隐状态传递历史信息。为了建立二者的联系,研究者尝试将 Transformer 的注意力机制进行线性化处理。这种线性化操作使得注意力矩阵能够被近似为状态空间模型中的转移矩阵,从而允许将预训练 Transformer 的参数直接复制到 Mamba 模型的对应层中。
这种参数初始化策略极大地降低了训练成本,使得 Mamba 学生模型在初始阶段就具备了接近教师模型的知识分布,而非从零开始学习。
三阶段蒸馏流程
完成参数初始化后,作者采用了一个精细化的三阶段蒸馏流程,旨在进一步提升 Mamba 模型对 Transformer 知识的吸收能力。
第一阶段:基于伪标签的蒸馏
此阶段主要利用无标签数据进行知识迁移。预训练的 Transformer 教师模型在无标签数据上运行,生成预测结果作为伪标签。随后,Mamba 学生模型在这些伪标签上进行训练。
该过程的损失函数设计非常关键,它结合了两种损失:
- KL 散度损失:用于模仿教师模型输出的概率分布,确保学生模型在输出分布上与教师保持一致。
- 交叉熵损失:用于拟合伪标签,强化学生对特定任务目标的拟合能力。
这种组合损失函数确保了模型既学到了教师的泛化能力,又保留了针对特定数据的判别力。
第二阶段:指令数据集上的监督微调
在第一阶段的基础上,模型进入有监督的微调阶段。作者使用了带标签的指令数据集(如 OpenHermes 2.5)进行训练。这一步骤至关重要,因为它将模型从通用的语言建模任务引导至具体的指令遵循任务,提升了模型在实际对话场景中的可用性。
第三阶段:基于人类反馈的优化
最后一个阶段引入了人类反馈数据,通过奖励模型进行优化。作者收集了人类对模型输出的偏好数据,据此构建一个奖励模型。随后,使用强化学习算法(如 PPO,Proximal Policy Optimization)来优化模型在该奖励模型下的表现。
这一过程模拟了 RLHF(Reinforcement Learning from Human Feedback)的标准流程,确保最终生成的混合模型不仅逻辑正确,而且符合人类的价值观和偏好。
值得注意的是,在 8 块 80G A100 GPU 上,每个混合模型的整个蒸馏过程只需不到五天的时间。这证明了该方法在计算效率上的可行性,为大规模模型的快速迭代提供了新思路。
混合模型推理加速算法
除了模型结构的优化,推理速度的提升还依赖于专门的解码算法。作者在得到 Transformer-Mamba 混合模型后,提出了推测解码(Speculative Decoding)算法来进一步加速推理过程。
推测解码的基本原理
推测解码的核心思想是利用一个轻量级的 Draft 模型来预测多个 token,然后再用验证模型(Verifier)来验证这些预测的正确性。这种方法显著提高了解码的并行性,减少了串行生成的等待时间。
在传统的自回归生成中,每个 token 必须等待前一个 token 生成完成后才能计算。而在推测解码中,Draft 模型可以根据当前上下文一次性预测出接下来的 K 个 token。
混合架构下的验证机制
对于预测出的 K 个 token,不同架构的处理方式有所不同:
- Transformer 层处理:可以直接并行地处理这 K 个 token,计算它们的隐状态。这是因为 Transformer 的注意力机制天然支持并行计算。
- Mamba 层处理:由于 Mamba 基于 SSM 结构,需要按照顺序依次处理每个 token。系统首先计算当前 token 的隐状态,并将其与之前的隐状态进行比较。
验证逻辑如下:
- 如果当前 token 是正确的,则将其添加到已接受的序列中,并更新最新的隐状态(但不保存中间状态以节省显存)。
- 如果当前 token 是错误的,则停止处理后续 token,并将最新的隐状态回退到上一个已接受的 token 处。
- 如果序列中的所有 K 个 token 都被接受,则将它们全部添加到输出序列中,并继续预测下一组 token。
- 如果有 token 被拒绝,则从第一个被拒绝的 token 处截断预测序列,并返回初始步骤从该位置开始重新预测。
这种机制充分利用了 Mamba 的顺序计算优势,同时通过并行验证减少了无效计算的时间开销。
性能评估与实验结果
对话任务表现
测试结果表明,混合模型在单轮聊天对话任务(AlpacaEval)和多轮对话任务(MT-Bench)上与 Llama-3 相当,甚至在某些指标上表现更优。这表明蒸馏过程并未丢失原始模型的关键能力,反而可能因为架构优化带来了细微的性能增益。
此外,研究还对不同混合比例的模型表现进行了测试。结果显示,按照 1:1 比例混合的模型表现最佳,这为未来的架构设计提供了参考依据。
通用 NLP 任务评测
在零样本的通用 NLP 任务评测中,混合模型的平均成绩优于同等规模的 RNN 模型。这验证了引入 Transformer 知识对纯 RNN 架构的显著提升作用。
![Performance Chart Placeholder]
少样本榜单表现
在少样本的 OpenLLM Leaderboard 榜单上,混合模型的表现与最好的开源 RNN 模型相当,并在 GSM8K(数学推理)和 CRUX(代码理解)任务上超过了对应的 Instruct 模型。这证明了混合架构在处理复杂逻辑和代码任务时的鲁棒性。
推理速度加速效果
除了模型性能,作者重点测试了推测解码算法带来的加速效果。
- 纯 Mamba 模型测试:在 2.8B 和 7B 的模型上,相比原来的解码方式,推理速度提升了 1.7 到 2.6 倍。这展示了 Mamba 架构本身结合推测解码的巨大潜力。
- 混合模型测试:在蒸馏的 Zephyr 和 Llama 混合模型上,Zephyr 混合模型的推理速度提升了 1.8 倍以上,Llama 混合模型也有 1.6 倍左右的加速。
这意味着在实际部署场景中,用户可以在保持响应质量的同时,获得更快的交互体验,这对于实时应用(如智能客服、语音助手)具有重要意义。
技术背景与行业意义
Transformer 与 Mamba 的互补性
Transformer 的优势在于强大的全局上下文建模能力,但其 O(N^2) 的复杂度限制了长序列的处理效率。Mamba 作为状态空间模型的代表,具有 O(N) 的线性复杂度,非常适合长文本处理和低延迟推理。
将两者结合,实际上是在保留 Transformer 强大语义理解能力的同时,引入了 Mamba 的高效推理特性。Tri Dao 等专家的参与也表明,学术界和工业界正在积极探索超越传统 Transformer 的新范式。
蒸馏技术的演进
传统的知识蒸馏通常关注于缩小模型体积,而此次工作展示了如何通过蒸馏改变模型架构。这种架构转换蒸馏(Architecture Conversion Distillation)为模型压缩和加速开辟了新的道路。它不再局限于同构模型之间的迁移,而是实现了异构模型间的知识传递。
未来展望
随着混合模型技术的成熟,我们可能会看到更多基于 SSM 的大模型出现。这不仅限于 LLM,还可能扩展到多模态领域。此外,推测解码算法的进一步优化,如动态调整猜测长度、自适应选择草稿模型等,将是提升推理效率的关键方向。
对于开发者而言,掌握这类混合架构的训练和部署技能,将在未来的 AI 工程化竞争中占据先机。Together AI 的这一成果提供了一个可行的技术路线图,鼓励社区进一步探索高效能大模型的可能性。
参考资料
论文地址:https://www.together.ai/blog/the-mamba-in-the-llama-distilling-and-accelerating-hybrid-models
注:本文内容基于 Together AI 官方发布的技术报告整理,旨在分享前沿技术动态。