引言
当前大型语言模型(LMs)面临一个核心挑战:规模扩大并不等同于更好地遵循用户意图。由于训练目标仅是预测网页上的下一个 token,模型常生成不真实、有毒或无用的输出。本文旨在让模型在'有用性'、'诚实性'和'无害性'上与用户意图对齐。
为此,研究提出使用人类反馈强化学习(RLHF)微调 GPT-3。整个过程分为三步:首先通过标注数据监督微调;其次基于模型输出的排名数据训练奖励模型;最后利用 PPO 算法根据奖励反馈进一步优化策略。
值得注意的是,RLHF 可能导致模型在公共 NLP 数据集上性能下降。作者发现将 PPO 更新与预训练分布的对数似然更新混合(即 PPO-ptx),可显著缓解这一问题。经过 RLHF 的模型不仅符合标注者偏好,还能泛化到未参与训练的标注者偏好,甚至扩展到代码和非英语任务。
方法与实验细节
从预训练模型到用户意图对齐,主要经历三个阶段。
首先是监督微调(SFT)。收集人类标注者针对输入 Prompt 提供的期望输出,用这些数据对预训练 GPT-3 进行微调。
其次是训练奖励模型(RM)。对于同一 Prompt,模型生成多个响应,由标注者按优劣排名。利用这些排名数据训练 RM,使其能预测人类更偏好哪个输出。为提升效率,RM 使用 6B 参数模型,一次性训练所有可能的比较对。
最后是强化学习(RL)。以 RM 的输出作为标量奖励,指导 SFT 模型微调。RL 环境类似老虎机,给定 Prompt 生成响应并获得奖励。
关键机制解析
在实际训练中,有几个关键点值得注意。
一是 KL 散度的作用。为防止模型过度优化奖励模型而偏离原始分布,我们在每个 token 上增加了 KL 散度惩罚。具体而言,最终奖励 $R(x, y)$ 不仅是 RM 给出的分数 $r_ heta(x, y)$,还需减去 KL 惩罚项:
$$R(x, y) = r_ heta(x, y) - \beta \log \left( \frac{\pi^{RL}(y|x)}{\pi^{SFT}(y|x)} \right)$$
其中 $\pi^{RL}$ 是当前强化学习模型的输出概率,$\pi^{SFT}$ 是原始监督微调模型的输出概率。计算时,RL 模型生成完整回复序列,再将其输入 SFT 模型对比概率,因此不存在长度不一致的问题。
二是如何混合预训练梯度。单纯优化人类偏好会导致'对齐税',即在问答、阅读理解等公共数据集上性能下降。为此,作者在 PPO 更新中混合了预训练梯度的更新。总目标函数变为既要最大化人类偏好奖励,又要最大化预训练数据分布的对数似然:
$$\text{Objective} = \text{Objective}{PPO} + \gamma \cdot \mathbb{E}{x \sim D_{pretrain}} [\log \pi(x)]$$
简单来说,就是在训练 PPO 的同时,随机抽取原始预训练文本让模型填空,并将这部分损失纳入优化指标,从而保留基础语言能力。

