跳到主要内容Double DQN 算法详解与 Python 实现 | 极客日志PythonAI算法
Double DQN 算法详解与 Python 实现
Double DQN 算法通过分离动作选择和目标价值计算来缓解 DQN 中的过估计偏差问题。该算法使用在线网络选择动作,目标网络评估价值,从而提升训练稳定性和收敛性。文中详细阐述了算法背景、核心思想、流程及公式推导,并提供了基于 PyTorch 的 Python 完整实现代码,适用于解决高维状态空间下的强化学习任务。
游戏玩家20 浏览 Double DQN(Double Deep Q-Network) 算法
一、Double DQN 算法详解
强化学习中的深度 Q 网络(DQN)是一种将深度学习与 Q 学习结合的算法,它通过神经网络逼近 Q 函数以解决复杂的高维状态问题。然而,DQN 存在过估计问题(Overestimation Bias),即在更新 Q 值时,由于同时使用同一个网络选择动作和计算目标 Q 值,可能导致 Q 值的估计偏高。
Double DQN(DDQN)引入了'双网络'机制来缓解这个问题,从而提高了算法的稳定性和收敛性。
二、算法背景和提出
在强化学习的早期研究中,Q 学习是一种经典算法,它通过构建 Q 值表来描述每个状态 - 动作对的长期累积奖励。然而,当状态和动作空间变得巨大甚至连续时,Q 学习方法难以扩展。为此,深度 Q 网络(Deep Q-Network, DQN)引入了神经网络来逼近 Q 函数,并取得了显著的成果,如成功应用于 Atari 游戏。但 DQN 算法在实际应用中暴露出了一些问题,其中过估计偏差(Overestimation Bias)尤为突出。
2.1 过估计偏差问题
在 DQN 算法中,Q 值更新公式如下:

其中:

是目标网络的 Q 值。

是折扣因子;

是当前的即时奖励;
DQN 使用的是'最大值'max 操作来选择动作并估计未来的价值,这种方式可能导致过高估计。其根本原因在于:
- 同一个网络(目标网络)既负责选择动作(动作选择偏好),又负责评估这些动作的价值(动作的价值计算)。
- 神经网络的逼近误差会放大估计值,从而进一步加剧过估计问题。
这种偏差会导致:
- 策略变得过于激进;
- 学习过程变得不稳定;
- 收敛速度减慢甚至无法收敛。
2.2 Double Q-Learning 的灵感
Double Q-Learning 是一种用于减少过估计问题的经典方法。其基本思想是分离动作选择和价值估计。它使用两个独立的 Q 值表:
- 一个表用于选择动作;
- 另一个表用于计算目标值。
Double Q-Learning 的目标值公式为:

通过这种分离计算,动作选择的误差不会直接影响到目标值计算,从而减少了过估计的风险。
2.3 Double DQN 的提出
Double DQN(DDQN)受 Double Q-Learning 启发,将其思想扩展到深度强化学习领域。主要区别在于:
- 使用在线网络(Online Network)来选择动作;
- 使用目标网络(Target Network)来估计动作的价值。
这种方法成功地解决了 DQN 的过估计问题,并在多个强化学习任务中表现出了更好的性能和稳定性。
三、Double DQN 的核心思想
Double DQN 通过分离动作选择和目标 Q 值计算来减小过估计问题:
- 使用在线网络(Online Network)选择动作。
- 使用目标网络(Target Network)计算目标 Q 值。
这种分离使得目标 Q 值的计算更加可靠,有助于减少估计偏差。
四、算法流程
五、公式推导
- 通过在线网络选择动作,可以更准确地反映当前策略的动作价值。
- 目标网络仅用来计算 Q 值,减少了目标计算时的估计偏差。
Double DQN 目标:
DDQN 通过分离动作选择和目标计算,目标值改为:
[Python] Double DQN 算法实现
下面给出 Double DQN 算法的完整 Python 实现代码,它通过 PyTorch 框架实现,并包含了核心的在线网络和目标网络的更新机制:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
GAMMA = 0.99
LR = 0.001
BATCH_SIZE = 64
MEMORY_CAPACITY = 10000
TARGET_UPDATE = 10
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones))
def __len__(self):
return len(self.buffer)
class DoubleDQNAgent:
def __init__(self, state_dim, action_dim):
self.state_dim = state_dim
self.action_dim = action_dim
self.online_net = QNetwork(state_dim, action_dim)
self.target_net = QNetwork(state_dim, action_dim)
self.target_net.load_state_dict(self.online_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.online_net.parameters(), lr=LR)
self.memory = ReplayBuffer(MEMORY_CAPACITY)
self.steps_done = 0
def select_action(self, state, epsilon):
if random.random() < epsilon:
return random.randint(0, self.action_dim - 1)
else:
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
q_values = self.online_net(state)
return q_values.argmax().item()
def store_transition(self, state, action, reward, next_state, done):
self.memory.push(state, action, reward, next_state, done)
def update(self):
if len(self.memory) < BATCH_SIZE:
return
states, actions, rewards, next_states, dones = self.memory.sample(BATCH_SIZE)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions).unsqueeze(1)
rewards = torch.FloatTensor(rewards).unsqueeze(1)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones).unsqueeze(1)
q_values = self.online_net(states).gather(1, actions)
with torch.no_grad():
next_actions = self.online_net(next_states).argmax(dim=1, keepdim=True)
next_q_values = self.target_net(next_states).gather(1, next_actions)
target_q_values = rewards + (1 - dones) * GAMMA * next_q_values
loss = nn.MSELoss()(q_values, target_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def update_target_network(self):
self.target_net.load_state_dict(self.online_net.state_dict())
import gym
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DoubleDQNAgent(state_dim, action_dim)
num_episodes = 500
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 500
for episode in range(num_episodes):
state = env.reset()
done = False
total_reward = 0
while not done:
epsilon = epsilon_end + (epsilon_start - epsilon_end) * np.exp(-1. * agent.steps_done / epsilon_decay)
action = agent.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
total_reward += reward
agent.store_transition(state, action, reward, next_state, done)
agent.update()
state = next_state
agent.steps_done += 1
if episode % TARGET_UPDATE == 0:
agent.update_target_network()
print(f"Episode {episode}, Total Reward: {total_reward}")
env.close()
[Notice] 代码说明
- ReplayBuffer:经验回放池,用于存储状态、动作、奖励、下一个状态和是否结束标志。
- QNetwork:定义深度 Q 网络,包含 3 个全连接层。
- DoubleDQNAgent:
- 维护在线网络(Online Network)和目标网络(Target Network)。
- 使用在线网络选择动作,用目标网络计算目标值。
- 训练流程:
- 在每个时间步,使用 ( \epsilon )-贪婪策略选择动作。
- 与环境交互,存储数据到经验回放池。
- 采样小批量数据进行训练,通过Double DQN 公式计算目标 Q 值。
- 定期更新目标网络。
Python 3.11.5
torch 2.1.0
torchvision 0.16.0
gym 0.26.2
由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。
六、优势与特点
| 特性 | DQN | Double DQN |
|---|
| 目标值计算 | 动作选择和评估使用同一网络 | 分离动作选择和目标评估 |
| 过估计偏差 | 明显存在 | 显著减小 |
| 训练稳定性 | 容易震荡 | 更加稳定 |
| 算法复杂度 | 较低 | 略微增加(多一次网络前向计算) |
分离动作选择和目标计算后,Double DQN 有效减少了过高估计的风险。
在 DQN 的基础上,仅需额外引入动作选择的分离逻辑,容易实现。
七、总结
Double DQN 算法的提出,主要是为了解决 DQN 中的'过估计偏差'问题。通过引入双网络,Double DQN 让动作选择和价值评估分离,大大提高了算法的稳定性和准确性。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online