TD3 算法详解:双延迟深度确定性策略梯度
双延迟深度确定性策略梯度算法(Twin Delayed Deep Deterministic Policy Gradient,简称 TD3)是强化学习领域针对连续动作空间问题设计的一种高效算法。它由 Fujimoto 等人在 2018 年提出,旨在解决深度确定性策略梯度(DDPG)算法在实际应用中存在的训练不稳定和 Q 值过估计问题。
一、为什么需要 TD3?
DDPG 虽然能处理连续控制任务,但在实践中暴露了几个关键缺陷,导致其在复杂任务中容易失败或表现不佳:
- Q 值过估计(Overestimation Bias):Critic 网络在训练时倾向于高估 Q 值,这种偏差会随时间累积,误导 Actor 网络学习次优策略。
- 策略噪声敏感:由于 DDPG 输出确定性动作,缺乏探索机制,容易陷入局部最优解。
- 训练震荡:Actor 和 Critic 同时更新时,相互干扰可能导致训练过程不稳定。
TD3 正是为了缓解这些问题而设计的,它通过三项核心改进显著提升了算法的鲁棒性。
二、TD3 的核心思想
1. 双 Critic 网络(Twin Critics)
为了解决 Q 值高估问题,TD3 引入了两个独立的 Critic 网络 $Q_{\theta_1}$ 和 $Q_{\theta_2}$。在计算目标 Q 值时,不再使用单个网络的预测,而是取两个网络预测值的最小值:
$$y = r + \gamma \min \big( Q_{\theta_1'}(s', \pi_{\phi'}(s')), Q_{\theta_2'}(s', \pi_{\phi'}(s')) \big)$$
这种机制类似于 Double Q-Learning 的思想,能有效降低因函数近似误差导致的系统性高估偏差。
2. 延迟更新(Delayed Policy Updates)
在 DDPG 中,每次采样后都会更新 Actor 和 Critic。TD3 则降低了 Actor 的更新频率。通常的做法是:每更新 Critic 两次,才更新一次 Actor。这样做的好处是让 Critic 有更充分的时间收敛到更准确的 Q 值估计,从而避免 Actor 基于不稳定的 Q 值进行优化。
3. 目标策略平滑(Target Policy Smoothing)
为了防止策略对特定的状态 - 动作对过拟合,TD3 在计算目标动作时加入了少量的高斯噪声,并对结果进行裁剪:
$$a' = \pi_{\phi'}(s') + \text{clip}(\epsilon, -c, c), \quad \epsilon \sim \mathcal{N}(0, \sigma)$$
这里的噪声不仅增加了探索性,还使得目标 Q 值更加平滑,提高了算法对噪声和波动的鲁棒性。
三、数学细节与算法流程
1. 损失函数定义
Critic 损失: $$L(\theta_i) = \mathbb{E}{(s, a, r, s') \sim \mathcal{D}} \left[ (Q{\theta_i}(s, a) - y)^2 \right]$$ 其中 $y$ 为上述双 Critic 计算出的目标值。
Actor 损失: $$J(\phi) = -\mathbb{E}{s \sim \mathcal{D}} \left[ Q{\theta_1}(s, \pi_\phi(s)) \right]$$ 目标是最大化 Critic 评估的动作价值。
2. 软更新(Soft Update)
为了保证训练稳定性,目标网络参数不是硬复制,而是通过软更新方式缓慢逼近主网络: $$\theta' \leftarrow \tau \theta + (1 - \tau) \theta'$$ 其中 $\tau$ 是一个较小的系数(如 0.005)。
四、PyTorch 实现要点
下面是一个基于 PyTorch 的 TD3 实现框架。代码结构清晰,涵盖了环境交互、经验回放、网络定义及训练循环。
1. 环境与配置
import argparse
gym
torch
torch.nn nn
torch.optim optim
collections deque
numpy np
device = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument(, default=)
parser.add_argument(, default=)
parser.add_argument(, default=)
parser.add_argument(, default=)
parser.add_argument(, default=)
parser.add_argument(, default=)
parser.add_argument(, default=)
parser.add_argument(, default=)
args = parser.parse_args()
env = gym.make(args.env_name)
state_dim = env.observation_space.shape[]
action_dim = env.action_space.shape[]
max_action = (env.action_space.high[])


