四足机器人强化学习:PPO 算法的 Python 实现与解析
详细解析了 rsl_rl 仓库中 PPO 算法的 Python 实现。内容涵盖仓库结构概览、PPO 核心公式回顾(概率比率、GAE、熵)、代码关键模块分析(初始化、经验回放、动作采样、环境反馈、收益计算)以及核心的 update 训练循环。重点讲解了策略裁剪、价值裁剪、KL 散度控制及自适应学习率机制,旨在帮助开发者理解四足机器人强化学习控制的底层逻辑与优化目标。

详细解析了 rsl_rl 仓库中 PPO 算法的 Python 实现。内容涵盖仓库结构概览、PPO 核心公式回顾(概率比率、GAE、熵)、代码关键模块分析(初始化、经验回放、动作采样、环境反馈、收益计算)以及核心的 update 训练循环。重点讲解了策略裁剪、价值裁剪、KL 散度控制及自适应学习率机制,旨在帮助开发者理解四足机器人强化学习控制的底层逻辑与优化目标。

python语法,明白面向对象的封装pytorch基础使用Policy Gradient、Actor-Critic 和 PPOrsl_rl 仓库的 PPO 算法的 python 实现。Unitree RL GYM 是一个开源的 基于 Unitree 机器人强化学习(Reinforcement Learning, RL)控制示例项目,用于训练、测试和部署四足机器人控制策略。该仓库支持多种 Unitree 机器人型号,包括 Go2、H1、H1_2 和 G1。仓库地址
![图片]
![图片]
git clone https://github.com/leggedrobotics/rsl_rl.git
cd rsl_rl
git checkout v1.0.2
PPO 的核心公式。PPO 的目标函数是:
$$L^{clip}(\theta)=\mathbb{E}[\min(r(\theta)A,\mathrm{clip}(r(\theta),1-\epsilon,1+\epsilon)A)]$$
其中:
r ≈ 1,说明新旧策略 差不多r >> 1 或 r << 1,说明策略 变化太大tree 指令看一下整个项目的结构rsl_rl 目录结构rsl_rl/
├── algorithms/
├── env/
├── modules/
├── runners/
├── storage/
└── utils/
algorithms/ 目录algorithms/
├── __init__.py
└── ppo.py
ppo.py 实现了 PPO(Proximal Policy Optimization) 算法。env/ 目录env/
├── __init__.py
└── vec_env.py
vec_env.py 实现 Vectorized Environment,支持多环境并行训练。modules/ 目录modules/
├── actor_critic.py
├── actor_critic_recurrent.py
runners/ 目录runners/
└── on_policy_runner.py
on_policy_runner.py 负责 按策略采样数据并执行训练循环。storage/ 目录storage/
└── rollout_storage.py
utils/ 目录utils/
└── utils.py
python 实现algorithms/
├── __init__.py
└── ppo.py
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import torch
import torch.nn nn
torch.optim optim
rsl_rl.modules ActorCritic
rsl_rl.storage RolloutStorage
:
actor_critic: ActorCritic
():
.device = device
.desired_kl = desired_kl
.schedule = schedule
.learning_rate = learning_rate
.actor_critic = actor_critic
.actor_critic.to(.device)
.storage =
.optimizer = optim.Adam(.actor_critic.parameters(), lr=learning_rate)
.transition = RolloutStorage.Transition()
.clip_param = clip_param
.num_learning_epochs = num_learning_epochs
.num_mini_batches = num_mini_batches
.value_loss_coef = value_loss_coef
.entropy_coef = entropy_coef
.gamma = gamma
.lam = lam
.max_grad_norm = max_grad_norm
.use_clipped_value_loss = use_clipped_value_loss
():
.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, .device)
():
.actor_critic.test()
():
.actor_critic.train()
():
.actor_critic.is_recurrent:
.transition.hidden_states = .actor_critic.get_hidden_states()
.transition.actions = .actor_critic.act(obs).detach()
.transition.values = .actor_critic.evaluate(critic_obs).detach()
.transition.actions_log_prob = .actor_critic.get_actions_log_prob(.transition.actions).detach()
.transition.action_mean = .actor_critic.action_mean.detach()
.transition.action_sigma = .actor_critic.action_std.detach()
.transition.observations = obs
.transition.critic_observations = critic_obs
.transition.actions
():
.transition.rewards = rewards.clone()
.transition.dones = dones
infos:
.transition.rewards += .gamma * torch.squeeze(.transition.values * infos[].unsqueeze().to(.device), )
.storage.add_transitions(.transition)
.transition.clear()
.actor_critic.reset(dones)
():
last_values = .actor_critic.evaluate(last_critic_obs).detach()
.storage.compute_returns(last_values, .gamma, .lam)
():
mean_value_loss =
mean_surrogate_loss =
.actor_critic.is_recurrent:
generator = .storage.reccurent_mini_batch_generator(.num_mini_batches, .num_learning_epochs)
:
generator = .storage.mini_batch_generator(.num_mini_batches, .num_learning_epochs)
obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch generator:
.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[])
actions_log_prob_batch = .actor_critic.get_actions_log_prob(actions_batch)
value_batch = .actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[])
mu_batch = .actor_critic.action_mean
sigma_batch = .actor_critic.action_std
entropy_batch = .actor_critic.entropy
.desired_kl .schedule == :
torch.inference_mode():
kl = torch.(
torch.log(sigma_batch / old_sigma_batch + ) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / ( * torch.square(sigma_batch)) - , axis=-)
kl_mean = torch.mean(kl)
kl_mean > .desired_kl * :
.learning_rate = (, .learning_rate / )
kl_mean < .desired_kl / kl_mean > :
.learning_rate = (, .learning_rate * )
param_group .optimizer.param_groups:
param_group[] = .learning_rate
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, - .clip_param, + .clip_param)
surrogate_loss = torch.(surrogate, surrogate_clipped).mean()
.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-.clip_param, .clip_param)
value_losses = (value_batch - returns_batch).()
value_losses_clipped = (value_clipped - returns_batch).()
value_loss = torch.(value_losses, value_losses_clipped).mean()
:
value_loss = (returns_batch - value_batch).().mean()
loss = surrogate_loss + .value_loss_coef * value_loss - .entropy_coef * entropy_batch.mean()
.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(.actor_critic.parameters(), .max_grad_norm)
.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
num_updates = .num_learning_epochs * .num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
.storage.clear()
mean_value_loss, mean_surrogate_loss
class PPO:
actor_critic: ActorCritic
def __init__(self, actor_critic, num_learning_epochs=1, num_mini_batches=1, clip_param=0.2, gamma=0.998, lam=0.95, value_loss_coef=1.0, entropy_coef=0.0, learning_rate=1e-3, max_grad_norm=1.0, use_clipped_value_loss=True, schedule="fixed", desired_kl=0.01, device='cpu',):
...
PPO 的超参数:
actor_critic:这里传入的是 PPO 算法必须的 Actor-Critic 网络(这个网络的定义在 modules/actor_critic.py,这个我们后面几期会进行解析)num_learning_epochs=1:每一批 rollout 数据 重复训练多少轮num_mini_batches=1:把 rollout 数据分成多少 mini-batch(以提高样本利用率)clip_param=0.2:这个是 PPO 的 $\epsilon$ 核心参数,用于对策略进行裁切gamma=0.998:奖励折扣因子,用于控制控制 长期奖励权重lam=0.95:GAE 的 $\lambda$ 参数,用于在计算优势函数的时候降低方差value_loss_coef=1.0:损失函数权重,越高越关注 value 网络entropy_coef=0.0:策略熵,鼓励策略保持一定随机性,用于设置额外探索奖励learning_rate=1e-3:神经网络学习率max_grad_norm=1.0:梯度裁剪,大于此值的梯度值会被裁切,防止梯度爆炸use_clipped_value_loss=True:是否使用 Value Clipping,防止 Critic 更新过大schedule="fixed":表示 训练过程中学习率保持固定,不根据 KL 或训练情况动态调整desired_kl=0.01:目标 KL 散度,表示 期望新旧策略之间的 KL 距离大约为 0.01,用于在自适应学习率策略中控制策略更新幅度。device='cpu':运行设备self.storage = None # initialized later
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()
self.storage:经验回放缓存(Rollout Buffer)占位符self.optimizer:Adam 优化器 来更新 Actor-Critic 网络参数self.transition:临时数据结构(step buffer)init_storage()def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
self.storage = RolloutStorage(
num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device
)
storage/rollout_storage.py,我们之后也会解析def test_mode(self):
self.actor_critic.test()
def train_mode(self):
self.actor_critic.train()
actor_critic 的模式modules/actor_critic.py,我们之后也会解析act()def act(self, obs, critic_obs):
if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
# Compute the actions and values
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
def act(self, obs, critic_obs):
obs 是 策略网络的输入critic_obs 是 价值网络输入if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
hidden_state,否则后面训练无法恢复序列状态。self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
Actor 网络 根据 策略网络的输入 来计算动作,.detach() 表示不参与梯度计算,只是进行采样。Critic 网络 计算价值,.detach() 表示不参与梯度计算,只是进行采样。self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
概率比率(Probability Ratio) 的时候使用self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
env.step() 之前保存,否则状态就改变了process_env_step()def process_env_step(self, rewards, dones, infos):
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
if 'time_outs' in infos:
self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
if 'time_outs' in infos:
self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
value 的计算 $r = r + \gamma V(s)$# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
transition 并重置 RNNcompute_returnsreturns 和 优势函数 advantagedef compute_returns(self, last_critic_obs):
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
GAE 计算,将 多步 TD 误差进行加权平均,从而得到更加稳定的 Advantage 估计。Return:$R_t = r_t + \gamma R_{t+1}$Advantage(GAE):$\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$,$A_t = \delta_t + \gamma\lambda\delta_{t+1} + (\gamma\lambda)^2\delta_{t+2}+...$update() 函数中完成。这个函数负责:
def update(self):
mean_value_loss = 0
mean_surrogate_loss = 0
if self.actor_critic.is_recurrent:
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
# KL
if self.desired_kl is not None and self.schedule == 'adaptive':
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-)
kl_mean = torch.mean(kl)
kl_mean > .desired_kl * :
.learning_rate = (, .learning_rate / )
kl_mean < .desired_kl / kl_mean > :
.learning_rate = (, .learning_rate * )
param_group .optimizer.param_groups:
param_group[] = .learning_rate
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, - .clip_param, + .clip_param)
surrogate_loss = torch.(surrogate, surrogate_clipped).mean()
.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-.clip_param, .clip_param)
value_losses = (value_batch - returns_batch).()
value_losses_clipped = (value_clipped - returns_batch).()
value_loss = torch.(value_losses, value_losses_clipped).mean()
:
value_loss = (returns_batch - value_batch).().mean()
loss = surrogate_loss + .value_loss_coef * value_loss - .entropy_coef * entropy_batch.mean()
.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(.actor_critic.parameters(), .max_grad_norm)
.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
num_updates = .num_learning_epochs * .num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
.storage.clear()
mean_value_loss, mean_surrogate_loss
mean_value_loss = 0
mean_surrogate_loss = 0
if self.actor_critic.is_recurrent:
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
mean_value_loss 和 mean_surrogate_loss:统计整个训练过程中的 平均 loss,用于日志打印。mini-batch 迭代器for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
obs_batch:Actor 网络输入critic_obs_batch:Critic 网络输入actions_batch:采样动作target_values_batch:旧价值函数 V(s)advantages_batch:GAE 优势函数critic_obs_batch:Critic 输入returns_batch:目标价值old_actions_log_prob_batch:旧策略概率 $\log \pi_\theta(a|s)$old_mu_batch:旧策略均值 $\mu$old_sigma_batch:旧策略方差 $\sigma$self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
act 函数进行 Actor 前向计算,重新计算 当前策略的动作分布 $\pi_\theta(a|s) = \mathcal{N}(\mu_\theta(s), \sigma_\theta(s))$actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
log prob $\log \pi_\theta(a_t|s_t)$,用于一会计算概率比率value_batch = self.actor_critic.evaluate(critic_obs_batch)
valuemu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
mu_batch:$\mu$ 策略均值sigma_batch:$\sigma$ 策略方差entropy:策略熵# KL
if self.desired_kl is not None and self.schedule == 'adaptive':
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
kl_mean = torch.mean(kl)
if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate
if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate
| KL | 说明 |
|---|---|
| 太大 | 更新太猛 |
| 太小 | 更新太慢 |
ratio = torch.exp(actions_log_prob_batch - old_actions_log_prob_batch)
PPO 的核心公式:$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$r ≈ 1,说明新旧策略 差不多r >> 1 或 r << 1,说明策略 变化太大surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param, self.clip_param)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (returns_batch - value_batch).pow(2).mean()
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
self.value_loss_coef * value_loss:价值网络损失surrogate_loss:策略网络损失self.entropy_coef * entropy_batch.mean():策略熵函数损失PPO value clip 分为两种计算 value_loss 的方式PPO value clip 可以防止 critic 更新过大# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
loss,清空 rollout buffernum_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
self.storage.clear()
![图片]
rsl_rl 仓库中 PPO 算法的 Python 实现进行了全面解析:从初始化超参数、经验回放缓存、动作采样、环境反馈处理,到优势函数计算与策略更新的完整流程。核心机制包括概率比率裁剪 (clip)、GAE 优势估计、价值函数裁剪、防止梯度爆炸、以及可选的自适应学习率和 KL 控制,最终通过组合策略损失、价值损失和策略熵形成完整优化目标,实现对四足机器人稳定且高效的强化学习训练。
微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online