一直对大模型的强化学习微调(RLHF)过程感到好奇,网上相关文章虽多,但往往浅尝辄止。本文结合微软 DeepSpeed 的 RLHF 实现代码,深入讲解奖励模型训练和强化学习微调部分。
背景与原理
在语言模型中使用强化学习微调时,Prompt 对应状态(State),输出的一系列单词对应动作(Action)。与游戏场景不同,这里通常只输入一次 Prompt,得到一串回答后获得一个 Reward。不过,对于回答的第二个词,可以将 Prompt+ 第一个词作为新的 State,状态转移蕴含在 Transformer 内部。
大模型中的 PPO 涉及 Critic Model 的概念,用于预测从 t 时刻到最后一个时刻的累加奖励值,而非通过实际累加得到,这样可以降低奖励方差。RLHF 主要包含两个阶段:有监督微调(SFT)和强化学习微调。
奖励(Reward)模型训练
在强化学习阶段,使用的 Reward Model 和 Critic Model 通常使用同一个模型初始化。Reward Model 的输入是 prompt+answer 的形式,让模型学会对 prompt+answer 进行打分。
Reward Model 最后一层隐藏层的输出维度为 (B, L, D),通过一个 D×1 的全连接层将维度变为 (B, L)。L 维度上第 i 个位置的数据表示从第 i 个位置到最后一个位置输出所能获得的奖励分值的累加和。这种形式的输出满足了 Critic Model 的输出要求。
# huggingface 模型返回值是个 list,第 0 位是模型最后输出的 hidden state
hidden_states = transformer_outputs[0]
# v_head 为 Dx1 的全连接网络对最后一维压缩
rewards = self.v_head(hidden_states).squeeze(-1)
对于一个 Reward Model 来说,目标是给一个句子进行打分。目前对于长度为 L 的句子,Reward Model 输出了 L 个值。我们用 L 维度上的最后一个位置的值作为本句话的奖励得分。
Reward Model 训练优化采用 Pairwise Loss,即同时输入模型关于同一个问题的两个回答,让模型学会这两个句子哪个分高哪个分低。这是因为在给 Reward Model 进行数据标注的过程中,给同一个问题的不同回答量化具体分值比较难,但是对他们进行排序相对简单。
# 同一个 batch 里的句子需要等长,短句后边会被 padding
# [divergence_ind:end_ind] 索引了 padding 前一个位置的输出分值
# chosen_reward 是同一个句子 pair 里分数高的句子,r_truncated_reward 是句子 pair 里分数低的句子
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]
Pairwise Loss 代码如下,如果给 pair 里边好的句子打分高,坏的句子打分低,loss 就会小:
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
在训练强化学习的过程中,会用到 Reward Model(Critic Model)的推理过程,通过调用 forward_value 实现。返回的值中有两种值:values 表示每个位置 i,从第 i 个位置到最后一个位置的奖励累加值;chosen_end_scores 指的是对每个 prompt+answer 的打分。
def forward_value(...):
...
if return_value_only:
# (B,L)
return values
else:
...
return {
: values,
: torch.stack(chosen_end_scores),
}


