本文分享在学习和实践 RLHF 时曾陷入的一些思维误区。这些误区的产生大多和强化基础知识理解不到位有关,建议非强化背景的同学耐心阅读以下内容。
RLHF 代码流程(背景知识)
Input: prompt,sft_model,reward_model
Initialize: actor_model / ref_model = sft_model,critic_model / reward_model = reward_model
Training:
step1:生产计算 loss 的中间数据(每次 rollout_batch_size 条数据):
actor_model generate,得到 prompt + response
reward_model predict,得到 reward,需要 clip 到一个区间内;
reference_model / actor_model forward,得到两个 log_probs,分别是 $\log \pi_{\theta}(a|s)$ 和 $\log \pi_{ref}(a|s)$,计算 KL_penalty;
critic_model forward,得到 values,此处记作 $V(s)$;
reward - KL_penalty:得到修正后的 reward(此处是 reference_model 的生效位置);
reward 和 values 反传,利用 PPO 论文中的计算公式,得到 advantages 和 returns(此处是 GAE 的生效位置)。
step2:更新 loss(每次 train_batch_size 条数据,反复调用,直到 step1 的数据用尽):
actor_model forward,又得到一个 log_probs,这里的 actor_model 是 $\pi_{\theta}$,与 step1 中得到的 $\pi_{old}$ 和 advantages 一起计算 loss(此处是 Importance Sampling 和 CLIP 的生效位置);
actor_model backward;
critic_model forward,又得到一个 values,此处记作 v,与 step1 中得到的 V(s) 和 returns 一起计算 loss,引入 clip 防止 critic_model 的更新幅度太大;
critic_model backward。
RLHF 训练流程的难点主要集中在以下几个问题:
- PPO 的处理技巧:CLIP,GAE,Importance Sampling
- advantages 和 returns 是怎么计算得到的?
- 凭什么 advantages 和 returns 可以作为 actor_model 和 critic_model 的优化目标?
- 整个训练过程为什么存在三个 policy 模型?
回答不上来的同学建议好好读下相关基础文章,然后深入 OpenRLHF 的源码。
RLHF 不等于 PPO
RLHF 的含义是'通过强化学习的训练方式,利用人类反馈来优化语言模型',其和 PPO 是不能完全划等号的,二者的区别主要在于:
- PPO 仅仅是 OpenAI 最喜欢的强化学习训练方法,其他强化学习训练方法也可以;
- Critic_model 是 PPO 需要的模块,并非 RLHF 必备模块,如果 RL 算法是 actor-critic 系列,那么就会引入 critic_model,如果是 REINFORCE 系列,那么就可以省掉 critic_model;
- Reference_model 是 RLHF 提出的概念,目标是防止语言模型在训练中崩溃,和 PPO 没有任何关系;
- Reward_model 也是 RLHF 提出的概念,目标是自动生产训练数据,同样和 PPO 没有关系,其他生产数据的方法(比如 verifier)也可以。
这也就是说,虽然我们熟知的 RLHF 是 4 个模型组成的,但实际上只需要准备好一个微调后的 sft_model 即可启动,毕竟 reference_model 与 sft_model 是同一个模型,critic_model 和 reward_model 并非必备模块。
一言以概之,RLHF = LLM + 任意 RL 算法 + 数据打分工具。
RL 的常用技巧并非 RLHF 必备
强化学习和传统监督学习一个很大的区别就是'训练数据是当场采集出来的',一边训模型,一边造数据。在传统的强化学习任务中,训练数据的生产是很困难的,比如下完一盘围棋、打完一盘马里奥游戏……往往是十几分钟产生一条训练数据(trajectory),但不到一秒就训完了。然而,在 RLHF 的场景下,训练数据还真就不难生产,生产 1 条 response 和训练 1 条 response 还真不一定谁更快(不过 1 次生产 N 条 response 确实快于 N 次生产 1 条 response)。
下面我就围绕重要性采样这一技巧来展开讲讲。RL 算法引入重要性采样这一概念,其目的是:可以一次性生产多条数据,或者说让生产的数据可以反复使用。


