跳到主要内容深度确定性策略梯度算法 (DDPG) 详解与 Python 实现 | 极客日志PythonAI算法
深度确定性策略梯度算法 (DDPG) 详解与 Python 实现
综述由AI生成深度确定性策略梯度 (DDPG) 是一种适用于连续动作空间的强化学习算法,结合 Actor-Critic 框架与经验回放机制。本文详细解析了 DDPG 的核心原理、目标网络更新及策略梯度优化公式,并基于 PyTorch 提供了完整的 Python 实现代码。通过 Pendulum-v1 环境验证,展示了智能体如何通过学习曲线收敛,适合机器人控制等场景。
星落6 浏览 深度确定性策略梯度算法 (DDPG) 详解
深度确定性策略梯度(Deep Deterministic Policy Gradient,DDPG)是一种基于深度强化学习的算法,专门用于解决连续动作空间的问题,比如机器人控制中的连续运动。它结合了确定性策略和深度神经网络,属于 Actor-Critic 框架,同时利用了 DQN 和 PG(Policy Gradient)的优点。

算法特点
- 适用于连续动作空间: DDPG 直接输出连续值动作,无需对动作进行离散化。
- 利用确定性策略: 与随机策略不同,DDPG 输出的是每个状态下一个确定的最优动作。
- 结合目标网络: 使用延迟更新的目标网络,稳定了训练过程,避免了过大的参数波动。
- 经验回放机制: 通过经验回放缓解数据相关性,提升样本利用率。
- 高效学习: 使用 Critic 网络评估动作的质量,使得策略优化过程更加高效。
核心改进点
- 从 DQN 继承的目标网络: 避免 Q 值的估计震荡问题,提高算法的训练稳定性。
- 从 PG 继承的策略梯度优化: 通过 Actor 网络直接优化策略,适应连续动作问题。
- 经验回放(Replay Buffer): 将交互环境中的经验(状态、动作、奖励、下一状态)存储起来,训练时从中随机采样,减少数据相关性和样本浪费。
- 双网络架构: Actor 网络负责生成动作;Critic 网络评估动作的质量。
算法公式推导
1. Q 值函数更新
DDPG 使用 Bellman 方程更新 Critic 网络的目标 Q 值:

其中 $s'$ 是下一状态,$\mu'(s')$ 是目标动作,$\gamma$ 是折扣因子,$\mu'$ 是目标 Actor 网络,$Q'$ 是目标 Critic 网络。
Critic 网络的优化目标是最小化以下损失函数:

其中 $y_i$ 是目标值,$\theta^Q$ 是 Critic 网络的参数。
2. 策略更新(Actor 网络)
Actor 网络通过最大化 Critic 网络的 Q 值来优化策略,其目标函数为:

3. 目标网络更新
![\theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau) \theta^{Q'}]

其中 $\tau \in (0, 1)$ 是软更新系数。
算法流程
- 初始化: 初始化 Actor、Critic 网络和它们对应的目标网络,初始化经验回放池。
- 交互环境: 在状态 $s_t$ 下,通过 Actor 网络生成动作 $a_t$,执行动作获取奖励 $r_t$ 和下一状态 $s_{t+1}$。
- 存储经验: 将 $(s_t, a_t, r_t, s_{t+1})$ 存储到经验回放池。
- 采样训练: 从经验池中随机采样小批量数据 $(s_i, a_i, r_i, s'_i)$。
- 更新 Critic 网络: 计算目标 Q 值,最小化 Critic 的损失函数。
- 更新 Actor 网络: 使用 Critic 网络的梯度来调整 Actor 网络的参数。
- 目标网络更新: 按照软更新公式更新目标网络的参数。
- 重复以上步骤,直到达到学习目标。
[Python] DDPG 算法实现
下面给出了 DDPG 算法的完整 Python 实现。该实现包括 Actor-Critic 架构、缓冲区和目标网络等。
1. 导入必要库
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
这里主要用到 gym 创建环境,numpy 处理数组,torch 构建模型,deque 实现经验回放池。
2. 定义 Actor 网络
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.layer1 = nn.Linear(state_dim, 256)
self.layer2 = nn.Linear(256, 256)
self.layer3 = nn.Linear(256, action_dim)
self.max_action = max_action
def forward(self, state):
x = torch.relu(self.layer1(state))
x = torch.relu(self.layer2(x))
x = torch.tanh(self.layer3(x)) * self.max_action
return x
解析:Actor 网络的作用是生成给定状态下的最优动作。输入包含状态维度、动作维度和最大动作值。网络结构包含两层隐藏层(256 个神经元),输出层使用 tanh 激活函数将动作限制在 [-1, 1],再乘以 max_action 缩放到实际动作范围。
3. 定义 Critic 网络
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.layer1 = nn.Linear(state_dim + action_dim, 256)
self.layer2 = nn.Linear(256, 256)
self.layer3 = nn.Linear(256, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=1)
x = torch.relu(self.layer1(x))
x = torch.relu(self.layer2(x))
x = self.layer3(x)
return x
解析:Critic 网络评估给定状态和动作的质量(即 Q 值)。输入将状态和动作拼接,经过两层隐藏层后输出一个标量 Q 值。
4. 定义经验回放池
class ReplayBuffer:
def __init__(self, max_size):
self.buffer = deque(maxlen=max_size)
def add(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions), np.array(rewards),
np.array(next_states), np.array(dones))
def size(self):
return len(self.buffer)
解析:作用是将智能体与环境交互的经验数据存储起来,打破样本间的时间相关性。add 方法存入经验,sample 方法随机采样,size 返回当前样本数量。
5. 定义 DDPG 智能体
class DDPGAgent:
def __init__(self, state_dim, action_dim, max_action, gamma=0.99, tau=0.005,
buffer_size=100000, batch_size=64):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
self.max_action = max_action
self.gamma = gamma
self.tau = tau
self.replay_buffer = ReplayBuffer(buffer_size)
self.batch_size = batch_size
解析:初始化 Actor 和 Critic 网络及其目标网络,并复制权重。使用 Adam 优化器分别优化两个网络。超参数包括折扣因子 gamma、软更新系数 tau、经验池容量和批量大小。
6. 动作选择方法
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1))
action = self.actor(state).detach().cpu().numpy().flatten()
return action
解析:根据当前状态生成一个连续动作。将输入状态转换为 Torch 张量,用 Actor 网络预测动作,并转为 NumPy 数组返回。
7. 训练方法
def train(self):
if self.replay_buffer.size() < self.batch_size:
return
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
states = torch.FloatTensor(states)
actions = torch.FloatTensor(actions)
rewards = torch.FloatTensor(rewards).unsqueeze(1)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones).unsqueeze(1)
with torch.no_grad():
next_actions = self.actor_target(next_states)
target_q = self.critic_target(next_states, next_actions)
target_q = rewards + (1 - dones) * self.gamma * target_q
current_q = self.critic(states, actions)
critic_loss = nn.MSELoss()(current_q, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
actor_loss = -self.critic(states, self.actor(states)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def add_to_replay_buffer(self, state, action, reward, next_state, done):
self.replay_buffer.add(state, action, reward, next_state, done)
解析:这是核心训练逻辑。首先检查回放池是否足够。然后采样数据,计算 Critic 的 MSE 损失并反向传播。接着计算 Actor 的损失(最大化 Critic 的 Q 值)并更新。最后使用软更新公式平滑更新目标网络参数。
8. 训练智能体
def train_ddpg(env_name, episodes=1000, max_steps=200):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
agent = DDPGAgent(state_dim, action_dim, max_action)
rewards = []
for episode in range(episodes):
state, _ = env.reset()
episode_reward = 0
for step in range(max_steps):
action = agent.select_action(state)
next_state, reward, done, _, _ = env.step(action)
agent.add_to_replay_buffer(state, action, reward, next_state, done)
agent.train()
state = next_state
episode_reward += reward
if done:
break
rewards.append(episode_reward)
print(f"Episode: {episode + 1}, Reward: {episode_reward}")
plt.plot(rewards)
plt.title("Learning Curve")
plt.xlabel("Episodes")
plt.ylabel("Cumulative Reward")
plt.show()
env.close()
解析:主训练循环。初始化环境和智能体,在每个 episode 中与环境交互,存储经验并更新网络,最后绘制学习曲线。
9. 可视化学习曲线
import matplotlib.pyplot as plt
plt.plot(rewards)
plt.title("Learning Curve")
plt.xlabel("Episodes")
plt.ylabel("Cumulative Reward")
plt.show()
运行结果
代码说明
- 演员和评论家网络:演员网络预测给定当前状态的动作,批评家网络评估状态 - 行为对的 q 值。
- Replay Buffer:存储过去的经验,使有效的采样训练成为可能。
- 训练:Critic 使用 Bellman 方程更新,Actor 被更新以最大化期望 q 值。
- 目标网络:平滑更新以稳定训练。
- 环境:代理在 Pendulum-v1 环境中进行训练作为演示。
注意:上述代码主要用于了解和学习算法原理。若应用于实际项目,建议根据具体任务调整超参数并进行进一步优化。
优势
- 解决连续动作问题: 可以直接输出一个连续值动作,而不像传统的离散强化学习算法需要动作离散化。
- 样本效率高: 使用了经验回放和目标网络,减少了样本相关性问题,提高了学习效率和稳定性。
通俗类比
可以把 DDPG 算法想象成一个赛车手(Actor)和他的教练(Critic):
- **赛车手(Actor)**决定转弯的角度、加速的力度,直接控制赛车。
- **教练(Critic)**则通过观察赛车手的表现,告诉他哪些动作是好的,哪些是需要改进的。
- 经验回放池就是赛车手在训练中不断回看他之前的比赛录像,找到改进的地方。
- 目标网络则类似于赛车手的长期目标,比如平稳驾驶,而不是今天开得快、明天开得慢。
应用场景
- 机器人运动控制(机械臂、无人机)
- 自动驾驶中的连续控制任务
- 游戏中的复杂策略设计
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online