1. 微调(Supervised Finetuning)
指令微调阶段使用了已标注数据。这个阶段训练的数据集数量不会像预训练阶段那么大,最多可以达到几千万条,最少可以达到几百条到几千条。指令微调可以将预训练的知识'涌现'出来,进行其他类型的任务,如问答类型的任务。一般指令微调阶段对于在具体行业上的应用是必要的,但指令微调阶段一般不能灌注进去新知识,而是将已有知识的能力以某类任务的形式展现出来。
指令微调任务有多种场景,比较常用的有:
- 风格化:特定的问答范式
- 自我认知:自我认知改变
- 能力增强:模型本身能力不够,对具体行业的数据理解不良
- Agent:支持 Agent 能力,比如程序编写、API 调用等
上述只是举了几个例子,一般来说距离用户最近的训练方式就是指令微调。
一般来说,LLM 中指的 base 模型是指经过了预训练(以及进行了一部分通用指令的微调)的模型。Chat 模型是经过了大量通用数据微调和人类对齐训练的模型。
如何选择 base 模型和 chat 模型进行微调呢?
- 数据量较少的时候(比如小于 1w 条)建议使用 chat 模型微调
- 数据量较多、数据较为全面的时候,建议使用 base 模型微调
当然,如果硬件允许,建议两个模型都进行尝试,选择效果较好的。需要注意的是,chat 模型有其独特的输入格式,在微调时一定要遵循。base 模型的输入格式一般比较简单(但也需要遵守该格式),而且一般该格式不支持多轮数据集。
如果需要用 base 模型训练多轮对话,一般需要使用一个支持多轮对话的 template。在 SWIFT 中,可以指定为
default,在训练时只需要指定–template_type default 即可。
- 重要概念
loss 代表模型求解的 y 和实际的 y 值的差异。该值会进行 loss.backward(),这个方法会求解梯度,并将对应梯度值记录在每个参数上
loss 可以理解为根据模型计算出来的值和正确值的偏差(也就是残差)。例如,回归任务中计算的值是 1.0,而实际的值应当为 2.0,那么 loss 为 2.0-1.0=1.0。上述 loss 类型为 MAE,除此外,还有 MSE,Hinge 等各类 loss。一般分类任务的 loss 为交叉熵(Cross-Entropy),这也是目前 LLM 最常用的 loss。
loss 计算出来后(这个过程也就是 forward,即前向推理),经过 backward 过程即可计算出梯度。
梯度:光滑的曲面上导数变化最大的方向
loss 可以经过 PyTorch 的 loss.backward() 将每个算子、每个步骤的梯度都计算出来(复杂微分方程的链式求导过程),当有了梯度后,可以将参数往负梯度方向更新,学习率(lr)就是这时候起作用的,由于直接加上负梯度太大,可能直接产生震荡,即值从一个点瞬间跑到了曲线上的另一个点,导致在这两点反复震荡不收敛,因此乘以一个 lr,让 loss 一点点下降。
epoch 代表对数据集训练多少轮次
iter 对输入数据的每次 forward+backward 代表一个 iter
batch_size 批处理大小。在一次前向推理中,同时处理多少行数据。由于同一批数据会并行求解梯度,因此 batch_size 越大,梯度越稳定。在 SFT 时较为合适的梯度一般选择为 16/32/64 等值
- batch_size 越大,并行计算消耗的显存越高。因此在低显存情况下,可以选用 batch_size=1,gradient_accumulation_steps=16。训练会在 iter%gradient_accumulation_steps==0 时集中进行一次参数更新。在 iter%gradient_accumulation_steps!=0 时,会将梯度值不断累加到参数上,这样就相当于将 batch_size 扩大了 gradient_accumulation_steps 倍
learning_rate 学习率 训练将负梯度值乘以该值加到原参数上。换句话说,每次只将参数更新一个小幅度,避免向错误的更新方向移动太多。
一般 LoRA 的学习率可以比全参数训练的学习率稍高一点,因为全参数训练会完全重置所有参数,训练时需要学习率更低。LLM 训练的学习率一般设置在 1e-4~1e-5 不等
max_length 输入句子的最大长度。比如设置为 4096,那么句子加答案转换为 token 后最大长度为 max_length。这个值会影响显存占用,需要按照自己的实际需求设置。
- 当 batch_size 大于 1 时,意味着不同句子的长度可能不同。data_collator 的作用就是按照固定 max_length 或者 batch 中的最大长度对其他句子的 token 进行补齐。补齐的部分不参与模型的 loss 计算,但仍然会占用计算量
flash_attention flash attention 是一种针对 attention 结构高效计算的组件,该组件主要原理利用了显卡的高速缓存。flash attention 会节省约 20%~40% 训练显存并提高训练速度,对训练精度没有不良影响。在显卡支持的情况下建议开启。














