近端策略优化算法 (PPO) 详解
近端策略优化(Proximal Policy Optimization, PPO)是一种强化学习算法,旨在复杂任务中既保证性能提升,又维持训练的稳定性和高效性。它由 OpenAI 在 2017 年提出,通过限制策略更新幅度,克服了传统策略梯度方法(如 TRPO)的计算复杂性,同时保持了良好的采样效率。
背景与核心思想
在强化学习中,直接优化策略往往会导致不稳定的训练,模型可能因为过大的参数更新而崩溃。PPO 的核心在于引入概率比率来衡量新旧策略的差异,并通过裁剪机制防止策略变化过大。
概率比率
PPO 使用以下比率来评估动作选择的变化:
$$r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}$$
其中 $\pi_\theta$ 是新策略,$\pi_{\theta_{\text{old}}}$ 是旧策略。这个比率反映了在当前状态下,新策略选择动作 $a_t$ 的概率相对于旧策略的变化程度。
优化目标
为了限制更新幅度,PPO 引入了剪辑目标函数:
$$L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$
这里的 $A_t$ 是优势函数,表示某个动作相对于平均表现的优劣。clip 操作将概率比率限制在区间 $[1-\epsilon, 1+\epsilon]$ 内,避免策略更新过于激进。
为什么 PPO 很强?
- 简洁性:相比 TRPO,无需复杂的二次优化,实现更简单。
- 稳定性:剪辑机制有效防止了策略更新过度导致的性能崩塌。
- 高效性:支持对采样数据进行多次迭代训练,提高了样本利用率。
数学推导与损失函数
PPO 的训练过程涉及三个主要部分的损失函数组合:策略损失、值函数损失和熵正则化。
1. 策略损失 (Policy Loss)
即上述的裁剪目标函数,目的是最大化优势函数的期望,同时限制策略分布的变化范围。
2. 值函数优化 (Value Function)
Critic 网络负责估计状态价值 $V(s)$,通过最小化均方误差进行更新:
$$L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right]$$
其中 $R_t$ 是累计回报。准确的值函数估计有助于计算更精确的优势函数。
3. 策略熵正则化 (Entropy Regularization)
为了防止策略过早收敛到局部最优,鼓励探索,引入熵项:
$$L^{ENT}(\theta) = \mathbb{E}t \left[ H(\pi\theta(s_t)) \right]$$
4. 总损失函数
综合以上三项,总损失函数通常定义为:
$$L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right]$$
系数 $c_1$ 和 $c_2$ 用于平衡不同目标的权重。
PyTorch 代码实现
下面展示一个基于 PyTorch 的完整 PPO 实现框架。代码结构清晰,包含 Actor-Critic 网络、经验存储类以及训练循环。
1. Actor-Critic 神经网络
我们使用共享层提取特征,分别输出动作概率和状态值。
import torch
import torch.nn as nn
from torch.distributions import Categorical
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorCritic, self).__init__()
# 共享特征提取层
self.shared_layer = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU()
)
# Actor: 输出动作概率分布
self.actor = nn.Sequential(
nn.Linear(128, action_dim),
nn.Softmax(dim=-1)
)
# Critic: 输出状态价值
self.critic = nn.Linear(128, 1)
def forward(self, state):
shared = self.shared_layer(state)
action_probs = self.actor(shared)
state_value = self.critic(shared)
return action_probs, state_value
这里 shared_layer 将状态映射到隐层表示,actor 使用 Softmax 确保输出为合法的概率分布,critic 则预测当前状态的预期回报。
2. 经验存储类 (Memory)
PPO 需要批量更新,因此需要一个缓冲区来存储 episode 的数据。
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
def clear(self):
self.states = []
self.actions = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
3. PPO Agent 初始化
初始化策略网络、优化器以及超参数。
import torch.optim as optim
class PPO:
def __init__(self, state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4):
self.policy = ActorCritic(state_dim, action_dim).to(device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
self.policy_old = ActorCritic(state_dim, action_dim).to(device)
self.policy_old.load_state_dict(self.policy.state_dict())
self.MseLoss = nn.MSELoss()
self.gamma = gamma
self.eps_clip = eps_clip
self.K_epochs = K_epochs
注意 policy_old 用于计算概率比率,每次更新后需要同步最新参数。
4. 动作选择
根据当前策略采样动作,并记录对数概率。
def select_action(self, state, memory):
state = torch.FloatTensor(state).to(device)
with torch.no_grad():
action_probs, _ = self.policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
return action.item()
5. 策略更新
这是 PPO 的核心部分,包括奖励归一化、优势计算和裁剪损失更新。
def update(self, memory):
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
# 计算折扣回报
rewards = []
discounted_reward = 0
for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
if is_terminal:
discounted_reward = 0
discounted_reward = reward + (self.gamma * discounted_reward)
rewards.insert(0, discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
# 奖励归一化,加速收敛
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
# 多轮 Epoch 更新
for _ in range(self.K_epochs):
action_probs, state_values = self.policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
# 计算概率比率
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
# 裁剪损失
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
loss_actor = -torch.min(surr1, surr2).mean()
# 值函数损失
loss_critic = self.MseLoss(state_values.squeeze(), rewards)
# 总损失
loss = loss_actor + 0.5 * loss_critic - 0.01 * entropy.mean()
# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 同步旧策略
self.policy_old.load_state_dict(self.policy.state_dict())
这里有一个细节需要注意:advantages 的计算依赖于 state_values,如果 Critic 不准,Advantage 也会偏差较大,所以 Critic 的损失权重通常也需要调整。
6. 主程序流程
完整的训练循环如下:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import gym
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo = PPO(state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4)
memory = Memory()
for episode in range(1, 1001):
state = env.reset()
total_reward = 0
for t in range(300):
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
if done:
break
ppo.update(memory)
memory.clear()
print(f"Episode {episode}, Total Reward: {total_reward}")
算法对比总结
| 特性 | PPO | TRPO | A3C |
|---|---|---|---|
| 核心思想 | 裁剪概率比率限制更新 | 信任区域约束优化 | 异步多线程并行 |
| 优化目标 | 引入剪辑机制 | KL 散度限制 | 策略梯度 |
| 更新方式 | 同步,支持多轮迭代 | 同步,严格约束 | 异步,独立线程 |
| 计算复杂度 | 低 | 高 (二次规划) | 较低 |
| 稳定性 | 高 | 极高 | 中等 |
| 适用场景 | 通用,主流选择 | 需极高稳定性的控制 | 资源受限或快速实验 |
PPO 作为 TRPO 的改进版,用简单的裁剪机制替代了复杂的二次优化,显著降低了实现难度,同时保持了良好的稳定性和效率。对于大多数强化学习任务,PPO 都是首选的默认算法。
注意事项
- 环境适配:代码示例基于 Gym 的 CartPole 环境,实际项目中可能需要针对特定环境调整网络结构或超参数。
- 超参数调优:学习率、折扣因子和裁剪阈值对训练效果影响较大,建议根据具体任务进行网格搜索。
- 奖励归一化:在长序列任务中,奖励归一化能有效缓解梯度消失或爆炸问题。
- 设备管理:确保
device变量正确配置以利用 GPU 加速训练。


