跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

深度确定性策略梯度算法 (DDPG) 详解与 PyTorch 实现

综述由AI生成深度确定性策略梯度(DDPG)是一种适用于连续动作空间的强化学习算法,结合了 Actor-Critic 框架、目标网络和经验回放机制。本文详细解析了 DDPG 的核心原理,包括 Q 值函数更新、策略梯度优化及目标网络软更新公式,并提供了基于 PyTorch 和 Gym 环境的完整 Python 实现。通过 Pendulum-v1 环境训练,展示了如何构建 Actor 与 Critic 网络、管理经验池以及可视化学习曲线,帮助读者理解算法在实际控制任务中的应用流程。

乱七八糟发布于 2026/3/24更新于 2026/5/54 浏览
深度确定性策略梯度算法 (DDPG) 详解与 PyTorch 实现

深度确定性策略梯度算法 (DDPG) 详解

深度确定性策略梯度(Deep Deterministic Policy Gradient,简称 DDPG)是一种基于深度强化学习的算法,专门用于解决连续动作空间的问题,例如机器人控制中的连续运动。它结合了确定性策略和深度神经网络,属于 Actor-Critic 框架,同时吸收了 DQN 和 PG(Policy Gradient)的优点。

核心机制与改进点

DDPG 之所以在连续控制任务中表现优异,主要得益于以下几个关键设计:

  • 适用于连续动作空间:直接输出连续值动作,无需像传统方法那样对动作进行离散化。
  • 确定性策略:与随机策略不同,DDPG 输出的是每个状态下一个确定的最优动作,这简化了梯度计算。
  • 目标网络机制:使用延迟更新的目标网络,有效稳定了训练过程,避免了参数波动过大导致的震荡。
  • 经验回放(Replay Buffer):通过存储交互经验并随机采样,打破了数据间的时间相关性,显著提升了样本利用率。
  • 双网络架构:Actor 网络负责生成动作,Critic 网络评估动作质量,两者协同优化。

算法公式推导

理解 DDPG 的核心在于掌握其更新逻辑。Critic 网络利用 Bellman 方程更新目标 Q 值:

y = r + \gamma Q'(s', \mu'(s'; \theta^{\mu'}); \theta^{Q'})

其中 $s'$ 是下一状态,$\mu'$ 是目标 Actor 网络,$Q'$ 是目标 Critic 网络,$\gamma$ 为折扣因子。 Critic 网络的优化目标是最小化以下损失函数:

L(\theta^Q) = \frac{1}{N} \sum_{i} \left( Q(s_i, a_i; \theta^Q) - y_i \right)^2

对于策略更新(Actor 网络),目标是最大化 Critic 网络给出的 Q 值:

J(\theta^\mu) = \frac{1}{N} \sum_{i} Q(s_i, \mu(s_i; \theta^\mu); \theta^Q)

使用梯度上升法更新 Actor 网络参数:

\nabla_{\theta^\mu} J \approx \frac{1}{N} \sum_{i} \nabla_a Q(s, a; \theta^Q) \big|{a=\mu(s)} \nabla{\theta^\mu} \mu(s; \theta^\mu)

此外,目标网络采用软更新方式缓慢向当前网络靠近:

![\theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau) \theta^{Q'} ]

\theta^{\mu'} \leftarrow \tau \theta^\mu + (1 - \tau) \theta^{\mu'}

其中 $\tau \in (0, 1)$ 是软更新系数。

代码实现

下面给出基于 PyTorch 和 Gym 环境的完整 Python 实现。我们将分模块构建 Actor、Critic、经验回放池以及智能体类。

环境准备

首先导入必要的库,包括 Gym 用于创建环境,PyTorch 构建模型,NumPy 处理数据。

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

构建 Actor 网络

Actor 网络负责根据状态输出动作。输入层接收状态维度,经过两层隐藏层(每层 256 个神经元,ReLU 激活),输出层使用 Tanh 激活将动作限制在 [-1, 1] 之间,再乘以最大动作值缩放到实际范围。

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.layer1 = nn.Linear(state_dim, 256)
        self.layer2 = nn.Linear(256, 256)
        self.layer3 = nn.Linear(256, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = torch.relu(self.layer1(state))
        x = torch.relu(self.layer2(x))
        x = torch.tanh(self.layer3(x)) * self.max_action
        return x

构建 Critic 网络

Critic 网络评估状态 - 动作对的质量(Q 值)。它将状态和动作拼接后作为输入,同样经过两层隐藏层,最后输出一个标量 Q 值。

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.layer1 = nn.Linear(state_dim + action_dim, 256)
        self.layer2 = nn.Linear(256, 256)
        self.layer3 = nn.Linear(256, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = self.layer3(x)
        return x

定义经验回放池

经验回放池用于存储历史交互数据,打破样本相关性。我们使用双端队列 deque 来管理固定容量的缓冲区。

class ReplayBuffer:
    def __init__(self, max_size):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(next_states),
            np.array(dones)
        )

    def size(self):
        return len(self.buffer)

定义 DDPG 智能体

智能体类整合了 Actor、Critic 及其目标网络,并配置了优化器和超参数。

class DDPGAgent:
    def __init__(self, state_dim, action_dim, max_action, gamma=0.99, tau=0.005, buffer_size=100000, batch_size=64):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = Actor(state_dim, action_dim, max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)

        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)

        self.max_action = max_action
        self.gamma = gamma
        self.tau = tau
        self.replay_buffer = ReplayBuffer(buffer_size)
        self.batch_size = batch_size

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1))
        action = self.actor(state).detach().cpu().numpy().flatten()
        return action

    def train(self):
        if self.replay_buffer.size() < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).unsqueeze(1)

        # 更新 Critic
        with torch.no_grad():
            next_actions = self.actor_target(next_states)
            target_q = self.critic_target(next_states, next_actions)
            target_q = rewards + (1 - dones) * self.gamma * target_q

        current_q = self.critic(states, actions)
        critic_loss = nn.MSELoss()(current_q, target_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # 更新 Actor
        actor_loss = -self.critic(states, self.actor(states)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 软更新目标网络
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def add_to_replay_buffer(self, state, action, reward, next_state, done):
        self.replay_buffer.add(state, action, reward, next_state, done)

训练流程与可视化

主训练循环负责与环境交互、存储经验并定期更新网络。这里以 Pendulum-v1 环境为例。

import matplotlib.pyplot as plt

def train_ddpg(env_name, episodes=1000, max_steps=200):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    agent = DDPGAgent(state_dim, action_dim, max_action)
    rewards = []

    for episode in range(episodes):
        state, _ = env.reset()
        episode_reward = 0
        for step in range(max_steps):
            action = agent.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            agent.add_to_replay_buffer(state, action, reward, next_state, done)
            agent.train()
            state = next_state
            episode_reward += reward
            if done:
                break
        rewards.append(episode_reward)
        print(f"Episode: {episode + 1}, Reward: {episode_reward}")

    plt.plot(rewards)
    plt.title("Learning Curve")
    plt.xlabel("Episodes")
    plt.ylabel("Cumulative Reward")
    plt.show()
    env.close()

if __name__ == "__main__":
    env_name = "Pendulum-v1"
    episodes = 500
    train_ddpg(env_name, episodes=episodes)

训练效果与注意事项

运行上述代码后,你将看到学习曲线图,展示累计奖励随回合数的变化趋势。通常情况下,随着训练进行,智能体应能逐渐学会控制 pendulum 保持直立或获得更高奖励。

在实际应用中,请注意以下几点:

  1. 超参数调优:不同的环境可能需要调整 gamma、tau、学习率等参数以获得最佳效果。
  2. 探索噪声:在训练初期,可以在动作中加入高斯噪声以增加探索能力,避免陷入局部最优。
  3. 环境适配:虽然本例使用 Pendulum 环境,但 DDPG 同样适用于机械臂、无人机等更复杂的连续控制场景。

DDPG 通过结合确定性策略与深度神经网络,为连续控制问题提供了一种高效且稳定的解决方案。希望这份详解和代码能帮助你快速上手该算法。

目录

  1. 深度确定性策略梯度算法 (DDPG) 详解
  2. 核心机制与改进点
  3. 算法公式推导
  4. 代码实现
  5. 环境准备
  6. 构建 Actor 网络
  7. 构建 Critic 网络
  8. 定义经验回放池
  9. 定义 DDPG 智能体
  10. 训练流程与可视化
  11. 训练效果与注意事项
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Android RxJava3 核心使用场景与实战指南
  • 知网 AIGC 检测标准与论文降重实战指南
  • VLA 机器人革命:解析 10 篇关键视觉 - 语言 - 动作模型论文
  • OpenClaw 接入飞书机器人配置指南
  • Python 列表内存存储本质:差异原因与优化建议
  • MySQL 数据类型全面解析与实战选型
  • RoboBrain 2.5:解决具身智能空间与时间维度的落地难题
  • OpenClaw 权限配置与安全指南
  • Bodymovin 开源动画转换工具跨平台集成方案
  • GitHub Copilot AI 编程使用教程:从入门到精通
  • Spring Boot 实战:基于 WebSocket 的前后端实时匹配系统实现
  • 从零开始搭建个人知识库的实践指南
  • 深度评测 GLM-5:代码生成实战体验
  • Discord 机器人创建与配置完整流程
  • LightRAG 框架介绍及 WebUI 本地部署指南
  • MIT 室内场景识别数据集详解与 YOLOv8 实战
  • RNN 与序列数据处理实战:从原理到 LSTM 文本分类
  • Google GenAI Toolbox:企业级 AI 数据库中间件与 LLM-SQL 安全互联实践
  • Mac Studio 1.5TB 显存集群:基于雷电 5 的 RDMA 技术测试
  • OpenClaw 生态主流厂商产品深度横评与选型指南

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online