近端策略优化算法 (PPO) 详解
近端策略优化(Proximal Policy Optimization,简称 PPO)是一种强化学习算法,旨在复杂任务中既保证性能提升,又维持训练的稳定性和高效性。它通过限制策略更新幅度,避免了传统方法中因参数更新过大导致的崩溃问题。
1. 背景与核心思想
PPO 由 OpenAI 在 2017 年提出,专注于简化训练过程,克服传统策略梯度方法(如 TRPO)的计算复杂性。在强化学习中,直接优化策略往往会导致不稳定的训练,模型可能因为过大的参数更新而崩溃。PPO 的核心在于限制策略更新幅度,使得每一步训练都不会偏离当前策略太多,同时高效利用采样数据。
为什么 PPO 很强?
- 简洁性:比 TRPO 更简单,无需复杂的二次优化。
- 稳定性:使用剪辑机制防止策略更新过度。
- 高效性:利用采样数据多次训练,提高样本利用率。
直观类比
假设你是一个篮球教练,训练球员投篮:
- 如果每次训练完全改变投篮动作,球员可能会表现失常(类似于策略更新过度)。
- 如果每次训练动作变化太小,可能很难进步(类似于更新不足)。
- PPO 的剪辑机制就像一个'适度改进'的规则,告诉球员在合理范围内调整投篮动作,同时评估每次投篮的表现是否优于平均水平。
2. 数学推导与公式
PPO 的核心是对策略更新进行限制,使训练更加稳定,同时保持效率。
概率比率
PPO 引入了概率比率,用于衡量新旧策略的差异:
$$r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}$$
其中 $\pi_{\theta_{\text{old}}}$ 是旧策略对动作 $a_t$ 的概率,$\pi_\theta$ 是新策略对动作 $a_t$ 的概率。这个比率表示策略变化的程度。
优化目标
为了限制策略的更新幅度,PPO 引入了剪辑目标函数:
$$L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$
PPO 的目标是找到一个折中:在保持改进的同时防止策略变化过大。这里的 $A_t$ 是优势函数,通过以下公式计算:
$$A_t = Q(s_t, a_t) - V(s_t)$$
或者用广义优势估计(GAE)的方法近似。
值函数优化
PPO 不仅优化策略,还同时更新值函数 $V(s_t)$,通过最小化均方误差来更新:
$$L^{VF}(\theta) = \mathbb{E}_t \left[ \left( V(s_t; \theta) - R_t \right)^2 \right]$$
其中 $R_t = \sum_{k=0}^n \gamma^k r_{t+k}$ 为累计回报。这个损失函数使得 Critic 能够更准确地估计状态值。
策略熵正则化
为了鼓励策略的探索,PPO 引入了熵正则化项:
$$L^{ENT}(\theta) = \mathbb{E}t \left[ H(\pi\theta(s_t)) \right]$$
增加熵可以防止策略过早收敛到局部最优。$H(\pi_\theta(s_t))$ 表示策略分布的不确定性。
总损失函数
PPO 结合策略损失、值函数损失和熵正则化项,形成总损失函数:
$$L(\theta) = \mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right]$$
$c_1$ 和 $c_2$ 是权重系数,用于平衡策略优化、值函数更新和熵正则化。
3. Python 代码实现
以下是使用 PyTorch 实现 PPO 算法的完整代码示例。我们基于 Gym 环境进行演示。
import torch
torch.nn nn
torch.optim optim
torch.distributions Categorical
numpy np
gym
device = torch.device( torch.cuda.is_available() )
(nn.Module):
():
(ActorCritic, ).__init__()
.shared_layer = nn.Sequential(
nn.Linear(state_dim, ),
nn.ReLU()
)
.actor = nn.Sequential(
nn.Linear(, action_dim),
nn.Softmax(dim=-)
)
.critic = nn.Linear(, )
():
shared = .shared_layer(state)
action_probs = .actor(shared)
state_value = .critic(shared)
action_probs, state_value
:
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
():
.states = []
.actions = []
.logprobs = []
.rewards = []
.is_terminals = []
:
():
.policy = ActorCritic(state_dim, action_dim).to(device)
.optimizer = optim.Adam(.policy.parameters(), lr=lr)
.policy_old = ActorCritic(state_dim, action_dim).to(device)
.policy_old.load_state_dict(.policy.state_dict())
.MseLoss = nn.MSELoss()
.gamma = gamma
.eps_clip = eps_clip
.K_epochs = K_epochs
():
state = torch.FloatTensor(state).to(device)
action_probs, _ = .policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
action.item()
():
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
rewards = []
discounted_reward =
reward, is_terminal ((memory.rewards), (memory.is_terminals)):
is_terminal:
discounted_reward =
discounted_reward = reward + (.gamma * discounted_reward)
rewards.insert(, discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + )
_ (.K_epochs):
action_probs, state_values = .policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, - .eps_clip, + .eps_clip) * advantages
loss_actor = -torch.(surr1, surr2).mean()
loss_critic = .MseLoss(state_values.squeeze(), rewards)
loss = loss_actor + * loss_critic - * entropy.mean()
.optimizer.zero_grad()
loss.backward()
.optimizer.step()
.policy_old.load_state_dict(.policy.state_dict())
__name__ == :
env = gym.make()
state_dim = env.observation_space.shape[]
action_dim = env.action_space.n
lr =
gamma =
eps_clip =
K_epochs =
max_episodes =
max_timesteps =
ppo = PPO(state_dim, action_dim, lr, gamma, eps_clip, K_epochs)
memory = Memory()
episode (, max_episodes + ):
state = env.reset()
total_reward =
t (max_timesteps):
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
done:
ppo.update(memory)
memory.clear()
()
env.close()


