跳到主要内容深度确定性策略梯度算法 (DDPG) 详解与实现 | 极客日志PythonAI算法
深度确定性策略梯度算法 (DDPG) 详解与实现
综述由AI生成深度确定性策略梯度算法 (DDPG) 专为解决连续动作空间问题设计,融合了确定性策略与深度神经网络的优势。该算法采用 Actor-Critic 架构,引入经验回放池打破数据相关性,并利用目标网络提升训练稳定性。核心在于通过 Critic 评估动作质量指导 Actor 优化,最终实现高效策略学习。文中详细解析了数学原理,并提供了基于 PyTorch 的完整代码实现,涵盖网络构建、训练循环及可视化分析,适合用于机器人控制等实际任务参考。
微码行者1 浏览 深度确定性策略梯度算法 (DDPG) 详解与实现
深度确定性策略梯度(Deep Deterministic Policy Gradient,简称 DDPG)是一种基于深度强化学习的算法,专门用于解决连续动作空间的问题,例如机器人控制中的连续运动。它结合了确定性策略和深度神经网络,属于 Actor-Critic 框架,同时利用了 DQN 和 PG(Policy Gradient)的优点。

算法特点
- 适用于连续动作空间: DDPG 直接输出连续值动作,无需对动作进行离散化。
- 利用确定性策略: 与随机策略不同,DDPG 输出的是每个状态下一个确定的最优动作。
- 结合目标网络: 使用延迟更新的目标网络,稳定了训练过程,避免了过大的参数波动。
- 经验回放机制: 通过经验回放缓解数据相关性,提升样本利用率。
- 高效学习: 使用 Critic 网络评估动作的质量,使得策略优化过程更加高效。
核心改进点
- 从 DQN 继承的目标网络: 避免 Q 值的估计震荡问题,提高算法的训练稳定性。
- 从 PG 继承的策略梯度优化: 通过 Actor 网络直接优化策略,适应连续动作问题。
- 经验回放(Replay Buffer): 将交互环境中的经验(状态、动作、奖励、下一状态)存储起来,训练时从中随机采样,减少数据相关性和样本浪费。
- 双网络架构: Actor 网络负责生成动作;Critic 网络评估动作的质量。
算法公式推导
1. Q 值函数更新
DDPG 使用 Bellman 方程更新 Critic 网络的目标 Q 值:
$$y = r + \gamma Q'(s', \mu'(s'; \theta^{\mu'}); \theta^{Q'})$$
其中 $s'$ 是下一状态,$\mu'(s')$ 是目标动作,$\gamma$ 是折扣因子,$\mu'$ 是目标 Actor 网络,$Q'$ 是目标 Critic 网络。
Critic 网络的优化目标是最小化以下损失函数:
$$L(\theta^Q) = \frac{1}{N} \sum_{i} \left( Q(s_i, a_i; \theta^Q) - y_i \right)^2$$
其中 $y_i$ 是目标值,$\theta^Q$ 是 Critic 网络的参数。
2. 策略更新(Actor 网络)
Actor 网络通过最大化 Critic 网络的 Q 值来优化策略,其目标函数为:
$$J(\theta^\mu) = \frac{1}{N} \sum_{i} Q(s_i, \mu(s_i; \theta^\mu); \theta^Q)$$
使用梯度上升法更新 Actor 网络:
$$\nabla_{\theta^\mu} J \approx \frac{1}{N} \sum_{i} \nabla_a Q(s, a; \theta^Q) \big|{a=\mu(s)} \nabla{\theta^\mu} \mu(s; \theta^\mu)$$
3. 目标网络更新
目标网络采用软更新方式,缓慢地向当前网络靠近:
$$\theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau) \theta^{Q'}$$
$$\theta^{\mu'} \leftarrow \tau \theta^\mu + (1 - \tau) \theta^{\mu'}$$
其中 $\tau \in (0, 1)$ 是软更新系数。
算法流程
- : 初始化 Actor、Critic 网络和它们对应的目标网络,初始化经验回放池。
初始化
交互环境: 在状态 $s_t$ 下,通过 Actor 网络生成动作 $a_t$,执行动作获取奖励 $r_t$ 和下一状态 $s_{t+1}$。存储经验: 将 $(s_t, a_t, r_t, s_{t+1})$ 存储到经验回放池。采样训练: 从经验池中随机采样小批量数据 $(s_i, a_i, r_i, s'_i)$。更新 Critic 网络: 计算目标值 $y_i$,最小化 Critic 的损失函数。更新 Actor 网络: 使用 Critic 网络的梯度来调整 Actor 网络的参数。目标网络更新: 按照软更新公式更新目标网络的参数。重复以上步骤, 直到达到学习目标。Python 实现
下面给出 DDPG 的完整 Python 实现。该实现包括 Actor-Critic 架构、缓冲区和目标网络等,基于 PyTorch 构建。
1. 导入必要库
首先引入 Gym 库创建环境,NumPy 处理数组,PyTorch 构建模型,以及 deque 实现经验回放池。
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
2. 定义 Actor 网络
Actor 网络的作用是生成给定状态下的最优动作。输入包含状态维度、动作维度和最大动作值。网络结构包含两层隐藏层(256 个神经元),输出层使用 tanh 激活函数将动作限制在 [-1, 1],再乘以 max_action 缩放到实际范围。
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.layer1 = nn.Linear(state_dim, 256)
self.layer2 = nn.Linear(256, 256)
self.layer3 = nn.Linear(256, action_dim)
self.max_action = max_action
def forward(self, state):
x = torch.relu(self.layer1(state))
x = torch.relu(self.layer2(x))
x = torch.tanh(self.layer3(x)) * self.max_action
return x
3. 定义 Critic 网络
Critic 网络评估给定状态和动作的质量(即 Q 值)。输入将状态和动作拼接,经过两层隐藏层后输出标量 Q 值。
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.layer1 = nn.Linear(state_dim + action_dim, 256)
self.layer2 = nn.Linear(256, 256)
self.layer3 = nn.Linear(256, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=1)
x = torch.relu(self.layer1(x))
x = torch.relu(self.layer2(x))
x = self.layer3(x)
return x
4. 定义经验回放池
经验回放池存储智能体与环境交互的经验数据,打破样本间的时间相关性。使用双端队列 deque 管理固定容量的缓冲区。
class ReplayBuffer:
def __init__(self, max_size):
self.buffer = deque(maxlen=max_size)
def add(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions), np.array(rewards),
np.array(next_states), np.array(dones))
def size(self):
return len(self.buffer)
5. 定义 DDPG 智能体
智能体类封装了 Actor、Critic 及其目标网络,以及优化器和超参数。这里使用了 Adam 优化器,并设置了软更新系数 tau。
class DDPGAgent:
def __init__(self, state_dim, action_dim, max_action, gamma=0.99, tau=0.005, buffer_size=100000, batch_size=64):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
self.max_action = max_action
self.gamma = gamma
self.tau = tau
self.replay_buffer = ReplayBuffer(buffer_size)
self.batch_size = batch_size
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1))
action = self.actor(state).detach().cpu().numpy().flatten()
return action
def train(self):
if self.replay_buffer.size() < self.batch_size:
return
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
states = torch.FloatTensor(states)
actions = torch.FloatTensor(actions)
rewards = torch.FloatTensor(rewards).unsqueeze(1)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones).unsqueeze(1)
with torch.no_grad():
next_actions = self.actor_target(next_states)
target_q = self.critic_target(next_states, next_actions)
target_q = rewards + (1 - dones) * self.gamma * target_q
current_q = self.critic(states, actions)
critic_loss = nn.MSELoss()(current_q, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
actor_loss = -self.critic(states, self.actor(states)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def add_to_replay_buffer(self, state, action, reward, next_state, done):
self.replay_buffer.add(state, action, reward, next_state, done)
6. 训练智能体
主训练循环负责创建环境、重置状态、选择动作、执行动作并更新网络。这里以 Pendulum-v1 环境为例。
def train_ddpg(env_name, episodes=1000, max_steps=200):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
agent = DDPGAgent(state_dim, action_dim, max_action)
rewards = []
for episode in range(episodes):
state, _ = env.reset()
episode_reward = 0
for step in range(max_steps):
action = agent.select_action(state)
next_state, reward, done, _, _ = env.step(action)
agent.add_to_replay_buffer(state, action, reward, next_state, done)
agent.train()
state = next_state
episode_reward += reward
if done:
break
rewards.append(episode_reward)
print(f"Episode: {episode + 1}, Reward: {episode_reward}")
plt.plot(rewards)
plt.title("Learning Curve")
plt.xlabel("Episodes")
plt.ylabel("Cumulative Reward")
plt.show()
env.close()
if __name__ == "__main__":
env_name = "Pendulum-v1"
episodes = 500
train_ddpg(env_name, episodes=episodes)
运行环境与配置
- Python 3.11.5
- Torch 2.1.0
- Gym 0.26.2
注意:上述代码主要用于理解和学习算法原理。若应用于实际项目,通常需要根据具体任务调整超参数和网络结构。
优势与应用场景
优势
- 解决连续动作问题: 可以直接输出一个连续值动作,而不像传统的离散强化学习算法需要动作离散化。
- 样本效率高: 使用了经验回放和目标网络,减少了样本相关性问题,提高了学习效率和稳定性。
应用场景
- 机器人运动控制(机械臂、无人机)
- 自动驾驶中的连续控制任务
- 游戏中的复杂策略设计
通俗类比
可以把 DDPG 算法想象成一个赛车手(Actor)和他的教练(Critic):
- 赛车手(Actor): 决定转弯的角度、加速的力度,直接控制赛车。
- 教练(Critic): 观察赛车手的表现,告诉他哪些动作是好的,哪些是需要改进的。
- 经验回放池: 赛车手在训练中不断回看他之前的比赛录像,找到改进的地方。
- 目标网络: 类似于赛车手的长期目标,比如平稳驾驶,而不是今天开得快、明天开得慢。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online