TD3 算法详解:双延迟深度确定性策略梯度
一、算法背景与动机
双延迟深度确定性策略梯度算法(Twin Delayed Deep Deterministic Policy Gradient, TD3)是强化学习中专门针对连续动作空间问题设计的一种算法。它由 Fujimoto 等人在 2018 年提出,旨在解决深度确定性策略梯度(DDPG)算法在实际应用中存在的关键挑战。
DDPG 的局限性
DDPG 结合了策略(Actor)和价值函数(Critic),在连续动作空间中表现优异,但存在以下主要问题:
- Q 值过估计问题:Critic 网络在训练时容易高估 Q 值,导致策略网络(Actor)学习不稳定。
- 策略噪声问题:由于策略直接输出确定性动作,在训练时容易陷入局部最优解。
- 训练不稳定性:Critic 网络和 Actor 网络同时更新时,相互影响可能导致训练震荡。
为了解决上述问题,TD3 通过三项核心改进显著提升了算法的鲁棒性。
二、核心思想
TD3 在 DDPG 的基础上提出了三项关键改进:
1. 双 Critic 网络(Twin Critics)
动机:DDPG 中的 Critic 网络在估计 Q 值时存在系统性的高估问题。
方法:TD3 使用两个独立的 Critic 网络计算 Q 值,取两者的最小值作为目标 Q 值。
$$y = r + \gamma \min \big( Q_{\theta_1'}(s', \pi_{\phi'}(s')), Q_{\theta_2'}(s', \pi_{\phi'}(s')) \big)$$
效果:有效减少了 Q 值的高估偏差(Overestimation Bias),防止策略受到错误估计的误导。
2. 延迟更新(Delayed Policy Updates)
动机:在 DDPG 中,Critic 网络和 Actor 网络同时更新,可能导致 Actor 策略在不稳定的 Q 值估计上进行优化。
方法:降低 Actor 和目标网络的更新频率,通常在 Critic 更新两次后才更新一次 Actor。
效果:降低了 Actor 网络的更新频率,从而提高了策略的稳定性。
3. 目标策略平滑(Target Policy Smoothing)
动机:DDPG 中的目标策略直接输出确定性动作,容易对极端动作过拟合。
方法:在目标值计算中,对动作加入高斯噪声并裁剪到一定范围。
$$a' = \pi_{\phi'}(s') + \text{clip}(\epsilon, -c, c), \quad \epsilon \sim \mathcal{N}(0, \sigma)$$
效果:提高了算法对噪声和目标值波动的鲁棒性。
三、数学细节解析
1. Actor-Critic 框架的核心
Actor-Critic 方法将策略学习(Actor)与价值评估(Critic)结合。Actor 负责生成动作,Critic 负责评估当前策略的表现。
(1) 策略梯度
Actor 通过最大化累计奖励学习最优策略:
$$\nabla_\phi J(\pi_\phi) = \mathbb{E}{s \sim \rho^\pi} \left[ \nabla\phi \pi_\phi(s) \nabla_a Q^\pi(s, a) \big|{a=\pi\phi(s)} \right]$$
其中 $\rho^\pi$ 是由策略生成的状态分布,$Q^\pi(s, a)$ 是 Critic 估计的动作值函数。
(2) 价值评估 (Critic)
Critic 通过最小化时间差分(Temporal Difference, TD)误差,学习状态 - 动作值函数:
$$L(\theta) = \mathbb{E}{(s, a, r, s') \sim \mathcal{D}} \left[ \big( Q\theta(s, a) - y \big)^2 \right]$$
其中目标值 $y$ 定义为:
$$y = r + \gamma Q_{\theta'}(s', \pi_{\phi'}(s'))$$
2. TD3 的损失函数
Critic 损失函数
TD3 使用两个 Critic 网络,损失函数为:
$$L(\theta_i) = \mathbb{E}{(s, a, r, s')} \left[ (Q{\theta_i}(s, a) - y)^2 \right]$$
其中目标值采用双 Critic 的最小值:
$$y = r + \gamma \min \big( Q_{\theta_1'}(s', a'), Q_{\theta_2'}(s', a') \big)$$
Actor 策略梯度
Actor 通过最大化 Critic 网络的输出优化策略:
$$\nabla_\phi J(\phi) = \mathbb{E}{s \sim \rho^\pi} \big[ \nabla_a Q{\theta_1}(s, a) \big|{a=\pi\phi(s)} \nabla_\phi \pi_\phi(s) \big]$$
四、PyTorch 实现细节
下面是一个完整的 TD3 实现示例,基于 PyTorch 和 OpenAI Gym。代码结构清晰,包含配置、经验回放、网络定义及训练循环。
1. 环境配置与参数解析
import argparse
import os
import random
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import namedtuple
from itertools import count
device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', default="Pendulum-v0")
parser.add_argument('--tau', default=0.005, type=float)
parser.add_argument('--learning_rate', default=3e-4, type=float)
parser.add_argument('--gamma', default=0.99, type=int)
parser.add_argument('--capacity', default=50000, type=int)
parser.add_argument('--batch_size', default=100, type=int)
parser.add_argument('--policy_noise', default=0.2, type=float)
parser.add_argument('--noise_clip', default=0.5, type=float)
parser.add_argument('--policy_delay', default=2, type=int)
args = parser.parse_args()
2. 经验回放缓冲区
class Replay_buffer():
def __init__(self, max_size=args.capacity):
self.storage = []
self.max_size = max_size
self.ptr = 0
def push(self, data):
if len(self.storage) == self.max_size:
self.storage[int(self.ptr)] = data
self.ptr = (self.ptr + 1) % self.max_size
else:
self.storage.append(data)
def sample(self, batch_size):
ind = np.random.randint(0, len(self.storage), size=batch_size)
x, y, u, r, d = [], [], [], [], []
for i in ind:
X, Y, U, R, D = self.storage[i]
x.append(np.array(X, copy=False))
y.append(np.array(Y, copy=False))
u.append(np.array(U, copy=False))
r.append(np.array(R, copy=False))
d.append(np.array(D, copy=False))
return np.array(x), np.array(y), np.array(u), np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1)
3. 神经网络定义
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, action_dim)
self.max_action = max_action
def forward(self, state):
a = F.relu(self.fc1(state))
a = F.relu(self.fc2(a))
return self.max_action * torch.tanh(self.fc3(a))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, 1)
def forward(self, state, action):
state_action = torch.cat([state, action], 1)
q = F.relu(self.fc1(state_action))
q = F.relu(self.fc2(q))
return self.fc3(q)
4. 算法逻辑与训练循环
class TD3():
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.critic_1 = Critic(state_dim, action_dim).to(device)
self.critic_1_target = Critic(state_dim, action_dim).to(device)
self.critic_2 = Critic(state_dim, action_dim).to(device)
self.critic_2_target = Critic(state_dim, action_dim).to(device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=args.learning_rate)
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=args.learning_rate)
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=args.learning_rate)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_1_target.load_state_dict(self.critic_1.state_dict())
self.critic_2_target.load_state_dict(self.critic_2.state_dict())
self.memory = Replay_buffer(args.capacity)
self.num_training = 0
def select_action(self, state):
state = torch.tensor(state.reshape(1, -1)).float().to(device)
return self.actor(state).cpu().data.numpy().flatten()
def update(self, num_iteration):
for i in range(num_iteration):
x, y, u, r, d = self.memory.sample(args.batch_size)
state = torch.FloatTensor(x).to(device)
action = torch.FloatTensor(u).to(device)
next_state = torch.FloatTensor(y).to(device)
done = torch.FloatTensor(d).to(device)
reward = torch.FloatTensor(r).to(device)
# Target policy smoothing
noise = torch.ones_like(action).data.normal_(0, args.policy_noise).to(device)
noise = noise.clamp(-args.noise_clip, args.noise_clip)
next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
# Compute target Q values
target_Q1 = self.critic_1_target(next_state, next_action)
target_Q2 = self.critic_2_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + ((1 - done) * args.gamma * target_Q).detach()
# Update Critic 1
current_Q1 = self.critic_1(state, action)
loss_Q1 = F.mse_loss(current_Q1, target_Q)
self.critic_1_optimizer.zero_grad()
loss_Q1.backward()
self.critic_1_optimizer.step()
# Update Critic 2
current_Q2 = self.critic_2(state, action)
loss_Q2 = F.mse_loss(current_Q2, target_Q)
self.critic_2_optimizer.zero_grad()
loss_Q2.backward()
self.critic_2_optimizer.step()
# Delayed Actor update
if i % args.policy_delay == 0:
actor_loss = - self.critic_1(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update target networks
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
for param, target_param in zip(self.critic_1.parameters(), self.critic_1_target.parameters()):
target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
for param, target_param in zip(self.critic_2.parameters(), self.critic_2_target.parameters()):
target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
self.num_training += 1
5. 主程序入口
if __name__ == '__main__':
env = gym.make(args.env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
agent = TD3(state_dim, action_dim, max_action)
if args.mode == 'train':
print("Collection Experience...")
for i in range(args.num_iteration):
state = env.reset()
ep_r = 0
for t in range(2000):
action = agent.select_action(state)
action = action + np.random.normal(0, 0.1, size=env.action_space.shape[0])
action = action.clip(env.action_space.low, env.action_space.high)
next_state, reward, done, info = env.step(action)
ep_r += reward
agent.memory.push((state, next_state, action, reward, np.float(done)))
state = next_state
if done or t == 1999:
break
if len(agent.memory.storage) >= args.capacity - 1:
agent.update(10)
if i % 50 == 0:
print(f"Ep_i {i}, Ep_Reward {ep_r:.2f}")
五、总结
TD3 不仅改进了 DDPG 的不足,还为强化学习的稳定性研究提供了重要的理论和实践参考。其成功之处在于克服了 Q 值过估计问题,使得训练过程更加稳定,并提升了策略更新的鲁棒性。作为一个里程碑式的算法,TD3 推动了连续动作空间强化学习的发展,为后续算法(如 SAC、PPO 等)提供了宝贵的启发。
参考文献:Addressing Function Approximation Error in Actor-Critic Methods (Fujimoto et al., 2018)


