近端策略优化算法 (PPO) 详解与 PyTorch 实现
背景与核心思想
近端策略优化(Proximal Policy Optimization,简称 PPO)是 OpenAI 在 2017 年提出的一种强化学习算法。它的核心目标是在复杂任务中既保证性能提升,又让训练过程更加稳定和高效。
传统的策略梯度方法(如 TRPO)虽然理论严谨,但计算复杂度高,实施难度大。PPO 通过限制策略更新的幅度,使得每一步训练都不会偏离当前策略太远,同时高效利用采样数据。简单来说,它就像是一个'适度改进'的规则:告诉模型在合理范围内调整动作策略,避免步子迈得太大导致崩溃,也防止更新不足停滞不前。
数学原理推导
概率比率与优势函数
PPO 的核心在于引入了概率比率来衡量新旧策略的差异:
$$r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}$$
其中 $\pi_{\theta_{\text{old}}}$ 是旧策略,$\pi_\theta$ 是新策略。这个比率表示新策略和旧策略在同一状态下选择动作的概率比值。
为了评价某个动作的相对好坏,我们使用优势函数 $A_t$:
$$A_t = Q(s_t, a_t) - V(s_t)$$
或者用广义优势估计(GAE)的方法近似。优势函数告诉我们,在当前状态下,采取特定动作比平均水平好多少。
优化目标
为了防止策略更新过度,PPO 引入了剪辑(Clip)机制。其目标函数如下:
$$L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$
这里的 clip 操作将概率比率 $r_t(\theta)$ 限制在区间 $[1-\epsilon, 1+\epsilon]$ 内。如果比率超出这个范围,我们就认为更新幅度过大,不再继续优化该样本,从而保证了训练的稳定性。
总损失函数
除了策略损失,PPO 还同时优化值函数和熵正则化项:
-
值函数损失:最小化预测值与真实回报的均方误差,帮助 Critic 更准确地评估状态价值。 $$L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right]$$
-
熵正则化:鼓励策略探索,防止过早收敛到局部最优。 $$L^{ENT}(\theta) = \mathbb{E}t \left[ H(\pi\theta(s_t)) \right]$$
最终的总损失函数为: $$L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right]$$
PyTorch 代码实现
下面展示如何使用 PyTorch 实现一个完整的 PPO Agent。代码结构清晰,包含网络定义、经验回放、以及训练循环。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gym
# 设置设备
device = torch.device("cuda" torch.cuda.is_available() )
(nn.Module):
():
(ActorCritic, ).__init__()
.shared_layer = nn.Sequential(
nn.Linear(state_dim, ),
nn.ReLU()
)
.actor = nn.Sequential(
nn.Linear(, action_dim),
nn.Softmax(dim=-)
)
.critic = nn.Linear(, )
():
shared = .shared_layer(state)
action_probs = .actor(shared)
state_value = .critic(shared)
action_probs, state_value
:
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
:
():
.policy = ActorCritic(state_dim, action_dim).to(device)
.optimizer = optim.Adam(.policy.parameters(), lr=lr)
.policy_old = ActorCritic(state_dim, action_dim).to(device)
.policy_old.load_state_dict(.policy.state_dict())
.MseLoss = nn.MSELoss()
.gamma = gamma
.eps_clip = eps_clip
.K_epochs = K_epochs
():
state = torch.FloatTensor(state).to(device)
torch.no_grad():
action_probs, _ = .policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
action.item()
():
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
rewards = []
discounted_reward =
reward, is_terminal ((memory.rewards), (memory.is_terminals)):
is_terminal:
discounted_reward =
discounted_reward = reward + (.gamma * discounted_reward)
rewards.insert(, discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + )
_ (.K_epochs):
action_probs, state_values = .policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, - .eps_clip, + .eps_clip) * advantages
loss_actor = -torch.(surr1, surr2).mean()
loss_critic = .MseLoss(state_values.squeeze(), rewards)
loss = loss_actor + * loss_critic - * entropy.mean()
.optimizer.zero_grad()
loss.backward()
.optimizer.step()
.policy_old.load_state_dict(.policy.state_dict())
env = gym.make()
state_dim = env.observation_space.shape[]
action_dim = env.action_space.n
ppo = PPO(state_dim, action_dim)
memory = Memory()
episode (, ):
state = env.reset()
total_reward =
t ():
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
done:
ppo.update(memory)
memory.clear()
()


