跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

Double DQN 算法详解:原理、流程与 PyTorch 实现

综述由AI生成Double DQN 算法针对 DQN 存在的过估计偏差问题,通过分离动作选择和目标 Q 值计算来提升稳定性。该方案利用在线网络选择动作,目标网络评估价值,有效降低了 Q 值估计的高估风险。文章详细阐述了算法背景、核心思想及数学推导,并提供了基于 PyTorch 的完整 Python 代码实现,涵盖网络构建、经验回放、训练循环等关键环节,适合希望深入理解强化学习算法原理及落地实现的开发者参考。

山野来信发布于 2026/3/24更新于 2026/5/74 浏览
Double DQN 算法详解:原理、流程与 PyTorch 实现

Double DQN 算法详解

强化学习中的深度 Q 网络(DQN)将深度学习与 Q 学习结合,通过神经网络逼近 Q 函数来解决高维状态问题。然而,标准 DQN 存在过估计偏差(Overestimation Bias),即在更新 Q 值时,由于同一个网络既负责选择动作又负责评估价值,容易导致 Q 值估计偏高。

Double DQN(DDQN)引入了双网络机制来缓解这一问题,显著提高了算法的稳定性和收敛性。

算法背景与提出

在强化学习早期,Q 学习依赖 Q 值表描述状态 - 动作对的长期累积奖励。当空间巨大或连续时,传统方法难以扩展。DQN 引入神经网络取得了如 Atari 游戏的成果,但实际应用中暴露出过估计偏差问题。

过估计偏差问题

DQN 的 Q 值更新公式如下:

$$y_t^{DQN} = r_t + \gamma \max_a Q_{\theta^-}(s_{t+1}, a)$$

其中 $Q_{\theta^-}$ 是目标网络的 Q 值,$\gamma$ 是折扣因子,$r_t$ 是即时奖励。

DQN 使用最大值操作选择动作并估计未来价值,这可能导致过高估计。根本原因在于:

  1. 同一个网络(目标网络)既负责选择动作,又负责评估这些动作的价值。
  2. 神经网络的逼近误差会放大估计值,加剧过估计。

这种偏差会导致策略过于激进、学习过程不稳定甚至无法收敛。

Double Q-Learning 的灵感

Double Q-Learning 通过分离动作选择和价值估计来减少过估计。它使用两个独立的 Q 值表:一个用于选择动作,另一个用于计算目标值。

其目标值公式为:

$$y_t^{DoubleQ} = r_t + \gamma Q_2(s_{t+1}, \arg\max_a Q_1(s_{t+1}, a))$$

通过分离计算,动作选择的误差不会直接影响目标值计算,从而降低了风险。

Double DQN 的提出

Double DQN 受此启发,将其扩展到深度强化学习领域。主要区别在于:

  1. 使用在线网络(Online Network)来选择动作。
  2. 使用目标网络(Target Network)来估计动作的价值。

Double DQN 的目标值公式为:

$$y_t^{DDQN} = r_t + \gamma Q_{\theta^-}(s_{t+1}, \arg\max_a Q_{\theta}(s_{t+1}, a))$$

其中 $Q_{\theta^-}$ 是目标网络,用于估计目标 Q 值;$Q_{\theta}$ 是在线网络,用于选择动作。这种方法成功解决了 DQN 的过估计问题,并在多个任务中表现出更好的性能。

Double DQN 的核心思想

核心在于分离动作选择和目标 Q 值计算:

  1. 使用在线网络选择动作。
  2. 使用目标网络计算目标 Q 值。

这种分离使得目标 Q 值的计算更加可靠,有助于减少估计偏差。

算法流程

初始化阶段需要构建两个神经网络:在线网络 $Q_{\theta}$ 和目标网络 $Q_{\theta^-}$。目标网络的参数会定期从在线网络同步。

在执行动作时,当前状态 $s_t$ 下利用在线网络选择动作 $a_t = \arg\max_a Q_{\theta}(s_t, a)$。随后将转移样本 $(s_t, a_t, r_t, s_{t+1})$ 存入经验回放池。

训练时从池中随机采样小批量数据。关键步骤在于目标值计算:使用在线网络选择下一个状态的最佳动作 $a' = \arg\max_a Q_{\theta}(s_{i+1}, a)$,再使用目标网络计算目标 Q 值 $y_i = r_i + \gamma Q_{\theta^-}(s_{i+1}, a')$。

最后使用均方误差作为损失函数对在线网络进行梯度下降,并每隔一定步数将在线网络参数复制到目标网络。

公式推导

Double DQN 通过分离动作选择和目标计算来减小过估计。Q 值由目标网络 $Q_{\theta^-}$ 计算,而动作 $a$ 由在线网络 $Q_{\theta}$ 选择。

DDQN 的目标值为:

$$y_t^{DDQN} = r_t + \gamma Q_{\theta^-}(s_{t+1}, \arg\max_a Q_{\theta}(s_{t+1}, a))$$

相比之下,传统 DQN 的目标值是:

$$y_t^{DQN} = r_t + \gamma \max_a Q_{\theta^-}(s_{t+1}, a)$$

这里的 max 操作是导致过估计问题的根源。

Python 实现

下面给出基于 PyTorch 框架的 Double DQN 完整实现,包含核心的在线网络和目标网络更新机制。

导入必要库

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

超参数设置

# Hyperparameters
GAMMA = 0.99          # 折扣因子
LR = 0.001            # 学习率
BATCH_SIZE = 64       # 批量大小
MEMORY_CAPACITY = 10000  # 经验回放池容量
TARGET_UPDATE = 10    # 目标网络更新周期

定义网络结构

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

这里定义了三层全连接网络,中间层使用 ReLU 激活函数,输出层直接输出 Q 值。

经验回放池

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions), np.array(rewards), 
                np.array(next_states), np.array(dones))

    def __len__(self):
        return len(self.buffer)

Double DQN 智能体

class DoubleDQNAgent:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.online_net = QNetwork(state_dim, action_dim)
        self.target_net = QNetwork(state_dim, action_dim)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=LR)
        self.memory = ReplayBuffer(MEMORY_CAPACITY)
        self.steps_done = 0

    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.online_net(state)
                return q_values.argmax().item()

    def store_transition(self, state, action, reward, next_state, done):
        self.memory.push(state, action, reward, next_state, done)

    def update(self):
        if len(self.memory) < BATCH_SIZE:
            return
        states, actions, rewards, next_states, dones = self.memory.sample(BATCH_SIZE)
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions).unsqueeze(1)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).unsqueeze(1)

        q_values = self.online_net(states).gather(1, actions)

        with torch.no_grad():
            next_actions = self.online_net(next_states).argmax(dim=1, keepdim=True)
            next_q_values = self.target_net(next_states).gather(1, next_actions)
            target_q_values = rewards + (1 - dones) * GAMMA * next_q_values

        loss = nn.MSELoss()(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

注意 update 方法中,next_actions 是由在线网络选出的,而 next_q_values 是由目标网络评估的,这正是 DDQN 的关键。

训练循环

import gym
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DoubleDQNAgent(state_dim, action_dim)

num_episodes = 500
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 500

for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    while not done:
        epsilon = epsilon_end + (epsilon_start - epsilon_end) * np.exp(-1. * agent.steps_done / epsilon_decay)
        action = agent.select_action(state, epsilon)
        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        agent.store_transition(state, action, reward, next_state, done)
        agent.update()
        state = next_state
        agent.steps_done += 1

    if episode % TARGET_UPDATE == 0:
        agent.update_target_network()
    print(f"Episode {episode}, Total Reward: {total_reward}")

env.close()

环境配置建议使用 Python 3.11.5, torch 2.1.0, gym 0.26.2。需要注意的是,本文代码主要用于展示算法逻辑,若应用于实际项目,通常还需要针对具体环境进行超参数调优。

优势与特点

特性DQNDouble DQN
目标值计算动作选择和评估使用同一网络分离动作选择和目标评估
过估计偏差明显存在显著减小
训练稳定性容易震荡更加稳定
算法复杂度较低略微增加
  1. 减小过估计偏差:分离动作选择和目标计算后,有效减少了过高估计的风险。
  2. 更稳定的训练过程:估计值更准确,训练平滑,收敛速度更快。
  3. 简单易实现:在 DQN 基础上仅需引入动作选择的分离逻辑。

总结

Double DQN 算法的提出主要是为了解决 DQN 中的'过估计偏差'问题。通过引入双网络,让动作选择和价值评估分离,大大提高了算法的稳定性和准确性。在实际工程中,这是处理离散动作空间强化学习任务的经典基线模型之一。

目录

  1. Double DQN 算法详解
  2. 算法背景与提出
  3. 过估计偏差问题
  4. Double Q-Learning 的灵感
  5. Double DQN 的提出
  6. Double DQN 的核心思想
  7. 算法流程
  8. 公式推导
  9. Python 实现
  10. 导入必要库
  11. 超参数设置
  12. Hyperparameters
  13. 定义网络结构
  14. 经验回放池
  15. Double DQN 智能体
  16. 训练循环
  17. 优势与特点
  18. 总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • GitHub Copilot 调用第三方模型 API 配置指南
  • crypto-js JavaScript 加密标准库使用指南
  • OpenClaw 集成 Telegram 机器人实战指南
  • Spring 事务管理核心:@Transactional 注解与传播机制详解
  • GitHub 日榜精选:AI 智能体与开发工具趋势
  • 2026 年主流 AI 工具对比:豆包、DeepSeek、元宝、ChatGPT、Cursor
  • 阿里开源 iFlow CLI:终端级 AI 智能体功能与使用指南
  • Spring Web 模块核心概念与 RESTful API 调用
  • 双指针算法详解:三数之和与四数之和
  • 基于 IsaacLab 从零训练机器人行走
  • 轮腿机器人代码调试与软硬件配置说明
  • 智谱 AI 免费大模型 API 调用教程
  • VSCode 远程连接 Copilot 显示脱机状态修复方案
  • Claude Code 本地部署与 Copilot 反向代理配置指南
  • 基于.NET的Web API控制器及方法注解属性
  • C++ 类大小计算:内存对齐、虚函数与继承详解
  • OpenClaw 飞书通信端机器人配置指南
  • Gaussian Grouping:在三维场景中分割与编辑任意物体
  • 如何利用 AI 大模型解决实际问题:实战案例与操作指南
  • 机器人自主导航避障全栈方案(涵盖ROS2实现与实车测试数据)

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online