跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

TD3 算法详解:双延迟深度确定性策略梯度

TD3 针对 DDPG 的 Q 值过估计问题提出改进,通过双 Critic 网络、延迟更新和目标策略平滑三项核心机制提升连续控制任务的稳定性与性能。本文解析其数学原理并给出 PyTorch 实现细节,涵盖环境配置、经验回放、网络定义及训练循环,适合希望深入理解并复现该算法的开发者参考。

灰度发布发布于 2026/3/23更新于 2026/5/96 浏览
TD3 算法详解:双延迟深度确定性策略梯度

TD3 算法详解:双延迟深度确定性策略梯度

一、算法背景与动机

双延迟深度确定性策略梯度算法(Twin Delayed Deep Deterministic Policy Gradient, TD3)是强化学习中专门针对连续动作空间问题设计的一种算法。它由 Fujimoto 等人在 2018 年提出,旨在解决深度确定性策略梯度(DDPG)算法在实际应用中存在的关键挑战。

DDPG 的局限性

DDPG 结合了策略(Actor)和价值函数(Critic),在连续动作空间中表现优异,但存在以下主要问题:

  • Q 值过估计问题:Critic 网络在训练时容易高估 Q 值,导致策略网络(Actor)学习不稳定。
  • 策略噪声问题:由于策略直接输出确定性动作,在训练时容易陷入局部最优解。
  • 训练不稳定性:Critic 网络和 Actor 网络同时更新时,相互影响可能导致训练震荡。

为了解决上述问题,TD3 通过三项核心改进显著提升了算法的鲁棒性。

二、核心思想

TD3 在 DDPG 的基础上提出了三项关键改进:

1. 双 Critic 网络(Twin Critics)

动机:DDPG 中的 Critic 网络在估计 Q 值时存在系统性的高估问题。

方法:TD3 使用两个独立的 Critic 网络计算 Q 值,取两者的最小值作为目标 Q 值。

$$y = r + \gamma \min \big( Q_{\theta_1'}(s', \pi_{\phi'}(s')), Q_{\theta_2'}(s', \pi_{\phi'}(s')) \big)$$

效果:有效减少了 Q 值的高估偏差(Overestimation Bias),防止策略受到错误估计的误导。

2. 延迟更新(Delayed Policy Updates)

动机:在 DDPG 中,Critic 网络和 Actor 网络同时更新,可能导致 Actor 策略在不稳定的 Q 值估计上进行优化。

方法:降低 Actor 和目标网络的更新频率,通常在 Critic 更新两次后才更新一次 Actor。

效果:降低了 Actor 网络的更新频率,从而提高了策略的稳定性。

3. 目标策略平滑(Target Policy Smoothing)

动机:DDPG 中的目标策略直接输出确定性动作,容易对极端动作过拟合。

方法:在目标值计算中,对动作加入高斯噪声并裁剪到一定范围。

$$a' = \pi_{\phi'}(s') + \text{clip}(\epsilon, -c, c), \quad \epsilon \sim \mathcal{N}(0, \sigma)$$

效果:提高了算法对噪声和目标值波动的鲁棒性。

三、数学细节解析

1. Actor-Critic 框架的核心

Actor-Critic 方法将策略学习(Actor)与价值评估(Critic)结合。Actor 负责生成动作,Critic 负责评估当前策略的表现。

(1) 策略梯度

Actor 通过最大化累计奖励学习最优策略:

$$\nabla_\phi J(\pi_\phi) = \mathbb{E}{s \sim \rho^\pi} \left[ \nabla\phi \pi_\phi(s) \nabla_a Q^\pi(s, a) \big|{a=\pi\phi(s)} \right]$$

其中 $\rho^\pi$ 是由策略生成的状态分布,$Q^\pi(s, a)$ 是 Critic 估计的动作值函数。

(2) 价值评估 (Critic)

Critic 通过最小化时间差分(Temporal Difference, TD)误差,学习状态 - 动作值函数:

$$L(\theta) = \mathbb{E}{(s, a, r, s') \sim \mathcal{D}} \left[ \big( Q\theta(s, a) - y \big)^2 \right]$$

其中目标值 $y$ 定义为:

$$y = r + \gamma Q_{\theta'}(s', \pi_{\phi'}(s'))$$

2. TD3 的损失函数

Critic 损失函数

TD3 使用两个 Critic 网络,损失函数为:

$$L(\theta_i) = \mathbb{E}{(s, a, r, s')} \left[ (Q{\theta_i}(s, a) - y)^2 \right]$$

其中目标值采用双 Critic 的最小值:

$$y = r + \gamma \min \big( Q_{\theta_1'}(s', a'), Q_{\theta_2'}(s', a') \big)$$

Actor 策略梯度

Actor 通过最大化 Critic 网络的输出优化策略:

$$\nabla_\phi J(\phi) = \mathbb{E}{s \sim \rho^\pi} \big[ \nabla_a Q{\theta_1}(s, a) \big|{a=\pi\phi(s)} \nabla_\phi \pi_\phi(s) \big]$$

四、PyTorch 实现细节

下面是一个完整的 TD3 实现示例,基于 PyTorch 和 OpenAI Gym。代码结构清晰,包含配置、经验回放、网络定义及训练循环。

1. 环境配置与参数解析

import argparse
import os
import random
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import namedtuple
from itertools import count

device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser()
parser.add_argument('--env_name', default="Pendulum-v0")
parser.add_argument('--tau', default=0.005, type=float)
parser.add_argument('--learning_rate', default=3e-4, type=float)
parser.add_argument('--gamma', default=0.99, type=int)
parser.add_argument('--capacity', default=50000, type=int)
parser.add_argument('--batch_size', default=100, type=int)
parser.add_argument('--policy_noise', default=0.2, type=float)
parser.add_argument('--noise_clip', default=0.5, type=float)
parser.add_argument('--policy_delay', default=2, type=int)
args = parser.parse_args()

2. 经验回放缓冲区

class Replay_buffer():
    def __init__(self, max_size=args.capacity):
        self.storage = []
        self.max_size = max_size
        self.ptr = 0

    def push(self, data):
        if len(self.storage) == self.max_size:
            self.storage[int(self.ptr)] = data
            self.ptr = (self.ptr + 1) % self.max_size
        else:
            self.storage.append(data)

    def sample(self, batch_size):
        ind = np.random.randint(0, len(self.storage), size=batch_size)
        x, y, u, r, d = [], [], [], [], []
        for i in ind:
            X, Y, U, R, D = self.storage[i]
            x.append(np.array(X, copy=False))
            y.append(np.array(Y, copy=False))
            u.append(np.array(U, copy=False))
            r.append(np.array(R, copy=False))
            d.append(np.array(D, copy=False))
        return np.array(x), np.array(y), np.array(u), np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1)

3. 神经网络定义

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, action_dim)
        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.fc1(state))
        a = F.relu(self.fc2(a))
        return self.max_action * torch.tanh(self.fc3(a))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, 1)

    def forward(self, state, action):
        state_action = torch.cat([state, action], 1)
        q = F.relu(self.fc1(state_action))
        q = F.relu(self.fc2(q))
        return self.fc3(q)

4. 算法逻辑与训练循环

class TD3():
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.critic_1 = Critic(state_dim, action_dim).to(device)
        self.critic_1_target = Critic(state_dim, action_dim).to(device)
        self.critic_2 = Critic(state_dim, action_dim).to(device)
        self.critic_2_target = Critic(state_dim, action_dim).to(device)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=args.learning_rate)
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=args.learning_rate)
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=args.learning_rate)

        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())

        self.memory = Replay_buffer(args.capacity)
        self.num_training = 0

    def select_action(self, state):
        state = torch.tensor(state.reshape(1, -1)).float().to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, num_iteration):
        for i in range(num_iteration):
            x, y, u, r, d = self.memory.sample(args.batch_size)
            state = torch.FloatTensor(x).to(device)
            action = torch.FloatTensor(u).to(device)
            next_state = torch.FloatTensor(y).to(device)
            done = torch.FloatTensor(d).to(device)
            reward = torch.FloatTensor(r).to(device)

            # Target policy smoothing
            noise = torch.ones_like(action).data.normal_(0, args.policy_noise).to(device)
            noise = noise.clamp(-args.noise_clip, args.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

            # Compute target Q values
            target_Q1 = self.critic_1_target(next_state, next_action)
            target_Q2 = self.critic_2_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + ((1 - done) * args.gamma * target_Q).detach()

            # Update Critic 1
            current_Q1 = self.critic_1(state, action)
            loss_Q1 = F.mse_loss(current_Q1, target_Q)
            self.critic_1_optimizer.zero_grad()
            loss_Q1.backward()
            self.critic_1_optimizer.step()

            # Update Critic 2
            current_Q2 = self.critic_2(state, action)
            loss_Q2 = F.mse_loss(current_Q2, target_Q)
            self.critic_2_optimizer.zero_grad()
            loss_Q2.backward()
            self.critic_2_optimizer.step()

            # Delayed Actor update
            if i % args.policy_delay == 0:
                actor_loss = - self.critic_1(state, self.actor(state)).mean()
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Soft update target networks
                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
                for param, target_param in zip(self.critic_1.parameters(), self.critic_1_target.parameters()):
                    target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)
                for param, target_param in zip(self.critic_2.parameters(), self.critic_2_target.parameters()):
                    target_param.data.copy_((1 - args.tau) * target_param.data + args.tau * param.data)

            self.num_training += 1

5. 主程序入口

if __name__ == '__main__':
    env = gym.make(args.env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    agent = TD3(state_dim, action_dim, max_action)

    if args.mode == 'train':
        print("Collection Experience...")
        for i in range(args.num_iteration):
            state = env.reset()
            ep_r = 0
            for t in range(2000):
                action = agent.select_action(state)
                action = action + np.random.normal(0, 0.1, size=env.action_space.shape[0])
                action = action.clip(env.action_space.low, env.action_space.high)
                next_state, reward, done, info = env.step(action)
                ep_r += reward
                agent.memory.push((state, next_state, action, reward, np.float(done)))
                state = next_state
                if done or t == 1999:
                    break
            if len(agent.memory.storage) >= args.capacity - 1:
                agent.update(10)
            if i % 50 == 0:
                print(f"Ep_i {i}, Ep_Reward {ep_r:.2f}")

五、总结

TD3 不仅改进了 DDPG 的不足,还为强化学习的稳定性研究提供了重要的理论和实践参考。其成功之处在于克服了 Q 值过估计问题,使得训练过程更加稳定,并提升了策略更新的鲁棒性。作为一个里程碑式的算法,TD3 推动了连续动作空间强化学习的发展,为后续算法(如 SAC、PPO 等)提供了宝贵的启发。

参考文献:Addressing Function Approximation Error in Actor-Critic Methods (Fujimoto et al., 2018)

目录

  1. TD3 算法详解:双延迟深度确定性策略梯度
  2. 一、算法背景与动机
  3. DDPG 的局限性
  4. 二、核心思想
  5. 1. 双 Critic 网络(Twin Critics)
  6. 2. 延迟更新(Delayed Policy Updates)
  7. 3. 目标策略平滑(Target Policy Smoothing)
  8. 三、数学细节解析
  9. 1. Actor-Critic 框架的核心
  10. (1) 策略梯度
  11. (2) 价值评估 (Critic)
  12. 2. TD3 的损失函数
  13. Critic 损失函数
  14. Actor 策略梯度
  15. 四、PyTorch 实现细节
  16. 1. 环境配置与参数解析
  17. 2. 经验回放缓冲区
  18. 3. 神经网络定义
  19. 4. 算法逻辑与训练循环
  20. 5. 主程序入口
  21. 五、总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • OpenClaw Windows10 本地 AI 智能体一键部署教程
  • ClawdBot 插件开发:为 Telegram 机器人添加快捷命令
  • 随机链表深拷贝:三步法详解与代码实现
  • 前端虚拟列表实现:优化万级数据渲染性能
  • DFS/BFS 图论基础与海岛问题实战 (C/C++)
  • 星辰 RPA 构建小红书自动发文机器人
  • VirtualBox Ubuntu 虚拟机与 Windows 主机文本复制粘贴设置指南
  • Qwen3 模型 LoRA 微调实战(基于 LLaMA-Factory)
  • 1060 爱丁顿数 Python 实现与解析
  • Eino ADK 中的 ChatModelAgent 详解与实战
  • 人工智能(AI)核心面试题与实战解析
  • Python + AI:构建智能害虫识别系统实战
  • 本地多模态 AI 搜索工具 XiaoyaoSearch 开源实践
  • Python 异步编程与协程实战指南
  • 随机森林算法原理与 Python 实战指南
  • 2026 年 3 月 13 日 AI 热点:芯片大战、Agent 爆发与安全争议
  • 2026 年 3 月 17 日 AI 行业前沿动态
  • Krita 插件配置与 AI 绘画模型部署指南:故障诊断与维护
  • KoboldAI 本地部署与配置实战指南
  • Rust Web 开发实战:Actix Web 框架全面指南

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online