近端策略优化算法 (PPO) 详解
近端策略优化(Proximal Policy Optimization, PPO)是 OpenAI 在 2017 年提出的一种策略优化算法。它的核心目标是在复杂任务中既保证性能提升,又让训练过程更加稳定和高效。相比传统的策略梯度方法,PPO 通过限制策略更新幅度,有效避免了模型因参数更新过大而崩溃的问题。
核心思想与背景
在强化学习中,直接优化策略往往会导致不稳定的训练。PPO 的设计初衷是简化训练流程,克服 TRPO(Trust Region Policy Optimization)的计算复杂性,同时保持高效的采样数据利用率。
关键机制:概率比率与剪辑
PPO 的核心在于引入概率比率来衡量新旧策略的差异:
$$ r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)} $$
其中 $\pi_\theta$ 是新策略,$\pi_{\theta_{\text{old}}}$ 是旧策略。这个比率反映了策略变化的程度。为了防止更新幅度过大,PPO 引入了**剪辑(Clipping)**操作,将比率限制在区间 $[1-\epsilon, 1+\epsilon]$ 内。
优化目标函数
PPO 的代理损失函数(Surrogate Loss)如下:
$$ L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] $$
这里 $A_t$ 是优势函数(Advantage Function),用于评价某个动作相对于平均水平的优劣。通过取最小值,我们确保了当策略改进时能最大化收益,而当更新过大导致性能下降时则受到惩罚。
数学推导与总损失
除了策略优化,PPO 还同时更新值函数(Critic)并引入熵正则化以鼓励探索。
1. 值函数优化
Critic 网络的目标是最小化预测值与真实回报之间的均方误差:
$$ L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right] $$
其中 $R_t$ 是累计回报。
2. 策略熵正则化
为了鼓励策略探索,防止过早收敛到局部最优,加入熵项:
$$ L^{ENT}(\theta) = \mathbb{E} \left[ H(\pi_\theta(s_t)) \right] $$
3. 总损失函数
最终的总损失函数结合了上述三部分:
$$ L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right] $$
系数 $c_1$ 和 $c_2$ 用于平衡策略优化、值函数更新和熵正则化的权重。
PyTorch 代码实现
下面是一个基于 PyTorch 的完整 PPO 实现示例。我们在实现时会重点关注 Actor-Critic 结构、经验存储以及多轮迭代更新。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gym
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() )
(nn.Module):
():
(ActorCritic, ).__init__()
.shared_layer = nn.Sequential(
nn.Linear(state_dim, ),
nn.ReLU()
)
.actor = nn.Sequential(
nn.Linear(, action_dim),
nn.Softmax(dim=-)
)
.critic = nn.Linear(, )
():
shared = .shared_layer(state)
action_probs = .actor(shared)
state_value = .critic(shared)
action_probs, state_value
:
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
:
():
.policy = ActorCritic(state_dim, action_dim).to(device)
.optimizer = optim.Adam(.policy.parameters(), lr=lr)
.policy_old = ActorCritic(state_dim, action_dim).to(device)
.policy_old.load_state_dict(.policy.state_dict())
.MseLoss = nn.MSELoss()
.gamma = gamma
.eps_clip = eps_clip
.K_epochs = K_epochs
():
state = torch.FloatTensor(state).to(device)
torch.no_grad():
action_probs, _ = .policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
action.item()
():
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
rewards = []
discounted_reward =
reward, is_terminal ((memory.rewards), (memory.is_terminals)):
is_terminal:
discounted_reward =
discounted_reward = reward + (.gamma * discounted_reward)
rewards.insert(, discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + )
_ (.K_epochs):
action_probs, state_values = .policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, - .eps_clip, + .eps_clip) * advantages
loss_actor = -torch.(surr1, surr2).mean()
loss_critic = .MseLoss(state_values.squeeze(), rewards)
loss = loss_actor + * loss_critic - * entropy.mean()
.optimizer.zero_grad()
loss.backward()
.optimizer.step()
.policy_old.load_state_dict(.policy.state_dict())
env = gym.make()
state_dim = env.observation_space.shape[]
action_dim = env.action_space.n
ppo = PPO(state_dim, action_dim, lr=, gamma=, eps_clip=, K_epochs=)
memory = Memory()
episode (, ):
state = env.reset()
total_reward =
t ():
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
done:
ppo.update(memory)
memory.clear()
()


