跳到主要内容
Soft Actor-Critic (SAC) 算法详解与 PyTorch 实现 | 极客日志
Python AI 算法
Soft Actor-Critic (SAC) 算法详解与 PyTorch 实现 Soft Actor-Critic (SAC) 算法是一种针对连续动作空间的先进强化学习方法,通过引入最大熵优化目标解决探索与利用的平衡问题。该算法采用 Actor-Critic 架构,结合双 Q 网络减少过估计偏差,并利用目标网络提升训练稳定性。详细阐述了 SAC 的核心思想、数学推导及算法流程,提供了基于 PyTorch 的完整代码实现,涵盖策略网络、价值网络、经验回放缓冲区及训练循环。相比传统方法,SAC 具有更高的样本效率和更强的鲁棒性,适用于机器人控制等复杂任务。
落日余晖 发布于 2026/3/27 更新于 2026/4/25 1 浏览Soft Actor-Critic (SAC) 算法详解
Soft Actor-Critic (SAC) 是一种目前最先进的强化学习算法,属于 Actor-Critic 方法的变体。它特别适合处理 连续动作空间 ,并通过引入最大熵(Maximum Entropy)强化学习的思想,有效解决了传统算法在稳定性和探索性上的痛点。
SAC 背景与核心思想
1. 强化学习的挑战
在传统强化学习中,我们常面临几个核心难题:
探索与利用的平衡 :初期需要广泛探索新策略,后期需利用已知最优策略,两者难以兼顾。
不稳定性 :连续动作空间中训练容易出现发散或收敛缓慢。
样本效率 :数据采集成本高,如何高效利用经验池中的数据至关重要。
SAC 通过以下核心机制应对这些问题:
最大熵强化学习 :在最大化累计奖励的同时,最大化策略的随机性(熵),鼓励探索。
双 Q 网络 :缓解 Q 值过估计问题。
目标网络 :使用目标网络稳定 Q 值计算。
2. 最大熵强化学习的目标
传统强化学习的目标是最大化期望累计奖励:
$$J(\pi) = \mathbb{E}{\pi} \left[ \sum {t=0}^T \gamma^t r(s_t, a_t) \right]$$
而 SAC 则通过添加一个 熵项 ,在奖励中加入策略随机性的权重,目标变为:
$$J(\pi) = \mathbb{E}{\pi} \left[ \sum {t=0}^T \gamma^t \left( r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot|s_t)) \right) \right]$$
其中:
$\alpha$:熵系数,控制熵和奖励之间的平衡。
$\mathcal{H}(\pi(\cdot|s_t)) = -\mathbb{E}_{a \sim \pi} [\log \pi(a|s_t)]$:表示策略的熵,鼓励策略更随机化。
效果 :更好的探索能力使策略更加多样化,同时避免陷入次优策略,提升学习稳定性。
SAC 算法流程
SAC 使用了 Actor-Critic 框架,结合策略梯度和 Q 函数更新。以下是算法的关键步骤:
初始化
创建目标值函数网络 $V_{\psi'}$,并设置其参数为 $V_{\psi}$ 的初始值。
初始化策略网络 $\pi_\phi$ 和值函数网络 $V_\psi$。
初始化两组 Q 网络 $Q_{\theta_1}, Q_{\theta_2}$,用于计算 Q 值。
每一回合循环
采样动作 :根据策略网络 $\pi_\phi$ 采样动作 $a \sim \pi(a|s)$。
执行动作 :执行动作,记录 $(s, a, r, s', \text{done})$ 到经验池中。
更新 Q 网络 :最小化 TD 误差损失。
更新值函数网络 :逼近软价值目标。
更新策略网络 :最大化奖励和熵。
更新目标值函数网络 :使用软更新规则 $\psi' \gets \tau \psi + (1 - \tau) \psi'$。
公式推导
1. Q 值更新
Q 值通过 Bellman 方程更新,目标是最小化 TD 误差:
$$y = r + \gamma (1 - \text{done}) \cdot V_{\psi'}(s')$$
损失函数为:
$$J_Q = \mathbb{E}{(s, a, r, s') \sim D} \left[ \left( Q {\theta_i}(s, a) - y \right)^2 \right]$$
2. 值函数更新 值函数估计策略的长期价值,目标值为:
$$y_V = \mathbb{E}{a \sim \pi} \left[ \min {i=1,2} Q_{\theta_i}(s, a) - \alpha \log \pi_\phi(a|s) \right]$$
损失函数为:
$$J_V = \mathbb{E}{s \sim D} \left[ \left( V \psi(s) - y_V \right)^2 \right]$$
3. 策略网络更新 策略网络的目标是最大化奖励和熵,等价于最小化:
$$J_\pi = \mathbb{E}{s \sim D, a \sim \pi} \left[ \alpha \log \pi \phi(a|s) - \min_{i=1,2} Q_{\theta_i}(s, a) \right]$$
4. 目标值函数更新 目标值函数使用软更新规则:
$$\psi' \gets \tau \psi + (1 - \tau) \psi'$$
其中 $\tau \in (0, 1]$ 控制更新步长。
Python 实现 下面是基于 PyTorch 的 Soft Actor-Critic (SAC) 算法完整实现。代码结构清晰,分为参数配置、网络定义、回放缓冲区和主训练循环。
1. 参数配置 """《SAC, Soft Actor-Critic 算法》时间:2024.12"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
LR = 0.001
BATCH_SIZE = 256
MEMORY_CAPACITY = 100000
2. 策略网络 策略网络负责生成随机的策略动作。这里使用了重参数化技巧(Reparameterization Trick)来支持梯度回传。
class PolicyNetwork (nn.Module):
def __init__ (self, state_dim, action_dim, max_action ):
super (PolicyNetwork, self ).__init__()
self .fc1 = nn.Linear(state_dim, 256 )
self .fc2 = nn.Linear(256 , 256 )
self .mean = nn.Linear(256 , action_dim)
self .log_std = nn.Linear(256 , action_dim)
self .max_action = max_action
def forward (self, state ):
x = torch.relu(self .fc1(state))
x = torch.relu(self .fc2(x))
mean = self .mean(x)
log_std = self .log_std(x).clamp(-20 , 2 )
std = torch.exp(log_std)
return mean, std
def sample (self, state ):
mean, std = self .forward(state)
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample()
y_t = torch.tanh(x_t)
action = y_t * self .max_action
log_prob = normal.log_prob(x_t)
log_prob -= torch.log(1 - y_t.pow (2 ) + 1e-6 )
log_prob = log_prob.sum (dim=-1 , keepdim=True )
return action, log_prob
3. Q 网络 价值网络用于评估状态 - 动作对的价值。SAC 通常使用两个独立的 Q 网络以减少过估计偏差。
class QNetwork (nn.Module):
def __init__ (self, state_dim, action_dim ):
super (QNetwork, self ).__init__()
self .fc1 = nn.Linear(state_dim + action_dim, 256 )
self .fc2 = nn.Linear(256 , 256 )
self .fc3 = nn.Linear(256 , 1 )
def forward (self, state, action ):
x = torch.cat([state, action], dim=-1 )
x = torch.relu(self .fc1(x))
x = torch.relu(self .fc2(x))
x = self .fc3(x)
return x
4. 经验回放缓冲区 class ReplayBuffer :
def __init__ (self, capacity ):
self .buffer = deque(maxlen=capacity)
def push (self, state, action, reward, next_state, done ):
self .buffer.append((state, action, reward, next_state, done))
def sample (self, batch_size ):
batch = random.sample(self .buffer, batch_size)
states, actions, rewards, next_states, dones = zip (*batch)
return (np.array(states), np.array(actions), np.array(rewards),
np.array(next_states), np.array(dones))
def __len__ (self ):
return len (self .buffer)
5. SAC 智能体类 class SACAgent :
def __init__ (self, state_dim, action_dim, max_action ):
self .device = torch.device("cuda" if torch.cuda.is_available() else "cpu" )
self .actor = PolicyNetwork(state_dim, action_dim, max_action).to(self .device)
self .q1 = QNetwork(state_dim, action_dim).to(self .device)
self .q2 = QNetwork(state_dim, action_dim).to(self .device)
self .actor_optimizer = optim.Adam(self .actor.parameters(), lr=LR)
self .q1_optimizer = optim.Adam(self .q1.parameters(), lr=LR)
self .q2_optimizer = optim.Adam(self .q2.parameters(), lr=LR)
self .replay_buffer = ReplayBuffer(MEMORY_CAPACITY)
self .max_action = max_action
def select_action (self, state ):
state = torch.FloatTensor(state).to(self .device).unsqueeze(0 )
action, _ = self .actor.sample(state)
return action.cpu().detach().numpy()[0 ]
def train (self ):
if len (self .replay_buffer) < BATCH_SIZE:
return
states, actions, rewards, next_states, dones = self .replay_buffer.sample(BATCH_SIZE)
states = torch.FloatTensor(states).to(self .device)
actions = torch.FloatTensor(actions).to(self .device)
rewards = torch.FloatTensor(rewards).unsqueeze(1 ).to(self .device)
next_states = torch.FloatTensor(next_states).to(self .device)
dones = torch.FloatTensor(dones).unsqueeze(1 ).to(self .device)
with torch.no_grad():
next_actions, log_probs = self .actor.sample(next_states)
target_q1 = self .q1(next_states, next_actions)
target_q2 = self .q2(next_states, next_actions)
target_q = torch.min (target_q1, target_q2) - ALPHA * log_probs
q_target = rewards + GAMMA * (1 - dones) * target_q
q1_loss = ((self .q1(states, actions) - q_target) ** 2 ).mean()
q2_loss = ((self .q2(states, actions) - q_target) ** 2 ).mean()
self .q1_optimizer.zero_grad()
q1_loss.backward()
self .q1_optimizer.step()
self .q2_optimizer.zero_grad()
q2_loss.backward()
self .q2_optimizer.step()
new_actions, log_probs = self .actor.sample(states)
q1_new = self .q1(states, new_actions)
q2_new = self .q2(states, new_actions)
q_new = torch.min (q1_new, q2_new)
actor_loss = (ALPHA * log_probs - q_new).mean()
self .actor_optimizer.zero_grad()
actor_loss.backward()
self .actor_optimizer.step()
def update_replay_buffer (self, state, action, reward, next_state, done ):
self .replay_buffer.push(state, action, reward, next_state, done)
6. 主训练循环 env = gym.make("Pendulum-v1" )
state_dim = env.observation_space.shape[0 ]
action_dim = env.action_space.shape[0 ]
max_action = float (env.action_space.high[0 ])
agent = SACAgent(state_dim, action_dim, max_action)
num_episodes = 500
for episode in range (num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.update_replay_buffer(state, action, reward, next_state, done)
agent.train()
state = next_state
episode_reward += reward
print (f"Episode {episode} , Reward: {episode_reward} " )
环境配置与注意事项
Python 3.11.5
torch 2.1.0
torchvision 0.16.0
gym 0.26.2
注意 :算法在实际项目中应用时,通常需要针对具体任务进行调参和优化。上述代码主要用于理解和学习算法原理,直接应用于复杂场景前请检查环境适配性和超参数设置。
SAC 优势总结
样本效率高 :利用离线经验池,充分挖掘历史数据价值。
探索能力强 :通过最大化熵,鼓励智能体尝试更多样化的行为。
稳定性好 :结合双 Q 网络和目标网络,显著降低训练波动。
适用性强 :专为连续动作空间设计,适用于机器人控制等复杂任务。
Haarnoja, Tuomas, et al. "Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor." arXiv preprint arXiv:1801.01290 (2018).
Haarnoja, Tuomas, et al. "Soft actor-critic algorithms and applications." arXiv preprint arXiv:1812.05905 (2018).
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online