Llama 3 与 Mamba 架构融合,推理速度提升 1.6 倍且性能更优
引言
在大型语言模型(LLM)的发展进程中,Transformer 架构长期占据主导地位。然而,随着模型规模的扩大,计算复杂度和推理延迟成为制约其部署的关键瓶颈。近期,Together AI 推出了一项突破性成果,通过将 Llama 3 蒸馏到 Mamba 架构中,实现了推理速度的显著提升,同时保持了甚至超越了原始模型的性能表现。
这一项目由提出 Mamba 架构的 Tri Dao 以及 FlashAttention 的作者共同参与,标志着 Transformer 与状态空间模型(SSM)混合架构的重要进展。Together AI 创始人兼 CEO 指出,Transformer 和 Mamba 的混合是未来大模型的一大发展方向。本文将深入探讨这一混合模型的构建过程、蒸馏策略及推理加速算法。
将 Transformer 蒸馏进 Mamba 的技术路径
初始化阶段:从 Transformer 到线性 RNN
在蒸馏正式开始之前,核心挑战在于如何将预训练的 Transformer 参数有效地迁移到 Mamba 模型中。作者观察到,Transformer 的注意力机制与 RNN 的计算流程之间存在潜在的数学相似性。
具体而言,Transformer 的自注意力机制可以被视为一种全局依赖建模,而 RNN 则通过隐状态传递历史信息。为了建立二者的联系,研究者尝试将 Transformer 的注意力机制进行线性化处理。这种线性化操作使得注意力矩阵能够被近似为状态空间模型中的转移矩阵,从而允许将预训练 Transformer 的参数直接复制到 Mamba 模型的对应层中。
这种参数初始化策略极大地降低了训练成本,使得 Mamba 学生模型在初始阶段就具备了接近教师模型的知识分布,而非从零开始学习。
三阶段蒸馏流程
完成参数初始化后,作者采用了一个精细化的三阶段蒸馏流程,旨在进一步提升 Mamba 模型对 Transformer 知识的吸收能力。
第一阶段:基于伪标签的蒸馏
此阶段主要利用无标签数据进行知识迁移。预训练的 Transformer 教师模型在无标签数据上运行,生成预测结果作为伪标签。随后,Mamba 学生模型在这些伪标签上进行训练。
该过程的损失函数设计非常关键,它结合了两种损失:
- KL 散度损失:用于模仿教师模型输出的概率分布,确保学生模型在输出分布上与教师保持一致。
- 交叉熵损失:用于拟合伪标签,强化学生对特定任务目标的拟合能力。
这种组合损失函数确保了模型既学到了教师的泛化能力,又保留了针对特定数据的判别力。
第二阶段:指令数据集上的监督微调
在第一阶段的基础上,模型进入有监督的微调阶段。作者使用了带标签的指令数据集(如 OpenHermes 2.5)进行训练。这一步骤至关重要,因为它将模型从通用的语言建模任务引导至具体的指令遵循任务,提升了模型在实际对话场景中的可用性。
第三阶段:基于人类反馈的优化
最后一个阶段引入了人类反馈数据,通过奖励模型进行优化。作者收集了人类对模型输出的偏好数据,据此构建一个奖励模型。随后,使用强化学习算法(如 PPO,Proximal Policy Optimization)来优化模型在该奖励模型下的表现。
这一过程模拟了 RLHF(Reinforcement Learning from Human Feedback)的标准流程,确保最终生成的混合模型不仅逻辑正确,而且符合人类的价值观和偏好。
值得注意的是,在 8 块 80G A100 GPU 上,每个混合模型的整个蒸馏过程只需不到五天的时间。这证明了该方法在计算效率上的可行性,为大规模模型的快速迭代提供了新思路。
混合模型推理加速算法
除了模型结构的优化,推理速度的提升还依赖于专门的解码算法。作者在得到 Transformer-Mamba 混合模型后,提出了推测解码(Speculative Decoding)算法来进一步加速推理过程。
推测解码的基本原理
推测解码的核心思想是利用一个轻量级的 Draft 模型来预测多个 token,然后再用验证模型(Verifier)来验证这些预测的正确性。这种方法显著提高了解码的并行性,减少了串行生成的等待时间。
在传统的自回归生成中,每个 token 必须等待前一个 token 生成完成后才能计算。而在推测解码中,Draft 模型可以根据当前上下文一次性预测出接下来的 K 个 token。
混合架构下的验证机制
对于预测出的 K 个 token,不同架构的处理方式有所不同:
- :可以直接并行地处理这 K 个 token,计算它们的隐状态。这是因为 Transformer 的注意力机制天然支持并行计算。


