大模型 RLHF 强化学习微调过程详解
引言
在大型语言模型(LLM)的训练流程中,通常包含预训练、有监督微调(SFT)和人类反馈强化学习(RLHF)三个阶段。虽然网上关于 PPO(Proximal Policy Optimization)的文章不少,但很多内容浅尝辄止。本文结合微软 DeepSpeed 的 RLHF 实现代码,深入讲解奖励模型训练和强化学习微调的核心逻辑。
在强化学习微调语言模型时,Prompt 对应状态(State),输出一串单词(Action),得到一个 Reward。与传统游戏场景不同,这里的状态转移蕴含在 Transformer 内部,每个动作后并不立即获得新的外部 State,而是通过模型内部机制更新隐状态。对于回答的第二个词,可以将 Prompt+ 第一个词当作新的 State。
奖励(Reward)模型训练
1. 模型初始化说明
在强化学习阶段,使用的 Reward Model 和 Critic Model 使用同一个模型初始化。因此,在训练 Reward 模型的过程中,实际上也在训练 Critic Model。
符号说明:大模型中间隐藏层的参数维度为 (B, L, D),其中 B 为 Batch Size,L 为句子长度,D 为 Embedding 维度。
2. 奖励模型的作用
RLHF 需要一个奖励模型来评估语言大模型(Actor Model)回答的好坏。该模型通常比被评估的语言大模型小一些(例如 Actor 66B,Reward 350M)。输入为 Prompt + Answer,输出用于打分。
奖励模型最后一层隐藏层的输出维度为 (B, L, D),通过一个 D×1 的全连接层将维度变为 (B, L)。在 L 维度上,第 i 个位置的数据表示从第 i 个位置到最后一个位置输出所能获得的奖励分值的累加和(类似 DQN 中的 Q 值)。这种形式的输出满足了 Critic Model 的输出要求。
hidden_states = transformer_outputs[0]
rewards = self.v_head(hidden_states).squeeze(-1)
对于一个奖励模型来说,目标是给一个句子进行打分。按理说每个句子对应一个分值即可,但目前长度为 L 的句子,奖励模型输出了 L 个值。我们取 L 维度上的最后一个位置的值作为本句话的奖励得分。
3. Pairwise Loss 训练
奖励模型训练优化采用 Pairwise Loss。即同时输入模型关于同一个问题的两个回答,让模型学会这两个句子哪个分高哪个分低。这是因为在给奖励模型进行数据标注的过程中,给同一个问题的不同回答量化具体分值比较难,但对它们进行排序相对简单。
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]
Pairwise Loss 代码如下,如果给 pair 里好的句子打分高(c_truncated_reward),坏的句子(r_truncated_reward)打分低,loss 就会小:
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
4. 推理过程中的 Critic Model
在训练强化学习的过程中,会用到 Reward Model(Critic Model,两者是同一模型的两个副本)的推理过程。通过调用 forward_value 实现,返回的值中有两种值:
values:表示每个位置 i,从第 i 个位置到最后一个位置的奖励累加值,供强化学习过程中 Critic Model 使用。
chosen_end_scores:指的是对每个 Prompt + Answer 的打分,供 Reward Model 使用。
def forward_value(...):
...
if return_value_only:
return values
else:
...
return {
"values": values,
"chosen_end_scores": torch.stack(chosen_end_scores),
}
强化学习微调
1. 四个核心模型
强化学习微调阶段,会用到 4 个模型:Actor Model、Ref_Model、Reward Model 和 Critic Model。
- Actor Model 和 Ref_Model 是 RLHF 第一阶段有监督微调模型的两个副本。
- Reward Model 和 Critic Model 是本文第一部分训练出来的模型的两个副本。
2. Actor Model 的训练模式与推理模式
- 训练模式:用 Teacher Force 的方式,将整句话输入到模型中,并通过 Mask 机制在保证不泄漏未来的单词情况下预测下一个单词。
- 推理模式:真正的自回归,预测出下一个单词之后,当作下一步输入再预测下下个单词。
首先用 Actor Model 在 推理模式 下根据 Prompt 生成一个 Answer(Prompt 对应强化学习里的 State,Answer 对应一系列 Action):
with torch.no_grad():
seq = self.actor_model.module.generate(prompts,
max_length=max_min_length,
min_length=max_min_length)
然后利用 Reward Model 和 Critic Model 对输出的 Prompt + Answer 进行打分(PPO 训练时使用的奖励值并不单单是 Reward Model 的输出,还要考虑 KL 散度):
reward_score = self.reward_model.forward_value(
seq, attention_mask,
prompt_length=self.prompt_length)['chosen_end_scores'].detach()
values = self.critic_model.forward_value(
seq, attention_mask, return_value_only=True).detach()[:, :-1]
3. KL 散度约束
Actor Model 是我们想通过强化学习微调的大模型,但是强化学习过程很容易把模型训练'坏',因此需要另外一个 不会参数更新 的 Ref_Model 来当作标的,别让 Actor Model 跑偏太远。
我们在 训练模式 下,将 Prompt + Answer 分别输入到 Actor Model 和 Ref Model,用 KL 散度来衡量 Ref Model 和 Actor Model 输出的差别。同时将 KL 散度纳入损失函数(KL 散度本质是纳入到奖励值里的,奖励值被纳入到了损失函数),进而来约束 Ref_Model 和 Actor Model 的输出分布别差距太大。
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
logits = output.logits
logits_ref = output_ref.logits
...
return {
...
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,1:]),
...
}
...
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_logprobs)
PPO 训练时候的奖励值综合考虑 KL 散度和 Reward 模型的输出,只考虑 Answer 部分的 KL 散度,将 Reward Model 的输出加到 KL 散度 L 维度的最后一个位置上,得到最终的奖励值:
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j]
4. PPO 训练核心:Advantage 计算
接下来的内容是 PPO 训练过程的核心,目标是计算 PPO 更新公式里的 Advantage,具体公式如下,V 就是 Critic Model 的输出。
def get_advantages_and_returns(self, values, rewards, start):
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = (rewards[:, t] + self.gamma * nextvalues) - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns
5. Actor 与 Critic 模型更新
以上过程,我们已经拿到了 PPO 训练所需要的 Advantage 以及 Actor Model 的输出,现在可以对 Actor Model 进行训练啦。
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
同样的,我们也要对 Critic Model 进行训练,更新,Loss 就是 MSE Loss。
def critic_loss_fn(self, values, old_values, returns, mask):
values_clipped = torch.clamp(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss
至此,我们的 RLHF 训练流程就结束了。第二部分开头我们说过,共涉及 Actor Model、Ref_Model、Reward Model 和 Critic Model 这四个模型,其实更新参数的模型只有 Actor Model 和 Critic Model。
总结与展望
RLHF 技术使得大模型能够更好地对齐人类价值观,提升回答质量。整个流程主要包括三个步骤:
- 有监督微调(SFT):构建高质量指令数据集,训练模型遵循指令。
- 奖励模型(RM)训练:基于偏好数据训练 Reward Model,用于评估回答优劣。
- 强化学习(PPO):利用 Reward Model 和 Critic Model 指导 Actor Model 更新策略,同时通过 KL 散度防止模型偏离原始分布过远。
在实际应用中,RLHF 面临诸多挑战,如奖励黑客(Reward Hacking)、训练不稳定、计算资源消耗巨大等。未来研究方向包括改进采样效率、探索更高效的 PPO 变体、以及减少对人标注数据的依赖。理解上述代码逻辑有助于开发者更好地调试和优化自己的 RLHF 系统。
建议参考 HuggingFace 官方博客了解更详细的算法原理,并结合实际业务场景调整超参数。