跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

大模型 RLHF 强化学习微调过程详解与代码分析

综述由AI生成详细解析了大模型 RLHF 强化学习微调的全过程,涵盖奖励模型训练与 PPO 强化学习微调两个核心阶段。文章基于微软 DeepSpeed 的实现代码,解释了 Reward Model 与 Critic Model 的初始化关系、Pairwise Loss 训练方法、Actor 与 Ref Model 的 KL 散度约束机制,以及 PPO 算法中 Advantage 的计算与策略更新逻辑。通过代码片段展示了如何计算奖励值、处理 Padding、执行 Actor 和 Critic 的损失函数更新。最后总结了 RLHF 的三个主要步骤及面临的挑战,为开发者提供了完整的理论框架与工程实践参考。

Elasticer发布于 2025/2/7更新于 2026/5/3020 浏览
大模型 RLHF 强化学习微调过程详解与代码分析

大模型 RLHF 强化学习微调过程详解

引言

在大型语言模型(LLM)的训练流程中,通常包含预训练、有监督微调(SFT)和人类反馈强化学习(RLHF)三个阶段。虽然网上关于 PPO(Proximal Policy Optimization)的文章不少,但很多内容浅尝辄止。本文结合微软 DeepSpeed 的 RLHF 实现代码,深入讲解奖励模型训练和强化学习微调的核心逻辑。

在强化学习微调语言模型时,Prompt 对应状态(State),输出一串单词(Action),得到一个 Reward。与传统游戏场景不同,这里的状态转移蕴含在 Transformer 内部,每个动作后并不立即获得新的外部 State,而是通过模型内部机制更新隐状态。对于回答的第二个词,可以将 Prompt+ 第一个词当作新的 State。

奖励(Reward)模型训练

1. 模型初始化说明

在强化学习阶段,使用的 Reward Model 和 Critic Model 使用同一个模型初始化。因此,在训练 Reward 模型的过程中,实际上也在训练 Critic Model。

符号说明:大模型中间隐藏层的参数维度为 (B, L, D),其中 B 为 Batch Size,L 为句子长度,D 为 Embedding 维度。

2. 奖励模型的作用

RLHF 需要一个奖励模型来评估语言大模型(Actor Model)回答的好坏。该模型通常比被评估的语言大模型小一些(例如 Actor 66B,Reward 350M)。输入为 Prompt + Answer,输出用于打分。

奖励模型最后一层隐藏层的输出维度为 (B, L, D),通过一个 D×1 的全连接层将维度变为 (B, L)。在 L 维度上,第 i 个位置的数据表示从第 i 个位置到最后一个位置输出所能获得的奖励分值的累加和(类似 DQN 中的 Q 值)。这种形式的输出满足了 Critic Model 的输出要求。

# Huggingface 模型返回值是个 list,第 0 位是模型最后输出的 hidden state
hidden_states = transformer_outputs[0]
# v_head 为 Dx1 的全连接网络对最后一维压缩
rewards = self.v_head(hidden_states).squeeze(-1)

对于一个奖励模型来说,目标是给一个句子进行打分。按理说每个句子对应一个分值即可,但目前长度为 L 的句子,奖励模型输出了 L 个值。我们取 L 维度上的最后一个位置的值作为本句话的奖励得分。

3. Pairwise Loss 训练

奖励模型训练优化采用 Pairwise Loss。即同时输入模型关于同一个问题的两个回答,让模型学会这两个句子哪个分高哪个分低。这是因为在给奖励模型进行数据标注的过程中,给同一个问题的不同回答量化具体分值比较难,但对它们进行排序相对简单。

# 同一个 batch 里的句子需要等长,短句后边会被 padding
# [divergence_ind:end_ind] 索引了 padding 前一个位置的输出分值
# chosen_reward 是同一个句子 pair 里分数高的句子,r_truncated_reward 是句子 pair 里分数低的句子
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]

Pairwise Loss 代码如下,如果给 pair 里好的句子打分高(c_truncated_reward),坏的句子(r_truncated_reward)打分低,loss 就会小:

loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()

4. 推理过程中的 Critic Model

在训练强化学习的过程中,会用到 Reward Model(Critic Model,两者是同一模型的两个副本)的推理过程。通过调用 forward_value 实现,返回的值中有两种值:

  • values:表示每个位置 i,从第 i 个位置到最后一个位置的奖励累加值,供强化学习过程中 Critic Model 使用。
  • chosen_end_scores:指的是对每个 Prompt + Answer 的打分,供 Reward Model 使用。
def forward_value(...):
    ...
    if return_value_only:
        # (B,L)
        return values
    else:
        ...
        return {
            "values": values,
            # (B,)
            "chosen_end_scores": torch.stack(chosen_end_scores),
        }

强化学习微调

1. 四个核心模型

强化学习微调阶段,会用到 4 个模型:Actor Model、Ref_Model、Reward Model 和 Critic Model。

  • Actor Model 和 Ref_Model 是 RLHF 第一阶段有监督微调模型的两个副本。
  • Reward Model 和 Critic Model 是本文第一部分训练出来的模型的两个副本。

2. Actor Model 的训练模式与推理模式

  • 训练模式:用 Teacher Force 的方式,将整句话输入到模型中,并通过 Mask 机制在保证不泄漏未来的单词情况下预测下一个单词。
  • 推理模式:真正的自回归,预测出下一个单词之后,当作下一步输入再预测下下个单词。

首先用 Actor Model 在 推理模式 下根据 Prompt 生成一个 Answer(Prompt 对应强化学习里的 State,Answer 对应一系列 Action):

# 保证不触发反向传播
with torch.no_grad():
    seq = self.actor_model.module.generate(prompts,
    max_length=max_min_length,
    min_length=max_min_length)

然后利用 Reward Model 和 Critic Model 对输出的 Prompt + Answer 进行打分(PPO 训练时使用的奖励值并不单单是 Reward Model 的输出,还要考虑 KL 散度):

# 奖励模型返回的是个字典,key 为 chosen_end_scores 位置存储数据维度为 (B,),表示对于 prompt+answer 的打分
reward_score = self.reward_model.forward_value(
                seq, attention_mask,
                prompt_length=self.prompt_length)['chosen_end_scores'].detach()
# Critic model 返回的数据维度为 (B,L),L 维度上第 i 个位置代表从 i 位置到最后的累积奖励
# 舍去最后一个位置是因为句子'终止符'无意义
values = self.critic_model.forward_value(
                seq, attention_mask, return_value_only=True).detach()[:, :-1]

3. KL 散度约束

Actor Model 是我们想通过强化学习微调的大模型,但是强化学习过程很容易把模型训练'坏',因此需要另外一个 不会参数更新 的 Ref_Model 来当作标的,别让 Actor Model 跑偏太远。

我们在 训练模式 下,将 Prompt + Answer 分别输入到 Actor Model 和 Ref Model,用 KL 散度来衡量 Ref Model 和 Actor Model 输出的差别。同时将 KL 散度纳入损失函数(KL 散度本质是纳入到奖励值里的,奖励值被纳入到了损失函数),进而来约束 Ref_Model 和 Actor Model 的输出分布别差距太大。

# 得到两个模型的输出
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
logits = output.logits
logits_ref = output_ref.logits
...
return {
...
# 分别得到两个模型在真实单词上的预测概率
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,1:]),
...
}
...
# 计算 kl 散度,log_probs 里边存的数字经过 log 变化了,因此减法就对应除法
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_logprobs)

PPO 训练时候的奖励值综合考虑 KL 散度和 Reward 模型的输出,只考虑 Answer 部分的 KL 散度,将 Reward Model 的输出加到 KL 散度 L 维度的最后一个位置上,得到最终的奖励值:

rewards = kl_divergence_estimate
# 只考虑 answer 部分的奖励,不考虑 prompt
start = prompts.shape[1] - 1
# 不考虑 padding 部分
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                         self.clip_reward_value)
batch_size = log_probs.shape[0]
# 在 L 维度上,每个位置都有 KL 散度,但是只在最后一个位置加上奖励值
for j in range(batch_size):
    rewards[j, start:ends[j]][-1] += reward_clip[j]

4. PPO 训练核心:Advantage 计算

接下来的内容是 PPO 训练过程的核心,目标是计算 PPO 更新公式里的 Advantage,具体公式如下,V 就是 Critic Model 的输出。

def get_advantages_and_returns(self, values, rewards, start):
    # values(B,L)critic model 输出
    # rewards(B,)reward model 输出
    # start answer 开始的位置
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    # 计算每个时刻(序列位置)的 critic model 预测误差
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        # critic model 预测的是 t 到到最后一个时刻的奖励和,所以变化量 delta 可以用如下公式表示
        delta = (rewards[:, t] + self.gamma * nextvalues) - values[:, t]
        # self.gamma=1,self.lam=0.95 是衰减因子,表示之前计算的 delta 对现在影响越来越小
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    # 后续用来更新 critic model 用
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

5. Actor 与 Critic 模型更新

以上过程,我们已经拿到了 PPO 训练所需要的 Advantage 以及 Actor Model 的输出,现在可以对 Actor Model 进行训练啦。

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    #logprobs, old_logprobs 都是经过 log 变化的单词概率,这里带着 log 做减法就相当于在做概率除法
    log_ratio = (logprobs - old_logprobs) * mask
    # 指数操作去掉 log
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss

同样的,我们也要对 Critic Model 进行训练,更新,Loss 就是 MSE Loss。

def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    # 用'老 critic model'的输出约束'新 critic model'不要步子太大,裁剪一下
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

至此,我们的 RLHF 训练流程就结束了。第二部分开头我们说过,共涉及 Actor Model、Ref_Model、Reward Model 和 Critic Model 这四个模型,其实更新参数的模型只有 Actor Model 和 Critic Model。

总结与展望

RLHF 技术使得大模型能够更好地对齐人类价值观,提升回答质量。整个流程主要包括三个步骤:

  1. 有监督微调(SFT):构建高质量指令数据集,训练模型遵循指令。
  2. 奖励模型(RM)训练:基于偏好数据训练 Reward Model,用于评估回答优劣。
  3. 强化学习(PPO):利用 Reward Model 和 Critic Model 指导 Actor Model 更新策略,同时通过 KL 散度防止模型偏离原始分布过远。

在实际应用中,RLHF 面临诸多挑战,如奖励黑客(Reward Hacking)、训练不稳定、计算资源消耗巨大等。未来研究方向包括改进采样效率、探索更高效的 PPO 变体、以及减少对人标注数据的依赖。理解上述代码逻辑有助于开发者更好地调试和优化自己的 RLHF 系统。

建议参考 HuggingFace 官方博客了解更详细的算法原理,并结合实际业务场景调整超参数。

目录

  1. 大模型 RLHF 强化学习微调过程详解
  2. 引言
  3. 奖励(Reward)模型训练
  4. 1. 模型初始化说明
  5. 2. 奖励模型的作用
  6. Huggingface 模型返回值是个 list,第 0 位是模型最后输出的 hidden state
  7. v_head 为 Dx1 的全连接网络对最后一维压缩
  8. 3. Pairwise Loss 训练
  9. 同一个 batch 里的句子需要等长,短句后边会被 padding
  10. [divergenceind:endind] 索引了 padding 前一个位置的输出分值
  11. chosenreward 是同一个句子 pair 里分数高的句子,rtruncated_reward 是句子 pair 里分数低的句子
  12. 4. 推理过程中的 Critic Model
  13. 强化学习微调
  14. 1. 四个核心模型
  15. 2. Actor Model 的训练模式与推理模式
  16. 保证不触发反向传播
  17. 奖励模型返回的是个字典,key 为 chosenendscores 位置存储数据维度为 (B,),表示对于 prompt+answer 的打分
  18. Critic model 返回的数据维度为 (B,L),L 维度上第 i 个位置代表从 i 位置到最后的累积奖励
  19. 舍去最后一个位置是因为句子“终止符”无意义
  20. 3. KL 散度约束
  21. 得到两个模型的输出
  22. 分别得到两个模型在真实单词上的预测概率
  23. 计算 kl 散度,log_probs 里边存的数字经过 log 变化了,因此减法就对应除法
  24. 只考虑 answer 部分的奖励,不考虑 prompt
  25. 不考虑 padding 部分
  26. 在 L 维度上,每个位置都有 KL 散度,但是只在最后一个位置加上奖励值
  27. 4. PPO 训练核心:Advantage 计算
  28. 5. Actor 与 Critic 模型更新
  29. 总结与展望
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • PostgreSQL 插件 pgvector 核心功能与版本演进总结
  • AI 代理工具全景:Claude Code 等六大产品深度解析
  • 2025 年蓝桥杯网络安全 CTF 省赛真题详解 (Web/Misc/Crypto/Reverse)
  • AI Agent 技术架构与落地实践指南
  • 鸿蒙金融理财全栈项目:上线运维、用户反馈与持续迭代优化
  • Ubuntu 24.04 GPU 服务器测试系统盘制作
  • 2026 年 AI 漫剧工具排行榜:11 款软件横向对比
  • Spring Boot 与 Leaflet 构建省级旅游口号 WebGIS 可视化平台
  • Java Map 常用方法与核心实现类深度解析
  • 数据中心网络核心架构:Clos 架构详解
  • Seedream 4.0 深度测评:AI 图像生成从个人创作到企业级应用
  • MATLAB 智能代码生成工具 Copilot_AI 功能解析
  • AirSim 无人机仿真入门:实现起飞与降落
  • Stable Diffusion 3.5 架构解析与 FP8 量化落地优化指南
  • Python3.8 图像生成应用:Stable Diffusion 轻量化部署
  • OpenCode 安装 oh-my-opencode 插件教程:AI 辅助一键安装
  • 旧电脑也能跑 AI 员工?OpenClaw 本地部署与插件开发实战
  • HomeAssistant 接入石头扫地机器人配置实战
  • Spring AI 框架入门与核心功能详解
  • 从 JDK 8 到 JDK 21:企业级 Java 升级避坑指南

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online