LLM 继续预训练(Continue Pretrain)实践指南与经验总结
1. 背景
近年来,大模型领域的商业化路径逐渐清晰。对于国内头部玩家而言,数据资源通常较为充足,只要显卡算力允许,往往会尽可能地对现有数据进行全量训练。然而,底座模型的通用能力存在天然上限,无法在所有领域同时达到最优。当试图加强某个特定领域的能力时,其他领域的性能往往会出现不同程度的下降,这种现象被称为"Alignment Tax"。
从 2022 年 OpenAI 发表的《Training language models to follow instructions with human feedback》论文开始,这一现象便引起了广泛关注。在许多实际场景中,例如教育、代码生成等垂直领域,用户需求相对集中。因此,在保证通用能力不显著下降的前提下,努力提升特定领域(Domain)的效果成为主流方案,即进行 Continue Pretrain(领域大模型增训)。
根据反馈,如果在 Continue Pretrain 后发现领域效果和通用效果同时提升,通常意味着底座模型的通用域训练尚不充分。此外,常见的操作还包括将英文模型增训为中文(如 Llama 系列),以及针对长上下文(Long Context)的 Continue Pretrain。
近期关于 Continue Pretrain 的分享中,涉及了多语言增训及长上下文处理的踩坑经验。本文基于相关论文及团队实践经验,整理出核心步骤与注意事项。
2. 核心步骤
Continue Pretrain 的整体流程主要包含三个关键阶段:词表扩展、领域继续预训练、领域对齐。
2.1 扩词表策略
不建议轻易扩大词表,仅在满足以下两个条件时可尝试:
- 底座模型的词表分布与目标领域的词表分布差距较大。
- 待增训的领域语料足够丰富。
大多数基础词表已包含常用字。例如,原词表中'北京'可能对应 ID [12, 15],扩词后可能变为 [10233]。这种变动涉及高频词,即使采用各种 Warmup 或 Frozen 策略,也需要更长的训练时间才能收敛。在多语言 Continue Pretrain 场景下,小语种语料有限,尚未看到正向收益样本即用尽,因此风险较高。
另一种可行的情况是仅扩充低频词,避免影响原有高频字/词。总体而言,选择一个词表质量较好的底座模型进行 Continue Pretrain,比面对不充分的底座训练更为稳妥,词表调整带来的坑往往更大。
2.2 Domain Continue Pretrain
此阶段参考了 Sailor 论文及相关工作,重点在于数据配比与超参数调优。
2.2.1 Replay(重放数据)
需要采样预训练(Pretrain)阶段的数据进行混合。当前开源 Base 模型通常在最后阶段混入了部分 SFT 数据以提升特定领域效果。由于开源样本比例限制,具体混入哪些数据需靠经验反推。若 Continue Pretrain 后对比原 Base 模型掉点严重,可能是缺少了部分关键的 SFT 数据补充。
2.2.2 学习率(Learning Rate)权衡
Sailor 论文指出,在保持 Continual Pre-training 总 Token 数一致的情况下,原有领域和新领域的 Loss 几乎可预测,本质上是学习率与重放比例(Replay Ratio)的权衡。学得越快,遗忘越多,即使增加 Replay 比例亦然。
实验表明,固定总 Token 数时,英语和马来语的 Validation Loss 随学习率和比例变化呈现规律性。英语 Loss 更具可预测性,可用二次项函数拟合,相关系数达 99.36%。关键指标为 log(English Proportion) - log(Learning Rate)。建议学习率设为 1e-4,相比 Qwen 原始设置(假设遵循 Llama 类似策略为 4e-4)有所降低。
在实践中,如果计算资源和数据资源充足,应尽量减小学习率以保留原模型效果。但若追求计算效率,建议在确定合适学习率后,放弃追求原领域无损的想法。观察发现:
- 学习率越小,新领域 Loss 下降越慢,但原领域遗忘也越慢。
- 学习率越大,新领域 Loss 下降加快,知识习得快。
- 学习率过高会导致 Loss 震荡,因知识学习速度有上限且数据分布差异限制了过快学习。
- 学习率极大时,Attention 分布剧烈调整会对原领域性能造成破坏性损失。
综上,1e-4 是一个有利于新领域学习且对原领域破坏较小的平衡点。推荐在正式训练前,使用少量 Token、固定重放比例(如 0.5)及多组随机学习率进行小规模实验,确认平衡点后再精细调整数据配比。
2.2.3 比例控制
领域数据占比过高可能导致 Loss 崩溃,占比过低则学习效率低,最终提升不明显。张舸和浩然的论文发现,随着领域数据占比提升,通用 Loss 和领域 Loss 呈此消彼长并趋于稳定的过程。
假设通用数据占比为 r,领域数据占比为 1-r,Scaling Law 公式可用于预估不同比例下的 Loss。通过小规模实验获取数据点拟合公式参数,即可推算更大参数量下的表现。多语言场景可参考 RegMix 方法,优化各语言 Loss 的 LogSum,使模型达到帕累托最优。Sailor 模型使用的重放数据(含英语和中文)总比例约为 30%。


