跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

近端策略优化算法 PPO 详解与 PyTorch 实现

综述由AI生成近端策略优化(PPO)是强化学习中兼顾稳定性与效率的策略梯度算法。通过裁剪概率比率限制更新幅度,结合优势函数与熵正则化平衡探索与利用。梳理了 PPO 的数学推导、损失函数构成及与 TRPO、A3C 的对比,并提供了基于 PyTorch 的 Actor-Critic 网络完整实现,涵盖环境交互、经验回放、策略更新等核心模块,适合希望深入理解并落地 PPO 的开发者参考。

NodeJser发布于 2026/3/22更新于 2026/6/916 浏览
近端策略优化算法 PPO 详解与 PyTorch 实现

近端策略优化算法 (PPO) 详解

近端策略优化(Proximal Policy Optimization, PPO)是一种强化学习算法,旨在复杂任务中既保证性能提升,又维持训练的稳定性和高效性。它由 OpenAI 在 2017 年提出,通过限制策略更新幅度,克服了传统策略梯度方法(如 TRPO)的计算复杂性,同时保持了良好的采样效率。

背景与核心思想

在强化学习中,直接优化策略往往会导致不稳定的训练,模型可能因为过大的参数更新而崩溃。PPO 的核心在于引入概率比率来衡量新旧策略的差异,并通过裁剪机制防止策略变化过大。

概率比率

PPO 使用以下比率来评估动作选择的变化:

$$r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}$$

其中 $\pi_\theta$ 是新策略,$\pi_{\theta_{\text{old}}}$ 是旧策略。这个比率反映了在当前状态下,新策略选择动作 $a_t$ 的概率相对于旧策略的变化程度。

优化目标

为了限制更新幅度,PPO 引入了剪辑目标函数:

$$L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$

这里的 $A_t$ 是优势函数,表示某个动作相对于平均表现的优劣。clip 操作将概率比率限制在区间 $[1-\epsilon, 1+\epsilon]$ 内,避免策略更新过于激进。

为什么 PPO 很强?

  1. 简洁性:相比 TRPO,无需复杂的二次优化,实现更简单。
  2. 稳定性:剪辑机制有效防止了策略更新过度导致的性能崩塌。
  3. 高效性:支持对采样数据进行多次迭代训练,提高了样本利用率。

数学推导与损失函数

PPO 的训练过程涉及三个主要部分的损失函数组合:策略损失、值函数损失和熵正则化。

1. 策略损失 (Policy Loss)

即上述的裁剪目标函数,目的是最大化优势函数的期望,同时限制策略分布的变化范围。

2. 值函数优化 (Value Function)

Critic 网络负责估计状态价值 $V(s)$,通过最小化均方误差进行更新:

$$L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right]$$

其中 $R_t$ 是累计回报。准确的值函数估计有助于计算更精确的优势函数。

3. 策略熵正则化 (Entropy Regularization)

为了防止策略过早收敛到局部最优,鼓励探索,引入熵项:

$$L^{ENT}(\theta) = \mathbb{E}t \left[ H(\pi\theta(s_t)) \right]$$

4. 总损失函数

综合以上三项,总损失函数通常定义为:

$$L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right]$$

系数 $c_1$ 和 $c_2$ 用于平衡不同目标的权重。

PyTorch 代码实现

下面展示一个基于 PyTorch 的完整 PPO 实现框架。代码结构清晰,包含 Actor-Critic 网络、经验存储类以及训练循环。

1. Actor-Critic 神经网络

我们使用共享层提取特征,分别输出动作概率和状态值。

import torch
import torch.nn as nn
from torch.distributions import Categorical

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        # 共享特征提取层
        self.shared_layer = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU()
        )
        # Actor: 输出动作概率分布
        self.actor = nn.Sequential(
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
        # Critic: 输出状态价值
        self.critic = nn.Linear(128, 1)

    def forward(self, state):
        shared = self.shared_layer(state)
        action_probs = self.actor(shared)
        state_value = self.critic(shared)
        return action_probs, state_value

这里 shared_layer 将状态映射到隐层表示,actor 使用 Softmax 确保输出为合法的概率分布,critic 则预测当前状态的预期回报。

2. 经验存储类 (Memory)

PPO 需要批量更新,因此需要一个缓冲区来存储 episode 的数据。

class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear(self):
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

3. PPO Agent 初始化

初始化策略网络、优化器以及超参数。

import torch.optim as optim

class PPO:
    def __init__(self, state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4):
        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.MseLoss = nn.MSELoss()
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

注意 policy_old 用于计算概率比率,每次更新后需要同步最新参数。

4. 动作选择

根据当前策略采样动作,并记录对数概率。

def select_action(self, state, memory):
    state = torch.FloatTensor(state).to(device)
    with torch.no_grad():
        action_probs, _ = self.policy_old(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(dist.log_prob(action))
    return action.item()

5. 策略更新

这是 PPO 的核心部分,包括奖励归一化、优势计算和裁剪损失更新。

def update(self, memory):
    old_states = torch.stack(memory.states).to(device).detach()
    old_actions = torch.stack(memory.actions).to(device).detach()
    old_logprobs = torch.stack(memory.logprobs).to(device).detach()

    # 计算折扣回报
    rewards = []
    discounted_reward = 0
    for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
        if is_terminal:
            discounted_reward = 0
        discounted_reward = reward + (self.gamma * discounted_reward)
        rewards.insert(0, discounted_reward)
    
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    # 奖励归一化,加速收敛
    rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

    # 多轮 Epoch 更新
    for _ in range(self.K_epochs):
        action_probs, state_values = self.policy(old_states)
        dist = Categorical(action_probs)
        new_logprobs = dist.log_prob(old_actions)
        entropy = dist.entropy()

        # 计算概率比率
        ratios = torch.exp(new_logprobs - old_logprobs.detach())
        advantages = rewards - state_values.detach().squeeze()

        # 裁剪损失
        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
        loss_actor = -torch.min(surr1, surr2).mean()

        # 值函数损失
        loss_critic = self.MseLoss(state_values.squeeze(), rewards)

        # 总损失
        loss = loss_actor + 0.5 * loss_critic - 0.01 * entropy.mean()

        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    # 同步旧策略
    self.policy_old.load_state_dict(self.policy.state_dict())

这里有一个细节需要注意:advantages 的计算依赖于 state_values,如果 Critic 不准,Advantage 也会偏差较大,所以 Critic 的损失权重通常也需要调整。

6. 主程序流程

完整的训练循环如下:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import gym
env = gym.make("CartPole-v1")

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

ppo = PPO(state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4)
memory = Memory()

for episode in range(1, 1001):
    state = env.reset()
    total_reward = 0
    for t in range(300):
        action = ppo.select_action(state, memory)
        state, reward, done, _ = env.step(action)
        memory.rewards.append(reward)
        memory.is_terminals.append(done)
        total_reward += reward
        if done:
            break
    ppo.update(memory)
    memory.clear()
    print(f"Episode {episode}, Total Reward: {total_reward}")

算法对比总结

特性PPOTRPOA3C
核心思想裁剪概率比率限制更新信任区域约束优化异步多线程并行
优化目标引入剪辑机制KL 散度限制策略梯度
更新方式同步,支持多轮迭代同步,严格约束异步,独立线程
计算复杂度低高 (二次规划)较低
稳定性高极高中等
适用场景通用,主流选择需极高稳定性的控制资源受限或快速实验

PPO 作为 TRPO 的改进版,用简单的裁剪机制替代了复杂的二次优化,显著降低了实现难度,同时保持了良好的稳定性和效率。对于大多数强化学习任务,PPO 都是首选的默认算法。

注意事项

  1. 环境适配:代码示例基于 Gym 的 CartPole 环境,实际项目中可能需要针对特定环境调整网络结构或超参数。
  2. 超参数调优:学习率、折扣因子和裁剪阈值对训练效果影响较大,建议根据具体任务进行网格搜索。
  3. 奖励归一化:在长序列任务中,奖励归一化能有效缓解梯度消失或爆炸问题。
  4. 设备管理:确保 device 变量正确配置以利用 GPU 加速训练。

目录

  1. 近端策略优化算法 (PPO) 详解
  2. 背景与核心思想
  3. 概率比率
  4. 优化目标
  5. 为什么 PPO 很强?
  6. 数学推导与损失函数
  7. 1. 策略损失 (Policy Loss)
  8. 2. 值函数优化 (Value Function)
  9. 3. 策略熵正则化 (Entropy Regularization)
  10. 4. 总损失函数
  11. PyTorch 代码实现
  12. 1. Actor-Critic 神经网络
  13. 2. 经验存储类 (Memory)
  14. 3. PPO Agent 初始化
  15. 4. 动作选择
  16. 5. 策略更新
  17. 6. 主程序流程
  18. 算法对比总结
  19. 注意事项
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Figma 搭配 Claude 与 Weavy AI:从原型到素材的完整工作流
  • OpenClaw:使用 AI 直接操控电脑的工具指南
  • LoRA 指令微调核心原理与实战细节
  • Flutter 集成 BIP340 实现鸿蒙端 Schnorr 签名实战
  • Figma + Claude + Weavy AI:构建高效设计工作流
  • Llama-Factory 微调:Warmup 步数设置与线性增长策略
  • AIGC 浪潮下的 Model Context Protocol (MCP) 详解
  • OpenClaw:实现 AI 直接操控电脑的本地部署指南
  • Python AI Agent 智能体构建指南:从原理到实战
  • OpenClaw 国内 AI 大模型配置教程
  • 机器学习与数据挖掘实战:基于 K-means 和决策树的餐饮企业分析
  • 开源轻小说机翻机器人部署与架构解析
  • Discord中创建机器人的流程
  • MySQL GROUP BY 语句语法及实例演示
  • AI Skill 开发实战:网页内容抓取功能实现
  • .NET 微服务架构:从 WebAPI 到 Docker 实战
  • MolStar 分子可视化工具深度解析
  • Windows 系统安装并编译 llama.cpp 步骤详解
  • DeepSeek-Coder vs Copilot:嵌入式开发场景适配性对比实战
  • Meta 开源大模型 LLaMA2 的本地部署与运行指南

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online