北航发布 LLaMA-Factory:零代码大模型微调与高效训练框架
在大模型技术飞速发展的今天,训练和微调大型语言模型(LLM)对于大多数普通工程师而言仍面临较高的门槛。为了降低大模型训练与微调的难度,北京航空航天大学发布了 LLaMA-Factory。这是一个旨在普及 LLMs 微调的开源框架,通过可扩展的模块统一了多种高效微调方法,使得数百种语言模型能够在资源有限的情况下进行高吞吐量的微调。
核心特性
LLaMA-Factory 简化了常用的训练方法,包括生成式预训练、监督式微调(SFT)、基于人类反馈的强化学习(RLHF)以及直接偏好优化(DPO)。用户可以通过命令行或 Web 界面,以最小或无需编码的方式自定义和微调他们的语言模型。该框架遵循 Apache-2.0 许可证开源,已在 GitHub 上获得大量关注,并在 Hugging Face Hub 上构建了数百个开源模型。
高效的微调技术
高效的 LLM 微调技术主要分为两大类:专注于优化的方法和旨在计算的方法。
高效优化
LLaMA-Factory 集成了多种高效的优化技术,旨在保持成本最低的同时调整 LLM 的参数。主要方法包括:
- Freeze-tuning:冻结大部分模型参数,仅更新特定层,大幅减少显存占用。
- Gradient Low-rank Projection (GaLore):通过低秩投影梯度来更新权重,显著降低内存需求。
- Low-rank Adaptation (LoRA):在预训练模型旁路中注入可训练的低秩矩阵,是目前最主流的高效微调方案。
- Weight-decomposed Low-rank Adaptation (DoRA):对 LoRA 进行改进,分解权重为幅度和角度,提升微调效果。
- LoRA+:进一步优化 LoRA 的学习率策略,提高收敛速度和最终性能。
这些方法可以显著提高模型的训练效率和内存使用效率,使得在消费级显卡上进行大模型微调成为可能。
高效计算
LLaMA-Factory 整合了多种高效计算技术,寻求减少 LLM 中所需的计算时间或空间。关键技术包括:
- 混合精度训练:结合 FP16 和 BF16 等格式,平衡精度与速度。
- 激活检查点(Activation Checkpointing):用计算换内存,减少前向传播时的中间激活值存储。
- Flash Attention:优化注意力机制的计算过程,减少 IO 开销。
- S2 Attention:针对长序列优化的注意力机制。
- 量化策略:支持将模型动态量化为 8 位或 4 位(如 QLoRA),将内存占用从每个参数 18 字节或 8 字节降低到仅为 0.6 字节左右。
- 适配器技术:灵活附加轻量级适配器层。
通过这些技术的结合,LLaMA-Factory 能够显著提高 LLM 的效率。
LLaMA-Factory 架构设计
LLaMA-Factory 由三个主要模块组成:Model Loader、Data Worker 和 Trainer。这种模块化设计最小化了模块对特定模型和数据集的依赖,使框架可以灵活扩展到数百个模型和数据集。
Model Loader
Model Loader 负责加载和初始化模型参数,包含四个组件:
- 模型初始化:使用 Transformers 库的 AutoModel API 来加载和初始化模型参数,支持 100 多个 LLM 架构。
- 模型修补:通过替换模型的前向计算来实现 Flash Attention 和 S2 Attention,提升推理和训练速度。
- 模型量化:支持多种后训练量化方法,可根据设备能力处理预训练模型的浮点精度。
- 适配器附加:根据模型注册表自动识别适配器应该附加的层,并使用 PEFT 库来附加适配器。
Data Worker
Data Worker 是一个数据处理管道,包括数据集加载、对齐、合并和预处理。它将不同任务的数据集标准化为统一格式,使我们能够在各种格式的数据集上微调模型。具体功能包括:
- 数据集加载:使用 datasets 库加载数据。
- 数据描述规范:设计统一规范来对齐相应的列,收集数据集。
- 聊天模板:提供多个聊天模板以适应不同的对话场景。
- 分层序列打包:优化输入数据的序列长度分布。
Trainer
Trainer 统一了高效的微调方法,以适应不同的任务和数据集,提供了四种主要的训练模式:
- 预训练与 SFT:利用 Transformer 的 Trainer 进行基础预训练和监督式微调。
- RLHF 与 DPO:采用 TRL(Transformer Reinforcement Learning)的 Trainer 进行基于人类反馈的强化学习和直接偏好优化。
- 模型共享 RLHF:提出了一种创新的模型共享 RLHF 方法,使整个 RLHF 训练不需要多于一个预训练模型。通过 PEFT 的 set_adapter 和 disable_adapter API 动态切换适配器和 value head,使预训练模型同时作为策略、值、参考和奖励模型。据我们所知,这是第一个在消费设备上支持 RLHF 训练的方法。
- 分布式训练:可与 DeepSpeed 结合进行分布式训练,利用 DeepSpeed ZeRO 优化器,通过分区或卸载进一步减少内存消耗。
LLaMA Board:统一用户界面
LLaMA Board 是基于 Gradio 的统一用户界面,允许用户在不编写任何代码的情况下自定义 LLM 的微调。它提供了简化的模型微调和推理服务,使用户可以轻松地利用 100 多个 LLM 和 50 多个数据集。
LLaMA Board 具有易于配置、可监控的训练、灵活的评估和多语言支持等特点。用户可以通过与 Web 界面交互来自定义微调参数,并实时监视训练进度。此外,LLaMA Board 支持自动评估模型的文本相似度分数或通过与模型聊天进行人工评估。目前,LLaMA Board 支持英语、俄语和中文三种语言。
实证研究
本文从两个角度对 LLaMA-Factory 进行评估:训练效率(内存使用、吞吐量和困惑度)以及适应下游任务的效果。
训练效率
实验利用 PubMed 数据集进行,提取了约 400,000 个标记用于构建训练样本。通过 fine-tune Gemma-2B、Llama2-7B 和 Llama2-13B 模型,使用不同的 fine-tuning 方法进行比较。结果表明:
- QLoRA 具有最低的内存占用,适合显存受限的场景。
- LoRA 具有更高的吞吐量,适合需要快速迭代的场景。
- GaLore 在大型模型上具有更低的困惑度(PPL)。
- LoRA 在小型模型上表现出优势。
下游任务微调
通过在下游任务上微调各种模型并比较它们的性能来进行评估。使用来自 CNN/DM、XSum 和 AdGen 三个代表性文本生成任务的示例构建训练集和测试集。选择几个经过指令微调的模型,并使用不同的微调方法进行微调,包括全微调(FT)、GaLore、LoRA 和 4 位 QLoRA。
微调后,计算每个任务的测试集上的 ROUGE 分数。实验结果表明,LoRA 和 QLoRA 在大多数情况下表现最佳,除了部分模型在特定数据集上的表现差异。这一现象突出了这些高效微调方法在适应特定任务中的有效性。此外,观察到 Mistral-7B 模型在英语数据集上表现更好,而 Qwen1.5-7B 模型在中文数据集上获得更高的分数。这些结果表明,微调模型的性能也与其在特定语言上的内在能力有关。
总结与未来工作
LLaMA-Factory 是一个统一的框架,可用于高效微调超过 100 个 LLM 模型。通过模块化设计,最小化模型、数据集和训练方法之间的依赖关系,并提供一种集成的方法来进行微调。此外,LLaMA Board 提供了一个灵活的 Web UI,可以在不需要编码的情况下进行自定义微调和评估 LLM。该框架在语言建模和文本生成任务上得到了实证验证。
未来,LLaMA-Factory 将与最先进的模型和高效微调技术保持同步,并探索更高级的并行训练策略和多模态高效微调 LLM 的可能性。项目地址:https://github.com/hiyouga/LLaMA-Factory,论文地址:https://arxiv.org/pdf/2403.13372.pdf。