强化学习微调阶段,会用到 4 个模型:Actor Model、Ref_Model、Reward Model 和 Critic Model。其中 Actor Model 和 Ref_Model 是 RLHF 第一阶段有监督微调模型的两个副本,Reward Model 和 Critic Model 是本文第一部分训练出来的模型的两个副本。
首先说明 Actor Model 的训练模式和推理模式的区别。训练模式是用 Teacher Force 的方式,将整句话输入到模型中,并通过 mask 机制在保证不泄漏未来的单词情况下预测下一个单词。推理模式是真正的自回归,预测出下一个单词之后,当作下一步输入再预测下下个单词。
Actor Model 是我们想通过强化学习微调的大模型,但是强化学习过程很容易把模型训练坏,因此需要另外一个不会参数更新的 Ref_Model 来当作标的,别让 Actor Model 跑偏太远。我们在训练模式下,将 prompt+answer 分别输入到 Actor Model 和 Ref Model,用 KL 散度来衡量 Ref Model 和 Actor Model 输出的差别。同时将 KL 散度纳入损失函数,进而来约束 Ref_Model 和 Actor Model 的输出分布别差距太大。
接下来的内容是 PPO 训练过程的比较核心内容了,目标是计算 PPO 更新公式里的 Advantage,具体公式如下,V 就是 Critic Model 的输出。
defget_advantages_and_returns(self, values, rewards, start):
# values(B,L)critic model 输出# rewards(B,)reward model 输出# start answer 开始的位置
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
# 计算每个时刻(序列位置)的 critic model 预测误差for t inreversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1else0.0# critic model 预测的是 t 到到最后一个时刻的奖励和,所以变化量 delta 可以用如下公式表示
delta = (rewards[:, t] + self.gamma * nextvalues) - values[:, t]
# self.gamma=1,self.lam=0.95 是衰减因子,表示之前计算的 delta 对现在影响越来越小
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# 后续用来更新 critic model 用
returns = advantages + values[:, start:]
return advantages.detach(), returns
以上过程,我们已经拿到了 PPO 训练所需要的 Advantage 以及 Actor Model 的输出,现在可以对 Actor Model 进行训练啦。Logprobs 和 Old_Logprobs 这两个参数分别是老 Actor 和新 Actor 在正确单词上出现的概率,这块是 PPO Importance Sampling 相关的知识。
至此,我们的 RLHF 训练流程就结束了。第二部分开头我们说过,共涉及 Actor Model、Ref_Model、Reward Model 和 Critic Model 这四个模型,其实更新参数的模型只有 Actor Model 和 Critic Model。通过这种方式,大模型能够在保持原有知识的同时,更好地遵循人类指令,提升回答质量与安全性。