跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

近端策略优化算法 (PPO) 详解与代码实现

近端策略优化(PPO)是一种强化学习算法,旨在平衡性能提升与训练稳定性。通过限制策略更新幅度,使用裁剪机制防止策略变化过大,同时利用优势函数引导改进。文章介绍了 PPO 的核心思想、数学推导、损失函数构成及 Actor-Critic 网络结构,并提供了基于 PyTorch 的完整代码实现与环境配置说明。

Ne0发布于 2025/11/13更新于 2026/4/243 浏览
近端策略优化算法 (PPO) 详解与代码实现

近端策略优化算法 (PPO) 详解

PPO 算法介绍

近端策略优化(PPO,Proximal Policy Optimization)是一种强化学习算法,设计的目的是在复杂任务中既保证性能提升,又让算法更稳定和高效。

1. 背景

PPO 是 OpenAI 在 2017 年提出的一种策略优化算法,专注于简化训练过程,克服传统策略梯度方法(如 TRPO)的计算复杂性,同时保证训练效果。

  • 问题:在强化学习中,直接优化策略会导致不稳定的训练,模型可能因为过大的参数更新而崩溃。
  • 解决方案:PPO 通过限制策略更新幅度,使得每一步训练都不会偏离当前策略太多,同时高效利用采样数据。
2. PPO 的核心思想

PPO 的目标是通过以下方式改进策略梯度优化:

  1. 限制策略更新幅度,防止策略过度偏离。

使用优势函数 $$ A(s, a) $$ 来评价某个动作的相对好坏。

优化目标

PPO 的目标函数如下:

$$ L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$

其中:

  • 剪辑操作:将 $$ r_t(\theta) $$ 限制在区间 $$ [1-\epsilon, 1+\epsilon] $$ 内,防止策略变化过大。
  • $$ A_t $$:优势函数,通过以下公式计算: $$ A_t = Q(s_t, a_t) - V(s_t) $$ 或者用广义优势估计(GAE)的方法近似。
  • $$ r_t(\theta) $$:概率比率,表示新策略和旧策略在同一状态下选择动作的概率比值。 $$ r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)} $$
3. 为什么 PPO 很强?
  1. 简洁性:比 TRPO(Trust Region Policy Optimization)更简单,无需二次优化。
  2. 稳定性:使用剪辑机制防止策略更新过度。
  3. 高效性:利用采样数据多次训练,提高样本利用率。
4. PPO 的直观类比

假设你是一个篮球教练,训练球员投篮:

  • 如果每次训练完全改变投篮动作,球员可能会表现失常(类似于策略更新过度)。
  • 如果每次训练动作变化太小,可能很难进步(类似于更新不足)。
  • PPO 的剪辑机制就像一个'适度改进'的规则,告诉球员在合理范围内调整投篮动作,同时评估每次投篮的表现是否优于平均水平。

PPO 算法的流程推导及数学公式

PPO(Proximal Policy Optimization)也是一种策略优化算法,它的核心思想是对策略更新进行限制,使训练更加稳定,同时保持效率。以下是其数学公式推导和整体流程:

1. 算法目标

强化学习的核心目标是优化策略 $$ \pi_\theta $$,最大化累积奖励 R。策略梯度方法(如 REINFORCE)直接优化策略,但更新过大可能导致不稳定。为了解决这个问题,PPO 通过引入限制更新幅度的机制,保证策略的稳定性。

目标是优化以下期望:

$$ J(\theta) = \mathbb{E}{\pi\theta} \left[ R \right] $$

通过梯度上升法更新策略。

2. PPO 的概率比率

PPO 在优化过程中引入了概率比率,用于衡量新旧策略的差异:

$$ r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)} $$

  • $$ \pi_{\theta_{\text{old}}}(a_t | s_t) $$:旧策略对动作 $$ a_t $$ 的概率。
  • $$ \pi_\theta(a_t | s_t) $$:新策略对动作 $$ a_t $$ 的概率。
  • 这个比率表示策略变化的程度。

    3. 优化目标

    为了限制策略的更新幅度,PPO 引入了剪辑目标函数:

    $$ L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$

    PPO 的目标是找到一个折中:在保持改进的同时防止策略变化过大。

    4. 值函数优化

    PPO 不仅优化策略,还同时更新值函数 $$ V(s_t) $$,通过最小化均方误差来更新:

    $$ L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right] $$

    • $$ R_t = \sum_{k=0}^n \gamma^k r_{t+k} $$:累计回报。
    • $$ V(s_t; \theta) $$:当前状态的值函数预测。

    这个损失函数使得 Critic 能够更准确地估计状态值。

    5. 策略熵正则化

    为了鼓励策略的探索,PPO 引入了熵正则化项:

    $$ L^{ENT}(\theta) = \mathbb{E}t \left[ H(\pi\theta(s_t)) \right] $$

    • 增加熵可以防止策略过早收敛到局部最优。
    • $$ H(\pi_\theta(s_t)) $$:策略的熵,表示策略分布的不确定性。
    6. 总损失函数

    PPO 结合策略损失、值函数损失和熵正则化项,形成总损失函数:

    $$ L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right] $$

    • $$ c_1 $$ 和 $$ c_2 $$:权重系数,用于平衡策略优化、值函数更新和熵正则化。
    7. PPO 算法流程

    PPO 可以简化为以下步骤:

    1. 重复以上步骤:通过多轮迭代,使策略逐步优化,直到收敛。
    2. 值函数更新:用以下损失函数优化值函数 $$ V(s_t) $$: $$ L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t) - R_t \right)^2 \right] $$
    3. 策略更新:如果更新过大(超出剪辑范围 $$ 1-\epsilon $$ 到 $$ 1+\epsilon $$),会被惩罚。保证更新幅度适中,既不太保守,也不太激进。
    4. 计算概率比率:比较新策略和旧策略对动作 $$ a_t $$ 的选择概率。
    5. 计算优势函数:评估某个动作 $$ a_t $$ 在状态 $$ s_t $$ 下相对于平均表现的优劣(优势函数 $$ A_t $$)。利用 $$ A_t $$ 引导策略改进。
    6. 采样:使用当前策略 $$ \pi_\theta $$ 与环境交互,收集状态 $$ s_t $$、动作 $$ a_t $$、奖励 $$ r_t $$。
    8. PPO 算法的关键

    PPO 的关键公式和目标可以概括如下:

    • 探索与稳定性平衡:通过 $$ L^{ENT}(\theta) $$,鼓励策略探索。
    • 同时优化值函数:通过 $$ L^{VF}(\theta) $$,提高 Critic 的预测精度。
    • 限制更新幅度:通过剪辑函数 $$ \text{clip}() $$,避免策略更新过大导致不稳定。
    • 核心目标:优化策略,使 $$ r_t(\theta) A_t $$ 的改进在限制范围内。

    这种设计使得 PPO 在训练过程中高效且稳定,是目前强化学习中的常用算法之一。

    Python PPO 算法的代码实现

    以下是使用 PyTorch 实现 PPO(Proximal Policy Optimization)算法的完整代码。

    """
    PPO 算法代码示例
    环境:gym
    作者:技术编辑整理
    """
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.distributions import Categorical
    import numpy as np
    import gym
    
    # 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Actor-Critic 神经网络定义
    class ActorCritic(nn.Module):
        def __init__(self, state_dim, action_dim):
            super(ActorCritic, self).__init__()
            # 共享层
            self.shared_layer = nn.Sequential(
                nn.Linear(state_dim, 128),
                nn.ReLU()
            )
            # Actor (策略网络)
            self.actor = nn.Sequential(
                nn.Linear(128, action_dim),
                nn.Softmax(dim=-1)
            )
            # Critic (价值网络)
            self.critic = nn.Linear(128, 1)
    
        def forward(self, state):
            shared = self.shared_layer(state)
            action_probs = self.actor(shared)
            state_value = self.critic(shared)
            return action_probs, state_value
    
    # 经验存储类
    class Memory:
        def __init__(self):
            self.states = []
            self.actions = []
            self.logprobs = []
            self.rewards = []
            self.is_terminals = []
    
        def clear(self):
            self.states = []
            self.actions = []
            self.logprobs = []
            self.rewards = []
            self.is_terminals = []
    
    # PPO Agent 类
    class PPO:
        def __init__(self, state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4):
            self.policy = ActorCritic(state_dim, action_dim).to(device)
            self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
            self.policy_old = ActorCritic(state_dim, action_dim).to(device)
            self.policy_old.load_state_dict(self.policy.state_dict())
            self.MseLoss = nn.MSELoss()
            self.gamma = gamma
            self.eps_clip = eps_clip
            self.K_epochs = K_epochs
    
        def select_action(self, state, memory):
            state = torch.FloatTensor(state).to(device)
            action_probs, _ = self.policy_old(state)
            dist = Categorical(action_probs)
            action = dist.sample()
            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(dist.log_prob(action))
            return action.item()
    
        def update(self, memory):
            old_states = torch.stack(memory.states).to(device).detach()
            old_actions = torch.stack(memory.actions).to(device).detach()
            old_logprobs = torch.stack(memory.logprobs).to(device).detach()
    
            # 计算折扣奖励
            rewards = []
            discounted_reward = 0
            for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
                if is_terminal:
                    discounted_reward = 0
                discounted_reward = reward + (self.gamma * discounted_reward)
                rewards.insert(0, discounted_reward)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
    
            # 更新 K 个 epoch
            for _ in range(self.K_epochs):
                action_probs, state_values = self.policy(old_states)
                dist = Categorical(action_probs)
                new_logprobs = dist.log_prob(old_actions)
                entropy = dist.entropy()
    
                ratios = torch.exp(new_logprobs - old_logprobs.detach())
                advantages = rewards - state_values.detach().squeeze()
    
                surr1 = ratios * advantages
                surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
                loss_actor = -torch.min(surr1, surr2).mean()
                loss_critic = self.MseLoss(state_values.squeeze(), rewards)
                loss = loss_actor + 0.5 * loss_critic - 0.01 * entropy.mean()
    
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
    
            self.policy_old.load_state_dict(self.policy.state_dict())
    
    # 主程序
    if __name__ == "__main__":
        env = gym.make("CartPole-v1")
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
    
        ppo = PPO(state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4)
        memory = Memory()
    
        max_episodes = 1000
        max_timesteps = 300
    
        for episode in range(1, max_episodes + 1):
            state = env.reset()
            total_reward = 0
            for t in range(max_timesteps):
                action = ppo.select_action(state, memory)
                state, reward, done, _ = env.step(action)
                memory.rewards.append(reward)
                memory.is_terminals.append(done)
                total_reward += reward
                if done:
                    break
            ppo.update(memory)
            memory.clear()
            print(f"Episode {episode}, Total Reward: {total_reward}")
    
        env.close()
    

    代码解释

    • Actor-Critic 网络结构:ActorCritic 模型通过共享层生成动作概率和状态值。
    • PPO 的优化目标:使用裁剪的目标函数限制策略更新幅度,避免过度更新。
    • 内存存储:Memory 用于存储每一轮的状态、动作、奖励和终止标志。
    • 策略更新:通过多个 epoch 更新,计算优势函数和裁剪后的策略梯度。
    • 奖励归一化:使用标准化方法对奖励进行处理,以加快收敛。
    • 训练循环:PPO 从环境中采样,更新策略,打印每一集的总奖励。

    总结

    PPO 的关键是通过限制策略的变化范围(剪辑),让优化更加稳定,同时通过优势函数引导策略改进,充分利用采样数据。这种平衡使得 PPO 成为许多强化学习任务的默认算法。

    PPO 算法、TRPO 算法和 A3C 算法对比

    特性PPO (Proximal Policy Optimization)TRPO (Trust Region Policy Optimization)A3C (Asynchronous Advantage Actor-Critic)
    核心思想使用裁剪的目标函数,限制策略更新幅度,保持稳定性和效率。限制策略更新的步幅(Trust Region),通过二次约束优化确保稳定性。通过异步多线程运行环境并行采样和训练,降低方差并加快收敛速度。
    优化目标函数引入剪辑机制通过 KL 散度限制策略更新优化策略梯度
    更新方式同步更新,支持多轮迭代更新样本数据以提高效率。同步更新,通过优化约束的目标函数严格限制更新步长。异步更新,多个线程独立采样和更新全局模型。
    计算复杂度低,计算简单,使用裁剪避免复杂的二次优化问题。高,涉及二次优化问题,计算复杂,资源需求较大。较低,依赖异步线程并行计算,资源利用率高。
    样本利用率高效,可重复利用采样数据进行多轮梯度更新。高效,严格优化目标,提升了样本效率。较低,因为每个线程独立运行,可能导致数据重复和冗余。
    实现难度中等,使用简单的裁剪方法,适合大多数场景。高,涉及复杂的约束优化和实现细节。较低,直接异步实现,简单易用。
    收敛速度快,因裁剪机制限制更新幅度,能快速稳定收敛。慢,因严格的步幅限制,收敛稳定但需要较多训练迭代。快,因多线程并行采样,能够显著减少训练时间。
    稳定性高,裁剪机制限制过大更新,避免不稳定行为。高,严格限制更新步幅,保证策略稳定改进。较低,异步更新可能导致收敛不稳定(如策略冲突)。
    应用场景广泛使用,适合大规模环境或复杂问题。适合需要极高稳定性的场景,如机器人控制等。适合资源受限的场景或需要快速实验的任务,如强化学习基准测试。
    优点简单易实现,收敛快,稳定性高,是主流强化学习算法。理论支持强,更新步幅严格受控,策略非常稳定。异步更新高效,能够充分利用多线程资源,加速训练。
    缺点理论支持弱于 TRPO,可能过于保守。实现复杂,计算资源需求高,更新速度慢。异步更新可能导致训练不稳定,样本利用率较低。
    论文来源Schulman et al., "Proximal Policy Optimization Algorithms" (2017)Schulman et al., "Trust Region Policy Optimization" (2015)Mnih et al., "Asynchronous Methods for Deep Reinforcement Learning" (2016)
    三种算法的对比总结
    1. PPO 是 TRPO 的改进版:PPO 使用简单的裁剪机制代替了 TRPO 的二次优化,显著降低了实现复杂度,同时保持了良好的稳定性和效率。
    2. A3C 的并行化设计:A3C 的核心是通过多线程异步更新提升效率,但其稳定性略低于 PPO 和 TRPO。
    3. 实用性:PPO 因其简单、稳定、高效的特点,已成为强化学习领域的主流算法;TRPO 更适合需要极高策略稳定性的任务;A3C 在资源受限的场景下表现优异。

    目录

    1. 近端策略优化算法 (PPO) 详解
    2. PPO 算法介绍
    3. 1. 背景
    4. 2. PPO 的核心思想
    5. 优化目标
    6. 3. 为什么 PPO 很强?
    7. 4. PPO 的直观类比
    8. PPO 算法的流程推导及数学公式
    9. 1. 算法目标
    10. 2. PPO 的概率比率
    11. 3. 优化目标
    12. 4. 值函数优化
    13. 5. 策略熵正则化
    14. 6. 总损失函数
    15. 7. PPO 算法流程
    16. 8. PPO 算法的关键
    17. Python PPO 算法的代码实现
    18. 设备配置
    19. Actor-Critic 神经网络定义
    20. 经验存储类
    21. PPO Agent 类
    22. 主程序
    23. 代码解释
    24. 总结
    25. PPO 算法、TRPO 算法和 A3C 算法对比
    26. 三种算法的对比总结
    • 💰 8折买阿里云服务器限时8折了解详情
    • 💰 8折买阿里云服务器限时8折购买
    • 🦞 5分钟部署阿里云小龙虾了解详情
    • 🤖 一键搭建Deepseek满血版了解详情
    • 一键打造专属AI 智能体了解详情
    极客日志微信公众号二维码

    微信扫一扫,关注极客日志

    微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog

    更多推荐文章

    查看全部
    • OpenClaw 配置指南:接入第三方 API 使用大模型
    • Meta Llama 4 Scout MoE 模型技术架构与性能解析
    • 自然语言处理在法律领域的应用与实战
    • 创建 GitHub 私人仓库并上传本地项目的完整步骤
    • OpenClaw AI 助手框架搭建与配置指南
    • Vue Print Designer 前端可视化打印设计器
    • 2025 前端复盘:框架内卷落幕,AI 重构生态与开发者破局
    • 链表算法实战:相交节点查找与回文结构判断
    • Git LFS 安装教程:Linux、macOS 与 Windows 全平台指南
    • C++ 继承机制详解
    • 使用 Java 实现简单高效的任务调度框架
    • Python 基础:错误与异常处理详解
    • Vitis 烧录 FPGA 失败排查:底层驱动、权限与硬件配置
    • Stable Diffusion 3 发布:20 亿参数 Medium 模型与 MMDiT 架构解析
    • SBUS 协议详解:从原理、硬件接口到代码实现
    • 基于 LSTM 神经网络的学生学习情况分析系统
    • 基于 Java 的百度地图驾车路线规划服务开发指南
    • 几种生成唯一序列号的常用方法
    • 通过官方 API 搭建 QQ 群聊机器人
    • 基于百度天气接口的空气质量 WebGIS 可视化实践——以湖南省为例

    相关免费在线工具

    • 加密/解密文本

      使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

    • RSA密钥对生成器

      生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

    • Mermaid 预览与可视化编辑

      基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

    • 随机西班牙地址生成器

      随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

    • Gemini 图片去水印

      基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

    • curl 转代码

      解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online