双延迟深度确定性策略梯度算法 (TD3) 详解
背景与动机
在连续动作空间的强化学习任务中,深度确定性策略梯度(DDPG)算法曾表现出色。然而,实际应用中我们发现 DDPG 存在几个显著缺陷:Q 值估计容易高估、策略更新不稳定以及训练过程对噪声敏感。
2018 年,Fujimoto 等人提出了双延迟深度确定性策略梯度算法(Twin Delayed Deep Deterministic Policy Gradient, TD3)。它并非推翻 DDPG,而是在其 Actor-Critic 架构基础上,通过三项关键改进解决了上述痛点,成为连续控制任务中的基准算法之一。
核心思想
TD3 的改进主要集中在三个方面,它们共同作用以提升训练的鲁棒性。
1. 双 Critic 网络(Twin Critics)
DDPG 使用单个 Critic 网络,容易因函数近似误差导致 Q 值系统性高估。TD3 引入了两个独立的 Critic 网络 $Q_{\theta_1}$ 和 $Q_{\theta_2}$。在计算目标值时,取两者预测的最小值:
$$y = r + \gamma \min \big( Q_{\theta_1'}(s', \pi_{\phi'}(s')), Q_{\theta_2'}(s', \pi_{\phi'}(s')) \big)$$
这种保守策略有效抑制了过估计偏差,让策略学习更稳健。
2. 延迟更新(Delayed Policy Updates)
在 DDPG 中,Actor 和 Critic 每步都同步更新,这可能导致 Actor 基于尚未收敛的 Critic 进行优化。TD3 降低了 Actor 的更新频率,通常每 $d$ 次 Critic 更新才更新一次 Actor(默认 $d=2$)。这意味着 Critic 有更多时间拟合数据,为 Actor 提供更准确的梯度信号。
3. 目标策略平滑(Target Policy Smoothing)
为了防止策略过拟合到特定的极端动作点,TD3 在计算目标动作时加入轻微的高斯噪声并裁剪:
$$a' = \pi_{\phi'}(s') + \text{clip}(\epsilon, -c, c), \quad \epsilon \sim \mathcal{N}(0, \sigma)$$
这相当于一种正则化手段,提高了算法对扰动和目标值波动的适应能力。
算法流程与数学细节
TD3 依然遵循 Actor-Critic 框架,但损失函数和更新逻辑有所调整。
Critic 损失函数
Critic 的目标是最小化均方误差(MSE),利用双网络输出最小值作为目标:
$$L(\theta_i) = \mathbb{E}{(s, a, r, s')} \big[ (Q{\theta_i}(s, a) - y)^2 \big]$$
其中 $y$ 的计算如上所述,包含了双 Critic 取最小值的操作。
Actor 策略梯度
Actor 旨在最大化 Critic 的 Q 值估计。由于使用了双 Critic,我们通常只依赖其中一个(如 $Q_{\theta_1}$)来计算梯度:
$$\nabla_\phi J(\phi) = \mathbb{E}{s} \big[ \nabla_a Q{\theta_1}(s, a) \big|{a=\pi\phi(s)} \nabla_\phi \pi_\phi(s) \big]$$
软更新机制
为了保持目标网络的稳定性,参数采用软更新(Soft Update)而非硬复制:
$$\theta' \leftarrow \tau \theta + (1 - \tau) \theta'$$
PyTorch 实现
下面是一个完整的 TD3 实现示例,基于 OpenAI Gym 环境。代码结构清晰,包含网络定义、经验回放及训练循环。
环境与配置
import argparse
import os
import random
import numpy as np
torch
torch.nn nn
torch.optim optim
torch.distributions Normal
gym
device = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument(, default=)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
parser.add_argument(, default=, =)
args = parser.parse_args()


