跳到主要内容
双延迟深度确定性策略梯度算法 (TD3) 详解 | 极客日志
Python AI 算法
双延迟深度确定性策略梯度算法 (TD3) 详解 TD3 算法针对 DDPG 在连续动作空间中的 Q 值过估计和不稳定问题提出改进。核心包括双 Critic 网络减少偏差、延迟更新增强稳定性、目标策略平滑提升鲁棒性。通过 PyTorch 实现,适用于复杂控制任务。
片刻 发布于 2026/3/23 更新于 2026/4/28 4 浏览双延迟深度确定性策略梯度算法 (TD3) 详解
一、TD3 算法的背景
双延迟深度确定性策略梯度算法,TD3(Twin Delayed Deep Deterministic Policy Gradient)是强化学习中专为解决连续动作空间问题设计的一种算法。TD3 算法的提出是在深度确定性策略梯度(DDPG)算法的基础上改进而来,用于解决强化学习训练中存在的一些关键挑战。
二、TD3 的背景
1. TD3 的理论背景
TD3 的提出基于以下几个强化学习的理论与技术发展:
Actor-Critic 架构 :
Actor 网络负责生成动作,Critic 网络负责评估动作的价值(Q 值)。这种架构使得算法能够高效地解决高维连续动作问题。
Actor 更新目标是最大化 Critic 网络的 Q 值,而 Critic 网络优化目标是最小化 Q 值预测误差。
确定性策略梯度(Deterministic Policy Gradient, DPG) :
DPG 是强化学习中一种适用于连续动作空间的策略梯度方法,TD3 继承了 DPG 的优势,即通过学习一个确定性策略直接生成动作。
双 Q 学习(Double Q-Learning) :
TD3 借鉴了双 Q 学习的思想,使用两个独立的 Critic 网络来降低 Q 值估计的偏差。
经验回放池(Replay Buffer) :
TD3 通过从经验回放池中采样数据训练网络,打破数据相关性,提高了学习效率。
2. DDPG 的局限性
TD3 算法由 Fujimoto 等人在 2018 年提出,对深度确定性策略梯度(Deep Deterministic Policy Gradient, DDPG)算法的改进。DDPG 是一种结合策略(Actor)和价值函数(Critic)的强化学习方法,可以在连续动作空间中表现出色。然而,DDPG 存在以下问题:
这些问题会使训练结果不够鲁棒,甚至使算法在复杂任务中失败。
Q 值过估计问题 :Critic 网络在训练时容易高估 Q 值,从而导致策略网络(Actor)学习不稳定。
策略噪声问题 :由于策略直接输出确定性动作,在训练时容易陷入局部最优解。
训练不稳定性 :Critic 网络和 Actor 网络同时训练时,相互影响可能导致训练震荡。
为了解决上述问题,TD3 通过以下三点创新改进了 DDPG:
三、TD3 算法的核心思想
TD3 在 DDPG 的基础上提出了三项关键改进:
1. 双 Critic 网络(Twin Critics)
动机 :DDPG 中的 Critic 网络在估计 Q 值时存在系统性的高估问题。
方法 :TD3 使用两个独立的 Critic 网络计算 Q 值,取两者的最小值来作为目标 Q 值。
效果 :有效减少了 Q 值的高估偏差(Overestimation Bias)。
2. 延迟更新(Delayed Policy Updates)
动机 :在 DDPG 中,Critic 网络和 Actor 网络同时更新,可能导致 Actor 策略在不稳定的 Q 值估计上进行优化。
方法 :TD3 降低 Actor 和目标网络的更新频率,通常在 Critic 更新两次后才更新 Actor。
效果 :降低了 Actor 网络的更新频率,从而提高了策略的稳定性。
3. 目标策略平滑(Target Policy Smoothing)
动机 :DDPG 中的目标策略直接输出确定性动作,容易对极端动作过拟合。TD3 通过在目标策略中加入高斯噪声,对动作进行'平滑'。
方法 :在目标值计算中,对动作加入噪声并裁剪到一定范围。
效果 :提高了算法对噪声和目标值波动的鲁棒性。
四、TD3 算法详细讲解 TD3(Twin Delayed Deep Deterministic Policy Gradient)适用于连续动作空间问题,主要基于 Actor-Critic 框架和深度确定性策略梯度(DDPG)。以下是 TD3 的数学基础与推导。
1. Actor-Critic 框架的核心 Actor-Critic 方法的核心在于将策略学习(Actor)与价值评估(Critic)结合。Actor 负责生成动作,Critic 负责评估当前策略的表现。Actor 网络优化目标是通过 Critic 网络的反馈提高策略质量。
(1) 策略梯度 $\nabla_\phi J(\pi_\phi) = \mathbb{E}{s \sim \rho^\pi} \left[ \nabla \phi \pi_\phi(s) \nabla_a Q^\pi(s, a) \big|{a=\pi \phi(s)} \right]$
$Q^\pi(s, a)$ 是 Critic 估计的动作值函数。
$\pi_\phi(s)$ 是 Actor 的策略函数。
(2) 价值评估 (Critic) Critic 通过最小化时间差分(Temporal Difference, TD)误差,学习状态 - 动作值函数:
$L(\theta) = \mathbb{E}{(s, a, r, s') \sim \mathcal{D}} \left[ \big( Q \theta(s, a) - y \big)^2 \right]$
$y = r + \gamma Q_{\theta'}(s', \pi_{\phi'}(s'))$
$\mathcal{D}$ 是经验回放池中采样的数据。
$\theta'$ 和 $\phi'$ 是 Critic 和 Actor 的目标网络参数。
2. TD3 的关键改进 TD3 在 DDPG 的基础上,针对 Q 值过估计和策略训练不稳定问题,提出了三项核心改进。
(1) 双 Critic 网络 TD3 引入两个 Critic 网络 $Q_{\theta_1}$ 和 $Q_{\theta_2}$,通过取最小值来降低 Q 值的高估偏差:
$y = r + \gamma \min \big( Q_{\theta_1'}(s', \pi_{\phi'}(s')), Q_{\theta_2'}(s', \pi_{\phi'}(s')) \big)$
目标是防止策略在训练中受到错误 Q 值估计的误导。
(2) 延迟 Actor 更新 为了避免 Actor 网络频繁更新导致策略不稳定,TD3 在 Critic 更新 n 次后才更新 Actor 一次(通常 n=2)。Actor 的优化目标为:
$L(\phi) = -\mathbb{E}{s \sim \mathcal{D}} \big[ Q {\theta_1}(s, \pi_\phi(s)) \big]$
Critic 网络训练稳定后,Actor 的策略梯度才会更加准确。
$L(\phi) = -\mathbb{E}{s} \big[ Q {\theta_1}(s, \pi_\phi(s)) \big]$
(3) 目标动作平滑 在计算目标值 $y$ 时,对动作加入高斯噪声 $\epsilon \sim \mathcal{N}(0, \sigma)$ 并进行裁剪,防止策略过拟合到极端动作:
$a' = \pi_{\phi'}(s') + \text{clip}(\epsilon, -c, c)$
这样可以让目标 Q 值更加平滑,增强策略的鲁棒性。
3. 完整 TD3 算法流程
初始化 Actor 和 Critic 网络及其对应的目标网络;
重复以上步骤直到收敛。
$\theta_i' \leftarrow \tau \theta_i + (1 - \tau) \theta_i'$
$\phi' \leftarrow \tau \phi + (1 - \tau) \phi'$
延迟更新 Actor :每隔 d 步,更新 Actor 策略:
$L(\phi) = -\mathbb{E}{s} \big[ Q {\theta_1}(s, \pi_\phi(s)) \big]$
更新 Critic :通过最小化损失函数更新 $Q_{\theta_1}$ 和 $Q_{\theta_2}$:
$L(\theta_i) = \mathbb{E}{(s, a, r, s')} \big[ (Q {\theta_i}(s, a) - y)^2 \big]$
从 $\mathcal{D}$ 中随机抽取一个批量数据 $(s, a, r, s')$:
与环境交互,使用当前策略 $\pi_\phi$ 执行动作 $a_t$,存储 $(s_t, a_t, r_t, s_{t+1})$;
构建经验回放池 $\mathcal{D}$,用于存储交互数据。
4. 数学细节解析
(1) Critic 损失函数 TD3 使用两个 Critic 网络,损失函数为:
$L(\theta) = \mathbb{E} \big[ (Q_{\theta}(s, a) - y)^2 \big]$
$y = r + \gamma \min \big( Q_{\theta_1'}(s', a'), Q_{\theta_2'}(s', a') \big)$
(2) Actor 策略梯度 Actor 通过最大化 Critic 网络的输出优化策略:
$\nabla_\phi J(\phi) = \mathbb{E}{s \sim \rho^\pi} \big[ \nabla_a Q {\theta_1}(s, a) \big|{a=\pi \phi(s)} \nabla_\phi \pi_\phi(s) \big]$
(3) 延迟更新的效果 延迟更新使 Actor 网络只在 Critic 网络收敛后才更新,减少了 Actor 网络梯度被不准确 Q 值引导的风险,从而提高了稳定性。
五、TD3 算法实现
TD3 的简易版核心实现 """TD3 Algorithm Implementation"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class Actor (nn.Module):
def __init__ (self, state_dim, action_dim, max_action ):
super (Actor, self ).__init__()
self .layer1 = nn.Linear(state_dim, 256 )
self .layer2 = nn.Linear(256 , 256 )
self .layer3 = nn.Linear(256 , action_dim)
self .max_action = max_action
def forward (self, x ):
x = torch.relu(self .layer1(x))
x = torch.relu(self .layer2(x))
x = self .max_action * torch.tanh(self .layer3(x))
return x
class Critic (nn.Module):
def __init__ (self, state_dim, action_dim ):
super (Critic, self ).__init__()
self .layer1 = nn.Linear(state_dim + action_dim, 256 )
self .layer2 = nn.Linear(256 , 256 )
self .layer3 = nn.Linear(256 , 1 )
def forward (self, x, u ):
x = torch.cat([x, u], 1 )
x = torch.relu(self .layer1(x))
x = torch.relu(self .layer2(x))
return self .layer3(x)
class TD3 :
def __init__ (self, state_dim, action_dim, max_action ):
self .actor = Actor(state_dim, action_dim, max_action).to(device)
self .actor_target = Actor(state_dim, action_dim, max_action).to(device)
self .actor_optimizer = optim.Adam(self .actor.parameters(), lr=1e-3 )
self .critic1 = Critic(state_dim, action_dim).to(device)
self .critic2 = Critic(state_dim, action_dim).to(device)
self .critic1_target = Critic(state_dim, action_dim).to(device)
self .critic2_target = Critic(state_dim, action_dim).to(device)
self .critic_optimizer = optim.Adam(
list (self .critic1.parameters()) + list (self .critic2.parameters()), lr=1e-3
)
self .max_action = max_action
self .replay_buffer = deque(maxlen=1000000 )
def update (self, batch_size=100 , gamma=0.99 , tau=0.005 , policy_noise=0.2 , noise_clip=0.5 , delay=2 ):
pass
def select_action (self, state ):
state = torch.FloatTensor(state.reshape(1 , -1 )).to(device)
return self .actor(state).cpu().data.numpy().flatten()
参数配置 import argparse
from collections import namedtuple
from itertools import count
import os, sys, random
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from tensorboardX import SummaryWriter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument('--mode' , default='train' , type =str )
parser.add_argument("--env_name" , default="Pendulum-v0" )
parser.add_argument('--tau' , default=0.005 , type =float )
parser.add_argument('--target_update_interval' , default=1 , type =int )
parser.add_argument('--iteration' , default=5 , type =int )
parser.add_argument('--learning_rate' , default=3e-4 , type =float )
parser.add_argument('--gamma' , default=0.99 , type =int )
parser.add_argument('--capacity' , default=50000 , type =int )
parser.add_argument('--num_iteration' , default=100000 , type =int )
parser.add_argument('--batch_size' , default=100 , type =int )
parser.add_argument('--seed' , default=1 , type =int )
parser.add_argument('--num_hidden_layers' , default=2 , type =int )
parser.add_argument('--sample_frequency' , default=256 , type =int )
parser.add_argument('--activation' , default='Relu' , type =str )
parser.add_argument('--render' , default=False , type =bool )
parser.add_argument('--log_interval' , default=50 , type =int )
parser.add_argument('--load' , default=False , type =bool )
parser.add_argument('--render_interval' , default=100 , type =int )
parser.add_argument('--policy_noise' , default=0.2 , type =float )
parser.add_argument('--noise_clip' , default=0.5 , type =float )
parser.add_argument('--policy_delay' , default=2 , type =int )
parser.add_argument('--exploration_noise' , default=0.1 , type =float )
parser.add_argument('--max_episode' , default=2000 , type =int )
parser.add_argument('--print_log' , default=5 , type =int )
args = parser.parse_args()
script_name = os.path.basename(__file__)
directory = './exp' + script_name + args.env_name + './'
经验回放缓冲区 class Replay_buffer ():
def __init__ (self, max_size=args.capacity ):
self .storage = []
self .max_size = max_size
self .ptr = 0
def push (self, data ):
if len (self .storage) == self .max_size:
self .storage[int (self .ptr)] = data
self .ptr = (self .ptr + 1 ) % self .max_size
else :
self .storage.append(data)
def sample (self, batch_size ):
ind = np.random.randint(0 , len (self .storage), size=batch_size)
x, y, u, r, d = [], [], [], [], []
for i in ind:
X, Y, U, R, D = self .storage[i]
x.append(np.array(X, copy=False ))
y.append(np.array(Y, copy=False ))
u.append(np.array(U, copy=False ))
r.append(np.array(R, copy=False ))
d.append(np.array(D, copy=False ))
return np.array(x), np.array(y), np.array(u), np.array(r).reshape(-1 , 1 ), np.array(d).reshape(-1 , 1 )
网络配置 class Actor (nn.Module):
def __init__ (self, state_dim, action_dim, max_action ):
super (Actor, self ).__init__()
self .fc1 = nn.Linear(state_dim, 400 )
self .fc2 = nn.Linear(400 , 300 )
self .fc3 = nn.Linear(300 , action_dim)
self .max_action = max_action
def forward (self, state ):
a = F.relu(self .fc1(state))
a = F.relu(self .fc2(a))
a = torch.tanh(self .fc3(a)) * self .max_action
return a
class Critic (nn.Module):
def __init__ (self, state_dim, action_dim ):
super (Critic, self ).__init__()
self .fc1 = nn.Linear(state_dim + action_dim, 400 )
self .fc2 = nn.Linear(400 , 300 )
self .fc3 = nn.Linear(300 , 1 )
def forward (self, state, action ):
state_action = torch.cat([state, action], 1 )
q = F.relu(self .fc1(state_action))
q = F.relu(self .fc2(q))
q = self .fc3(q)
return q
算法逻辑 class TD3 ():
def __init__ (self, state_dim, action_dim, max_action ):
self .actor = Actor(state_dim, action_dim, max_action).to(device)
self .actor_target = Actor(state_dim, action_dim, max_action).to(device)
self .critic_1 = Critic(state_dim, action_dim).to(device)
self .critic_1_target = Critic(state_dim, action_dim).to(device)
self .critic_2 = Critic(state_dim, action_dim).to(device)
self .critic_2_target = Critic(state_dim, action_dim).to(device)
self .actor_optimizer = optim.Adam(self .actor.parameters())
self .critic_1_optimizer = optim.Adam(self .critic_1.parameters())
self .critic_2_optimizer = optim.Adam(self .critic_2.parameters())
self .actor_target.load_state_dict(self .actor.state_dict())
self .critic_1_target.load_state_dict(self .critic_1.state_dict())
self .critic_2_target.load_state_dict(self .critic_2.state_dict())
self .max_action = max_action
self .memory = Replay_buffer(args.capacity)
self .writer = SummaryWriter(directory)
self .num_critic_update_iteration = 0
self .num_actor_update_iteration = 0
self .num_training = 0
def select_action (self, state ):
state = torch.tensor(state.reshape(1 , -1 )).float ().to(device)
return self .actor(state).cpu().data.numpy().flatten()
def update (self, num_iteration ):
if self .num_training % 500 == 0 :
print ("====================================" )
print ("model has been trained for {} times..." .format (self .num_training))
print ("====================================" )
for i in range (num_iteration):
x, y, u, r, d = self .memory.sample(args.batch_size)
state = torch.FloatTensor(x).to(device)
action = torch.FloatTensor(u).to(device)
next_state = torch.FloatTensor(y).to(device)
done = torch.FloatTensor(d).to(device)
reward = torch.FloatTensor(r).to(device)
noise = torch.ones_like(action).data.normal_(0 , args.policy_noise).to(device)
noise = noise.clamp(-args.noise_clip, args.noise_clip)
next_action = (self .actor_target(next_state) + noise).clamp(-self .max_action, self .max_action)
target_Q1 = self .critic_1_target(next_state, next_action)
target_Q2 = self .critic_2_target(next_state, next_action)
target_Q = torch.min (target_Q1, target_Q2)
target_Q = reward + ((1 - done) * args.gamma * target_Q).detach()
current_Q1 = self .critic_1(state, action)
loss_Q1 = F.mse_loss(current_Q1, target_Q)
self .critic_1_optimizer.zero_grad()
loss_Q1.backward()
self .critic_1_optimizer.step()
current_Q2 = self .critic_2(state, action)
loss_Q2 = F.mse_loss(current_Q2, target_Q)
self .critic_2_optimizer.zero_grad()
loss_Q2.backward()
self .critic_2_optimizer.step()
if i % args.policy_delay == 0 :
actor_loss = - self .critic_1(state, self .actor(state)).mean()
self .actor_optimizer.zero_grad()
actor_loss.backward()
self .actor_optimizer.step()
for param, target_param in zip (self .actor.parameters(), self .actor_target.parameters()):
target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
for param, target_param in zip (self .critic_1.parameters(), self .critic_1_target.parameters()):
target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
for param, target_param in zip (self .critic_2.parameters(), self .critic_2_target.parameters()):
target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
self .num_actor_update_iteration += 1
self .num_critic_update_iteration += 1
self .num_training += 1
def save (self ):
torch.save(self .actor.state_dict(), directory + 'actor.pth' )
torch.save(self .actor_target.state_dict(), directory + 'actor_target.pth' )
torch.save(self .critic_1.state_dict(), directory + 'critic_1.pth' )
torch.save(self .critic_1_target.state_dict(), directory + 'critic_1_target.pth' )
torch.save(self .critic_2.state_dict(), directory + 'critic_2.pth' )
torch.save(self .critic_2_target.state_dict(), directory + 'critic_2_target.pth' )
print ("====================================" )
print ("Model has been saved..." )
print ("====================================" )
def load (self ):
self .actor.load_state_dict(torch.load(directory + 'actor.pth' ))
self .actor_target.load_state_dict(torch.load(directory + 'actor_target.pth' ))
self .critic_1.load_state_dict(torch.load(directory + 'critic_1.pth' ))
self .critic_1_target.load_state_dict(torch.load(directory + 'critic_1_target.pth' ))
self .critic_2.load_state_dict(torch.load(directory + 'critic_2.pth' ))
self .critic_2_target.load_state_dict(torch.load(directory + 'critic_2_target.pth' ))
print ("====================================" )
print ("Model has been loaded..." )
print ("====================================" )
主程序入口 if __name__ == '__main__' :
agent = TD3(state_dim, action_dim, max_action)
ep_r = 0
if args.mode == 'test' :
agent.load()
for i in range (args.iteration):
state = env.reset()
for t in count():
action = agent.select_action(state)
next_state, reward, done, info = env.step(np.float32(action))
ep_r += reward
env.render()
if done or t == 2000 :
print ("Ep_i \t{}, the ep_r is \t{:0.2f}, the step is \t{}" .format (i, ep_r, t))
break
state = next_state
elif args.mode == 'train' :
print ("====================================" )
print ("Collection Experience..." )
print ("====================================" )
if args.load:
agent.load()
for i in range (args.num_iteration):
state = env.reset()
for t in range (2000 ):
action = agent.select_action(state)
action = action + np.random.normal(0 , args.exploration_noise, size=env.action_space.shape[0 ])
action = action.clip(env.action_space.low, env.action_space.high)
next_state, reward, done, info = env.step(action)
ep_r += reward
if args.render and i >= args.render_interval:
env.render()
agent.memory.push((state, next_state, action, reward, np.float (done)))
if (i + 1 ) % 10 == 0 :
print ('Episode {}, The memory size is {} ' .format (i, len (agent.memory.storage)))
if len (agent.memory.storage) >= args.capacity - 1 :
agent.update(10 )
state = next_state
if done or t == args.max_episode - 1 :
agent.writer.add_scalar('ep_r' , ep_r, global_step=i)
if i % args.print_log == 0 :
print ("Ep_i \t{}, the ep_r is \t{:0.2f}, the step is \t{}" .format (i, ep_r, t))
ep_r = 0
break
else :
raise NameError("mode wrong!!!" )
六、TD3 的优势
降低 Q 值高估偏差 :双 Critic 网络的最小值策略有效减少了偏差。
增强训练稳定性 :延迟更新减少了网络间的干扰。
适应复杂环境 :目标动作平滑提高了鲁棒性。
七、总结 TD3 不仅改进了 DDPG 的不足,还为强化学习的稳定性研究提供了重要的理论和实践参考。其成功之处在于:
克服了 Q 值过估计问题 ,使得训练过程更加稳定;
提升了策略更新的鲁棒性 ,能更高效地探索动作空间。
作为一个里程碑式的算法,TD3 推动了连续动作空间强化学习的发展,为后续算法(如 SAC、PPO 等)提供了宝贵的启发。
参考文献:Addressing Function Approximation Error in Actor-Critic Methods
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online