PPO 算法的 Python 实现与解析
详细解析了基于 rsl_rl 库的 PPO 强化学习算法在 Python 中的实现细节。内容包括仓库结构概览、核心类初始化、经验回放缓存管理、动作采样与环境反馈处理,以及训练循环中的损失计算与梯度更新。重点阐述了概率比率裁剪、GAE 优势估计、价值函数裁剪及 KL 散度控制等关键机制,适用于四足机器人控制策略的训练与部署。

详细解析了基于 rsl_rl 库的 PPO 强化学习算法在 Python 中的实现细节。内容包括仓库结构概览、核心类初始化、经验回放缓存管理、动作采样与环境反馈处理,以及训练循环中的损失计算与梯度更新。重点阐述了概率比率裁剪、GAE 优势估计、价值函数裁剪及 KL 散度控制等关键机制,适用于四足机器人控制策略的训练与部署。

本文旨在解析 rsl_rl 仓库中 PPO 算法的核心代码与训练逻辑。默认读者拥有一定的强化学习基础和代码基础,对强化学习基础感兴趣的读者可参考相关入门教程。
阅读本系列的前置知识:
python 语法,明白面向对象的封装pytorch 基础使用Policy Gradient、Actor-Critic 和 PPO本期将讲解 rsl_rl 仓库的 PPO 算法的 python 实现。
Unitree RL GYM 是一个开源的 基于 Unitree 机器人强化学习(Reinforcement Learning, RL)控制示例项目,用于训练、测试和部署四足机器人控制策略。该仓库支持多种 Unitree 机器人型号,包括 Go2、H1、H1_2 和 G1。仓库地址
[图片]
关于仓库的安装和环境配置官方的文档已经非常清楚了,这里就不在赘述。 通过下述指令可以快速获取仓库代码。
[图片]
git clone https://github.com/leggedrobotics/rsl_rl.git
cd rsl_rl
git checkout v1.0.2
姑且这里回顾一下 PPO 的核心公式。PPO 的目标函数是:
$$L^{clip}(\theta) = \mathbb{E}[\min(r(\theta)A, \text{clip}(r(\theta), 1-\epsilon, 1+\epsilon)A)]$$
其中:
概率比率(Probability Ratio) $r(\theta) = \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}$ 它表示:
通过上述公式,PPO 会限制 $r(\theta)$ 的取值范围 $[1-\epsilon, 1+\epsilon]$。如果超过这个范围,梯度就会被裁剪,不再继续增大。
GAE(Generalized Advantage Estimation):
策略熵的公式为:$H(\pi) = -\sum \pi(a|s) \log \pi(a|s)$
拉取完仓库以后,我们可以简单的使用 tree 指令看一下整个项目的结构。
rsl_rl 目录结构
rsl_rl/
├── algorithms/
├── env/
├── modules/
├── runners/
├── storage/
└── utils/
algorithms/ 目录algorithms/
├── __init__.py
└── ppo.py
ppo.py 实现了 PPO(Proximal Policy Optimization) 算法。env/ 目录env/
├── __init__.py
└── vec_env.py
vec_env.py 实现 Vectorized Environment,支持多环境并行训练。modules/ 目录modules/
├── actor_critic.py
├── actor_critic_recurrent.py
runners/ 目录runners/
└── on_policy_runner.py
on_policy_runner.py 负责 按策略采样数据并执行训练循环。storage/ 目录storage/
└── rollout_storage.py
utils/ 目录utils/
└── utils.py
python 实现代码的路径在 algorithms/ppo.py。代码整体在这,别急我们一部分一部分进行分析。
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import torch
import torch.nn nn
torch.optim optim
rsl_rl.modules ActorCritic
rsl_rl.storage RolloutStorage
:
actor_critic: ActorCritic
():
.device = device
.desired_kl = desired_kl
.schedule = schedule
.learning_rate = learning_rate
.actor_critic = actor_critic
.actor_critic.to(.device)
.storage =
.optimizer = optim.Adam(.actor_critic.parameters(), lr=learning_rate)
.transition = RolloutStorage.Transition()
.clip_param = clip_param
.num_learning_epochs = num_learning_epochs
.num_mini_batches = num_mini_batches
.value_loss_coef = value_loss_coef
.entropy_coef = entropy_coef
.gamma = gamma
.lam = lam
.max_grad_norm = max_grad_norm
.use_clipped_value_loss = use_clipped_value_loss
():
.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, .device)
():
.actor_critic.test()
():
.actor_critic.train()
():
.actor_critic.is_recurrent:
.transition.hidden_states = .actor_critic.get_hidden_states()
.transition.actions = .actor_critic.act(obs).detach()
.transition.values = .actor_critic.evaluate(critic_obs).detach()
.transition.actions_log_prob = .actor_critic.get_actions_log_prob(.transition.actions).detach()
.transition.action_mean = .actor_critic.action_mean.detach()
.transition.action_sigma = .actor_critic.action_std.detach()
.transition.observations = obs
.transition.critic_observations = critic_obs
.transition.actions
():
.transition.rewards = rewards.clone()
.transition.dones = dones
infos:
.transition.rewards += .gamma * torch.squeeze(.transition.values * infos[].unsqueeze().to(.device), )
.storage.add_transitions(.transition)
.transition.clear()
.actor_critic.reset(dones)
():
last_values = .actor_critic.evaluate(last_critic_obs).detach()
.storage.compute_returns(last_values, .gamma, .lam)
():
mean_value_loss =
mean_surrogate_loss =
.actor_critic.is_recurrent:
generator = .storage.reccurent_mini_batch_generator(.num_mini_batches, .num_learning_epochs)
:
generator = .storage.mini_batch_generator(.num_mini_batches, .num_learning_epochs)
obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch generator:
.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[])
actions_log_prob_batch = .actor_critic.get_actions_log_prob(actions_batch)
value_batch = .actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[])
mu_batch = .actor_critic.action_mean
sigma_batch = .actor_critic.action_std
entropy_batch = .actor_critic.entropy
.desired_kl != .schedule == :
torch.inference_mode():
kl = torch.(
torch.log(sigma_batch / old_sigma_batch + ) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / ( * torch.square(sigma_batch)) - ,
axis=-
)
kl_mean = torch.mean(kl)
kl_mean > .desired_kl * :
.learning_rate = (, .learning_rate / )
kl_mean < .desired_kl / kl_mean > :
.learning_rate = (, .learning_rate * )
param_group .optimizer.param_groups:
param_group[] = .learning_rate
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, - .clip_param, + .clip_param)
surrogate_loss = torch.(surrogate, surrogate_clipped).mean()
.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-.clip_param, .clip_param)
value_losses = (value_batch - returns_batch).()
value_losses_clipped = (value_clipped - returns_batch).()
value_loss = torch.(value_losses, value_losses_clipped).mean()
:
value_loss = (returns_batch - value_batch).().mean()
loss = surrogate_loss + .value_loss_coef * value_loss - .entropy_coef * entropy_batch.mean()
.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(.actor_critic.parameters(), .max_grad_norm)
.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
num_updates = .num_learning_epochs * .num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
.storage.clear()
mean_value_loss, mean_surrogate_loss
我们来看看这个类初始化部分:
初始化传入了大量 PPO 的超参数:
actor_critic:这里传入的是 PPO 算法必须的 Actor-Critic 网络(这个网络的定义在 modules/actor_critic.py,这个我们后面几期会进行解析)num_learning_epochs=1:每一批 rollout 数据 重复训练多少轮num_mini_batches=1:把 rollout 数据分成多少 mini-batch(以提高样本利用率)clip_param=0.2:这个是 PPO 的 $\epsilon$ 核心参数,用于对策略进行裁切gamma=0.998:奖励折扣因子,用于控制控制 长期奖励权重lam=0.95:GAE 的 $\lambda$ 参数,用于在计算优势函数的时候降低方差value_loss_coef=1.0:损失函数权重,越高越关注 value 网络entropy_coef=0.0:策略熵,鼓励策略保持一定随机性,用于设置额外探索奖励learning_rate=1e-3:神经网络学习率max_grad_norm=1.0:梯度裁剪,大于此值的梯度值会被裁切,防止梯度爆炸use_clipped_value_loss=True:是否使用 Value Clipping,防止 Critic 更新过大schedule="fixed":表示 训练过程中学习率保持固定,不根据 KL 或训练情况动态调整desired_kl=0.01:目标 KL 散度,表示 期望新旧策略之间的 KL 距离大约为 0.01,用于在自适应学习率策略中控制策略更新幅度。device='cpu':运行设备同时还定义了一些变量
self.storage = None # initialized later
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()
self.storage:经验回放缓存(Rollout Buffer)占位符self.optimizer:Adam 优化器 来更新 Actor-Critic 网络参数self.transition:临时数据结构(step buffer)init_storage()def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
self.storage = RolloutStorage(
num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device
)
这个函数用于初始化经验回放缓存(Rollout Buffer)机制的 数据缓存(Rollout Buffer)。
定义在 storage/rollout_storage.py,我们之后也会解析
def test_mode(self):
self.actor_critic.test()
def train_mode(self):
self.actor_critic.train()
这两都是启用 actor_critic 的模式。
这个网络的定义在 modules/actor_critic.py,我们之后也会解析
act()这个函数的作用:根据当前状态,计算动作并返回,同时并记录训练数据
def act(self, obs, critic_obs):
if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
# Compute the actions and values
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
我们一步步看,首先我们来看函数的输入
def act(self, obs, critic_obs):
obs 是 策略网络的输入critic_obs 是 价值网络输入if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
这里判断是否需要使用 RNN / LSTM 网络,如果是,需要保存 hidden_state,否则后面训练无法恢复序列状态。
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
Actor 网络 根据 策略网络的输入 来计算动作,.detach() 表示不参与梯度计算,只是进行采样。Critic 网络 计算价值,.detach() 表示不参与梯度计算,只是进行采样。self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
这里计算动作概率 $\log \pi_\theta(a|s)$,用于后面计算 概率比率(Probability Ratio) 的时候使用
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
这里保存策略分布,用于 KL 散度计算。其中动作通常来自 高斯分布:$a \sim N(\mu, \sigma)$
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
保存状态并返回动作,这里需要在 env.step() 之前保存,否则状态就改变了
process_env_step()这个函数用于处理环境返回结果,并存储数据
def process_env_step(self, rewards, dones, infos):
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
if 'time_outs' in infos:
self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
self.transition.rewards = rewards.clone()
self.transition.dones = dones
保存奖励 $r_t$,同时保存终止信号
# Bootstrapping on time outs
if 'time_outs' in infos:
self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
有些 episode 结束不是因为失败,而是达到最大步数,那就不能把未来价值 $V(s)$ 当成 0。
这时候修正 value 的计算 $r = r + \gamma V(s)$
# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
在经验池里头储存数据,每一步的数据包含 $(s,a,r,V,\log prob)$
清空 transition 并重置 RNN
compute_returns计算 PPO 训练需要的奖励 returns 和 优势函数 advantage
def compute_returns(self, last_critic_obs):
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
第一行会计算最后状态价值 $V(s_T)$
第二行就是 GAE 计算,将 多步 TD 误差进行加权平均,从而得到更加稳定的 Advantage 估计。
Return:$R_t = r_t + \gamma R_{t+1}$
优势函数 Advantage(GAE):$\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$,$A_t = \delta_t + \gamma\lambda\delta_{t+1} + (\gamma\lambda)^2\delta_{t+2}+...$
前面的代码主要完成 数据采样与优势计算,而 PPO 的核心训练逻辑全部在 update() 函数中完成。这个函数负责:
我们一步步来看:
mean_value_loss = 0
mean_surrogate_loss = 0
if self.actor_critic.is_recurrent:
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
mean_value_loss 和 mean_surrogate_loss:统计整个训练过程中的 平均 loss,用于日志打印。mini-batch 迭代器for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
然后我们在每个 batch 取出这些变量:
obs_batch:Actor 网络输入critic_obs_batch:Critic 网络输入actions_batch:采样动作target_values_batch:旧价值函数 V(s)advantages_batch:GAE 优势函数returns_batch:目标价值old_actions_log_prob_batch:旧策略概率 $\log \pi_\theta(a|s)$old_mu_batch:旧策略均值 $\mu$old_sigma_batch:旧策略方差 $\sigma$self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
首先调用 act 函数进行 Actor 前向计算,重新计算 当前策略的动作分布 $\pi_\theta(a|s) = \mathcal{N}(\mu_\theta(s), \sigma_\theta(s))$
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
然后获取当前动作概率 log prob $\log \pi_\theta(a_t|s_t)$,用于一会计算概率比率
value_batch = self.actor_critic.evaluate(critic_obs_batch)
Critic 计算价值函数 value
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
mu_batch:$\mu$ 策略均值sigma_batch:$\sigma$ 策略方差entropy:策略熵KL 散度控制就干一件事:如果 KL 太大,降低学习率。 这是一种简单的 Trust Region 近似实现,用于防止策略更新过大导致训练不稳定。
# KL
if self.desired_kl != None and self.schedule == 'adaptive':
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5,
axis=-1
)
kl_mean = torch.mean(kl)
if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate
其中的 kl 对应 高斯分布 KL 公式 $KL(\pi_{old}||\pi_{new}) = \log\frac{\sigma}{\sigma_{old}} + \frac{\sigma_{old}^2 + (\mu_{old}-\mu)^2}{2\sigma^2} - \frac12$
自适应学习率并更新
| KL | 说明 |
|---|---|
| 太大 | 更新太猛 |
| 太小 | 更新太慢 |
ratio = torch.exp(actions_log_prob_batch - old_actions_log_prob_batch)
这就是 PPO 的核心公式:$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$
它表示:
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
也就是对应的 $L^{CLIP} = \mathbb{E}[\min(r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$ 第一行是原始策略梯度,也就是公式中的 $L^{PG} = \mathbb{E}[r_t(\theta)A_t]$ 通过上述公式,PPO 会限制 $r(\theta)$ 的取值范围 $[1-\epsilon, 1+\epsilon]$。如果超过这个范围,梯度就会被裁剪,不再继续增大。 注意:这里加负号是因为 PyTorch 默认最小化 loss
# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param, self.clip_param)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (returns_batch - value_batch).pow(2).mean()
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
这里计算完整的损失函数公式 $L = L_{policy} + c_1 L_{value} - c_2 H(\pi)$ 其中:
self.value_loss_coef * value_loss:价值网络损失surrogate_loss:策略网络损失self.entropy_coef * entropy_batch.mean():策略熵函数损失这里根据是否使用 PPO value clip 分为两种计算 value_loss 的方式
如果使用 PPO value clip 可以防止 critic 更新过大
# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
剩下的就是
计算平均 loss,清空 rollout buffer
num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
self.storage.clear()
PPO 训练循环:
数学目标:$L = \mathbb{E}[\min(r_t A_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon)A_t)] + c_1(V-R)^2 - c_2H(\pi)$
[图片]
本期我们对 rsl_rl 仓库中 PPO 算法的 Python 实现进行了全面解析:从初始化超参数、经验回放缓存、动作采样、环境反馈处理,到优势函数计算与策略更新的完整流程。核心机制包括概率比率裁剪 (clip)、GAE 优势估计、价值函数裁剪、防止梯度爆炸、以及可选的自适应学习率和 KL 控制,最终通过组合策略损失、价值损失和策略熵形成完整优化目标,实现对四足机器人稳定且高效的强化学习训练。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online