低阶约束下实现 LLM 的全秩训练
在大型语言模型(LLM)的训练中,内存消耗一直是制约大规模部署的关键瓶颈。低秩适应(Low-Rank Adaptation, LoRA)等方法通过冻结预训练权重并引入可训练的低秩矩阵来节省显存,但这类方法本质上将训练限制在了低秩子空间内,不可避免地导致了次优性能。GaLore 等梯度投影方法虽然试图缓解这一问题,但依然无法摆脱丢弃子空间外信息的局限。
针对这一痛点,北京理工大学团队提出了一种名为 Fira 的即插即用框架。Fira 的核心目标是在保持低秩约束以提升内存效率的同时,实现全秩训练,从而避免性能损失。
核心洞察与方案
Fira 的设计基于两个关键观察和对应的解决方案:
1. 自适应优化器的缩放效应
研究人员发现,在从低秩向全秩过渡的过程中,自适应优化器(如 Adam)对梯度范数的缩放影响是相对稳定的。基于此,Fira 提出了一种基于范数的缩放方法。该方法利用低秩优化器的缩放行为作为原始全秩优化器的替代,使得我们可以在优化器内部保留低秩约束,同时利用全秩梯度进行更新。这相当于用一种'低成本'的方式模拟了全秩优化的效果。
2. 梯度范数增长限制器
在实际优化过程中,梯度的突然上升往往会导致损失函数的尖峰,影响训练稳定性。为此,Fira 引入了一个范数增长限制器。它通过调节梯度范数的相对增量来平滑梯度更新,有效避免了训练过程中的损失震荡。
实验效果与对比
在广泛的预训练和微调实验中,Fira 展现了显著优势。与主流的 LoRA 和 GaLore 相比,Fira 不仅实现了更优的性能,还保持了极高的内存效率。
- 内存效率:在 LLaMA 1B 架构上,Fira 将优化器状态的内存使用量减少了 61.1%。
- 性能表现:对于 LLaMA 7B 架构的预训练,Fira 使用的秩比 GaLore 小 8 倍,但性能却远远优于 GaLore。即使在极低的秩设置(如 4 或 16)下,Fira 仍能实现与满秩训练相当甚至更好的结果。
总结
Fira 框架通过巧妙的范数缩放策略和梯度平滑机制,成功打破了低秩训练与全秩性能之间的权衡。它证明了在不牺牲内存效率的前提下,完全有可能实现接近全秩的训练效果。这对于资源受限场景下的 LLM 开发具有重要的参考价值。







