跳到主要内容
近端策略优化算法 PPO 详解与代码实现 | 极客日志
Python AI 算法
近端策略优化算法 PPO 详解与代码实现 近端策略优化算法 PPO 是一种强化学习算法,旨在通过限制策略更新幅度来保证训练稳定性和效率。核心思想是利用概率比率裁剪机制防止策略变化过大,结合优势函数引导策略改进。相比 TRPO 无需二次优化,相比 A3C 更稳定。文章详细推导了数学公式,包括损失函数、值函数优化及熵正则化,并提供了基于 PyTorch 和 Gym 环境的完整代码实现,涵盖 Actor-Critic 网络结构、经验回放存储及训练循环逻辑,适合希望深入理解 PPO 原理及落地应用的开发者参考。
www 发布于 2026/3/30 更新于 2026/4/23 1 浏览近端策略优化算法 (PPO) 详解
PPO 算法介绍
近端策略优化、PPO(Proximal Policy Optimization)是一种强化学习算法,设计的目的是在复杂任务中既保证性能提升,又让算法更稳定和高效。以下用通俗易懂的方式介绍其核心概念和流程。
1. 背景
PPO 是 OpenAI 在 2017 年提出的一种策略优化算法,专注于简化训练过程,克服传统策略梯度方法(如 TRPO)的计算复杂性,同时保证训练效果。
问题:在强化学习中,直接优化策略会导致不稳定的训练,模型可能因为过大的参数更新而崩溃。
解决方案:PPO 通过限制策略更新幅度,使得每一步训练都不会偏离当前策略太多,同时高效利用采样数据。
2. PPO 的核心思想
PPO 的目标是通过以下方式改进策略梯度优化:
限制策略更新幅度,防止策略过度偏离。
使用优势函数
来评价某个动作的相对好坏。
优化目标
PPO 的目标函数如下:
其中:
剪辑操作
将
限制在区间
内,防止策略变化过大。
它表示新策略和旧策略在同一状态下选择动作的概率比值。
3. 为什么 PPO 很强?
简洁性: 比 TRPO(Trust Region Policy Optimization)更简单,无需二次优化。
稳定性: 使用剪辑机制防止策略更新过度。
高效性: 利用采样数据多次训练,提高样本利用率。
4. PPO 的直观类比
如果每次训练完全改变投篮动作,球员可能会表现失常(类似于策略更新过度)。
如果每次训练动作变化太小,可能很难进步(类似于更新不足)。
PPO 的剪辑机制就像一个'适度改进'的规则,告诉球员在合理范围内调整投篮动作,同时评估每次投篮的表现是否优于平均水平。
PPO 算法的流程推导及数学公式 PPO(Proximal Policy Optimization)也是一种策略优化算法,它的核心思想是对策略更新进行限制,使训练更加稳定,同时保持效率。以下是其数学公式推导和整体流程:
1. 算法目标 ,最大化累积奖励 R。策略梯度方法(如 REINFORCE)直接优化策略,但更新过大可能导致不稳定。为了解决这个问题,PPO 通过引入限制更新幅度的机制,保证策略的稳定性。
2. PPO 的概率比率 PPO 在优化过程中引入了概率比率 ,用于衡量新旧策略的差异:
3. 优化目标 为了限制策略的更新幅度,PPO 引入了剪辑目标函数:
PPO 的目标是找到一个折中:在保持改进的同时防止策略变化过大。
4. 值函数优化 这个损失函数使得 Critic 能够更准确地估计状态值。
5. 策略熵正则化
6. 总损失函数 PPO 结合策略损失、值函数损失和熵正则化项,形成总损失函数:
:权重系数,用于平衡策略优化、值函数更新和熵正则化。
7. PPO 算法流程
重复以上步骤: 通过多轮迭代,使策略逐步优化,直到收敛。
),会被惩罚。
保证更新幅度适中,既不太保守,也不太激进。
8. PPO 算法的关键 这种设计使得 PPO 在训练过程中高效且稳定,是目前强化学习中的常用算法之一。
Python PPO 算法的代码实现 以下是使用 PyTorch 实现 PPO(Proximal Policy Optimization)算法的完整代码
"""PPO 算法代码示例"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gym
逐行解释 PPO 代码和公式 以下是对实现的 PyTorch PPO 算法代码的详细解释,逐行结合公式解析:
1. Actor-Critic 神经网络
class ActorCritic (nn.Module):
def __init__ (self, state_dim, action_dim ):
super (ActorCritic, self ).__init__()
self .shared_layer = nn.Sequential(
nn.Linear(state_dim, 128 ),
nn.ReLU()
)
self .actor = nn.Sequential(
nn.Linear(128 , action_dim),
nn.Softmax(dim=-1 )
)
self .critic = nn.Linear(128 , 1 )
def forward (self, state ):
shared = self .shared_layer(state)
action_probs = self .actor(shared)
state_value = self .critic(shared)
return action_probs, state_value
shared_layer: 将状态 s 映射到一个隐层表示,使用 ReLU 激活函数。
critic: 输出状态值函数 V(s),表示在状态 s 下的预期累计奖励。
actor: 输出策略
,表示在状态 s 下选择动作 a 的概率分布。使用 Softmax 确保输出是概率。
2. Memory 类
class Memory :
def __init__ (self ):
self .states = []
self .actions = []
self .logprobs = []
self .rewards = []
self .is_terminals = []
def clear (self ):
self .states = []
self .actions = []
self .logprobs = []
self .rewards = []
self .is_terminals = []
states: 状态、actions: 动作、logprobs: 动作的对数概率、rewards: 即时奖励、is_terminals: 是否为终止状态(布尔值)
3. PPO 初始化
class PPO :
def __init__ (self, state_dim, action_dim, lr=0.002 , gamma=0.99 , eps_clip=0.2 , K_epochs=4 ):
self .policy = ActorCritic(state_dim, action_dim).to(device)
self .optimizer = optim.Adam(self .policy.parameters(), lr=lr)
self .policy_old = ActorCritic(state_dim, action_dim).to(device)
self .policy_old.load_state_dict(self .policy.state_dict())
self .MseLoss = nn.MSELoss()
self .gamma = gamma
self .eps_clip = eps_clip
self .K_epochs = K_epochs
policy: 当前策略网络,用于输出动作概率和状态值。
gamma: 折扣因子,用于奖励的时间衰减。
eps_clip: 剪辑阈值
policy_old: 旧策略网络,用于计算概率比率
4. 动作选择 def select_action (self, state, memory ):
state = torch.FloatTensor(state).to(device)
action_probs, _ = self .policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
return action.item()
log_prob(action): 记录动作的对数概率
dist.sample(): 按照概率分布采样动作
5. 策略更新 def update (self, memory ):
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
rewards = []
discounted_reward = 0
for reward, is_terminal in zip (reversed (memory.rewards), reversed (memory.is_terminals)):
if is_terminal:
discounted_reward = 0
discounted_reward = reward + (self .gamma * discounted_reward)
rewards.insert(0 , discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7 )
6. Surrogate Loss for _ in range (self .K_epochs):
action_probs, state_values = self .policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self .eps_clip, 1 + self .eps_clip) * advantages
loss_actor = -torch.min (surr1, surr2).mean()
loss_critic = self .MseLoss(state_values.squeeze(), rewards)
loss = loss_actor + 0.5 * loss_critic - 0.01 * entropy.mean()
self .optimizer.zero_grad()
loss.backward()
self .optimizer.step()
self .policy_old.load_state_dict(self .policy.state_dict())
7. 主程序 device = torch.device("cuda" if torch.cuda.is_available() else "cpu" )
env = gym.make("CartPole-v1" )
state_dim = env.observation_space.shape[0 ]
action_dim = env.action_space.n
lr = 0.002
gamma = 0.99
eps_clip = 0.2
K_epochs = 4
max_episodes = 1000
max_timesteps = 300
ppo = PPO(state_dim, action_dim, lr, gamma, eps_clip, K_epochs)
memory = Memory()
for episode in range (1 , max_episodes + 1 ):
state = env.reset()
total_reward = 0
for t in range (max_timesteps):
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
if done:
break
ppo.update(memory)
memory.clear()
print (f"Episode {episode} , Total Reward: {total_reward} " )
env.close()
代码说明 ActorCritic 模型通过共享层生成动作概率和状态值。
使用裁剪的目标函数限制策略更新幅度,避免过度更新。
state_values 通过 Critic 网络提供状态价值的估计。
Memory 用于存储每一轮的状态、动作、奖励和终止标志。
通过多个 epoch 更新,计算优势函数 和裁剪后的策略梯度 。
PPO 从环境中采样,更新策略,打印每一集的总奖励。
Python 3.11 .5
torch 2.1 .0
torchvision 0.16 .0
gym 0.26 .2
由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。
总结 PPO 的关键是通过限制策略的变化范围(剪辑),让优化更加稳定,同时通过优势函数引导策略改进,充分利用采样数据。这种平衡使得 PPO 成为许多强化学习任务的默认算法。
PPO 算法、TRPO 算法 和 A3C 算法对比 以下是 PPO 算法、TRPO 算法 和 A3C 算法的区别分析:
特性 PPO (Proximal Policy Optimization) TRPO (Trust Region Policy Optimization) A3C (Asynchronous Advantage Actor-Critic) 核心思想 使用裁剪的目标函数,限制策略更新幅度,保持稳定性和效率。 限制策略更新的步幅(Trust Region),通过二次约束优化确保稳定性。 通过异步多线程运行环境并行采样和训练,降低方差并加快收敛速度。 优化目标函数 引入剪辑机制 通过 KL 散度限制策略更新 优化策略梯度 更新方式 同步更新,支持多轮迭代更新样本数据以提高效率。 同步更新,通过优化约束的目标函数严格限制更新步长。 异步更新,多个线程独立采样和更新全局模型。 计算复杂度 低,计算简单,使用裁剪避免复杂的二次优化问题。 高,涉及二次优化问题,计算复杂,资源需求较大。 较低,依赖异步线程并行计算,资源利用率高。 样本利用率 高效,可重复利用采样数据进行多轮梯度更新。 高效,严格优化目标,提升了样本效率。 较低,因为每个线程独立运行,可能导致数据重复和冗余。 实现难度 中等,使用简单的裁剪方法,适合大多数场景。 高,涉及复杂的约束优化和实现细节。 较低,直接异步实现,简单易用。 收敛速度 快,因裁剪机制限制更新幅度,能快速稳定收敛。 慢,因严格的步幅限制,收敛稳定但需要较多训练迭代。 快,因多线程并行采样,能够显著减少训练时间。 稳定性 高,裁剪机制限制过大更新,避免不稳定行为。 高,严格限制更新步幅,保证策略稳定改进。 较低,异步更新可能导致收敛不稳定(如策略冲突)。 应用场景 广泛使用,适合大规模环境或复杂问题。 适合需要极高稳定性的场景,如机器人控制等。 适合资源受限的场景或需要快速实验的任务,如强化学习基准测试。 优点 简单易实现,收敛快,稳定性高,是主流强化学习算法。 理论支持强,更新步幅严格受控,策略非常稳定。 异步更新高效,能够充分利用多线程资源,加速训练。 缺点 理论支持弱于 TRPO,可能过于保守。 实现复杂,计算资源需求高,更新速度慢。 异步更新可能导致训练不稳定,样本利用率较低。 论文来源 Schulman et al., "Proximal Policy Optimization Algorithms" (2017) Schulman et al., "Trust Region Policy Optimization" (2015) Mnih et al., "Asynchronous Methods for Deep Reinforcement Learning" (2016)
三种算法的对比总结:
PPO 是 TRPO 的改进版: PPO 使用简单的裁剪机制代替了 TRPO 的二次优化,显著降低了实现复杂度,同时保持了良好的稳定性和效率。
A3C 的并行化设计: A3C 的核心是通过多线程异步更新提升效率,但其稳定性略低于 PPO 和 TRPO。
实用性: PPO 因其简单、稳定、高效的特点,已成为强化学习领域的主流算法;TRPO 更适合需要极高策略稳定性的任务;A3C 在资源受限的场景下表现优异。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online