近端策略优化算法 (PPO) 原理与 PyTorch 实现详解
背景与动机
近端策略优化(Proximal Policy Optimization,简称 PPO)是 OpenAI 在 2017 年提出的一种策略优化算法。它的核心目标是在复杂任务中既保证性能提升,又让训练过程更加稳定和高效。
在强化学习中,直接优化策略往往会导致不稳定的训练结果。如果参数更新幅度过大,模型可能会因为偏离当前策略太远而崩溃;反之,如果更新过小,学习效率又会很低。PPO 通过限制策略更新的幅度,确保每一步训练都不会偏离当前策略太多,同时高效利用采样数据,从而解决了传统策略梯度方法(如 TRPO)计算复杂的问题。
核心思想
PPO 的核心在于引入概率比率和裁剪机制来平衡探索与稳定性。
概率比率
我们使用概率比率 $r_t(\theta)$ 来衡量新旧策略的差异:
$$r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}$$
其中 $\pi_{\theta_{\text{old}}}$ 是旧策略,$\pi_\theta$ 是新策略。这个比率反映了在当前状态下选择动作 $a_t$ 的概率变化程度。
裁剪目标函数
为了防止策略更新过大,PPO 引入了一个裁剪操作,将概率比率限制在区间 $[1-\epsilon, 1+\epsilon]$ 内。其优化目标函数如下:
$$L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$
这里的 $A_t$ 是优势函数,表示某个动作相对于平均表现的优劣。如果新策略导致概率比率超出范围,损失函数会被截断,从而避免过大的更新步长。
数学推导与损失函数
PPO 的训练涉及三个主要部分的损失函数组合:策略损失、值函数损失和熵正则化。
1. 策略损失(Policy Loss)
即上述的裁剪目标函数,旨在最大化期望回报的同时限制策略变化。
2. 值函数优化(Value Function Loss)
Critic 网络负责估计状态价值 $V(s_t)$,通过最小化均方误差来更新:
$$L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right]$$
其中 $R_t$ 是累计回报。这有助于 Critic 更准确地评估状态价值,进而辅助 Actor 进行更好的决策。
3. 策略熵正则化(Entropy Regularization)
为了鼓励探索,防止策略过早收敛到局部最优,引入熵项:
$$L^{ENT}(\theta) = \mathbb{E}t \left[ H(\pi\theta(s_t)) \right]$$
4. 总损失函数
综合以上三项,总损失函数为:
$$L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right]$$
系数 $c_1$ 和 $c_2$ 用于平衡各项的重要性。
PyTorch 代码实现
下面是一个基于 PyTorch 的完整 PPO 实现示例,包含 Actor-Critic 网络结构、经验存储及策略更新逻辑。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
numpy np
gym
(nn.Module):
():
(ActorCritic, ).__init__()
.shared_layer = nn.Sequential(
nn.Linear(state_dim, ),
nn.ReLU()
)
.actor = nn.Sequential(
nn.Linear(, action_dim),
nn.Softmax(dim=-)
)
.critic = nn.Linear(, )
():
shared = .shared_layer(state)
action_probs = .actor(shared)
state_value = .critic(shared)
action_probs, state_value
:
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
:
():
.policy = ActorCritic(state_dim, action_dim).to(device)
.optimizer = optim.Adam(.policy.parameters(), lr=lr)
.policy_old = ActorCritic(state_dim, action_dim).to(device)
.policy_old.load_state_dict(.policy.state_dict())
.MseLoss = nn.MSELoss()
.gamma = gamma
.eps_clip = eps_clip
.K_epochs = K_epochs
():
state = torch.FloatTensor(state).to(device)
torch.no_grad():
action_probs, _ = .policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
action.item()
():
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
rewards = []
discounted_reward =
reward, is_terminal ((memory.rewards), (memory.is_terminals)):
is_terminal:
discounted_reward =
discounted_reward = reward + (.gamma * discounted_reward)
rewards.insert(, discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + )
_ (.K_epochs):
action_probs, state_values = .policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, - .eps_clip, + .eps_clip) * advantages
loss_actor = -torch.(surr1, surr2).mean()
loss_critic = .MseLoss(state_values.squeeze(), rewards)
loss = loss_actor + * loss_critic - * entropy.mean()
.optimizer.zero_grad()
loss.backward()
.optimizer.step()
.policy_old.load_state_dict(.policy.state_dict())
device = torch.device( torch.cuda.is_available() )
env = gym.make()
state_dim = env.observation_space.shape[]
action_dim = env.action_space.n
ppo = PPO(state_dim, action_dim)
memory = Memory()
episode ():
state = env.reset()
total_reward =
t ():
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
done:
ppo.update(memory)
memory.clear()
()


