跳到主要内容
Soft Actor-Critic (SAC) 算法详解与 PyTorch 实现 | 极客日志
Python AI 算法
Soft Actor-Critic (SAC) 算法详解与 PyTorch 实现 Soft Actor-Critic (SAC) 是一种基于最大熵框架的离线策略强化学习算法,特别适用于连续动作空间。它通过引入熵正则化项平衡探索与利用,结合双 Q 网络缓解过估计问题,并利用目标网络提升训练稳定性。本文详细阐述了 SAC 的核心思想、数学推导及算法流程,并提供了基于 PyTorch 的完整代码实现,涵盖策略网络、Q 网络、经验回放缓冲区等关键组件,适合希望深入理解并复现该算法的开发者参考。
remedios 发布于 2026/3/28 0 浏览Soft Actor-Critic (SAC) 算法详解
Soft Actor-Critic (SAC) 是一种先进的强化学习算法,属于 Actor-Critic 方法的变体。它特别适合处理 连续动作空间 ,并通过引入最大熵(Maximum Entropy)强化学习的思想,有效解决了传统算法中的稳定性和探索问题。
SAC 背景与核心思想
1. 强化学习的挑战
在强化学习中,我们常面临几个核心难题:
探索与利用的平衡 :传统算法难以在初期充分探索新策略与后期利用已有最优策略之间取得平衡。
不稳定性 :在连续动作空间中,训练容易出现发散或收敛缓慢的情况。
样本效率 :数据采集成本高,如何高效利用经验池中的数据至关重要。
SAC 通过以下核心思想应对这些挑战:
最大熵强化学习 :在最大化累计奖励的同时,最大化策略的随机性(熵),鼓励探索。
双 Q 网络 :缓解 Q 值过估计的问题。
目标网络 :使用目标网络稳定 Q 值计算。
2. 最大熵强化学习的目标
传统强化学习的目标是最大化期望累计奖励:
$$J(\pi) = \mathbb{E}{\pi} \left[ \sum {t=0}^T \gamma^t r(s_t, a_t) \right]$$
而 SAC 则通过添加一个 熵项 ,在奖励中加入策略随机性的权重,目标变为:
$$J(\pi) = \mathbb{E}{\pi} \left[ \sum {t=0}^T \gamma^t \left( r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot|s_t)) \right) \right]$$
其中:
$\alpha$:熵系数,控制熵和奖励之间的平衡。
$\mathcal{H}(\pi(\cdot|s_t)) = -\mathbb{E}_{a \sim \pi} [\log \pi(a|s_t)]$:表示策略的熵,鼓励策略更随机化。
效果 :
更好的探索 :熵的最大化使策略更加多样化。
更稳定的学习 :避免陷入次优策略。
SAC 算法流程
SAC 使用了 Actor-Critic 框架,结合策略梯度和 Q 函数更新。以下是算法的关键步骤:
初始化
创建目标值函数网络 $V_{\psi'}$,并设置其参数为 $V_{\psi}$ 的初始值。
初始化策略网络 $\pi_\phi$ 和值函数网络 $V_\psi$。
初始化两组 Q 网络 $Q_{\theta_1}, Q_{\theta_2}$,用于计算 Q 值。
每一回合循环
采样动作 :根据策略网络 $\pi_\phi$ 采样动作 $a \sim \pi(a|s)$。
执行动作 :执行动作,记录 $(s, a, r, s', \text{done})$ 到经验池中。
更新 Q 网络 :最小化 TD 误差,更新 Q 值。
更新值函数网络 :逼近软价值函数,更新 V 网络。
更新策略网络 :最大化奖励和熵,更新策略参数。
更新目标值函数网络 :使用软更新规则平滑参数。
使用软更新规则:
$$\psi' \gets \tau \psi + (1 - \tau) \psi'$$
策略网络的目标是最大化奖励和熵,最小化以下损失:
$$J_\pi = \mathbb{E} \left[ \alpha \log \pi_\phi(a|s) - \min_{i=1,2} Q_{\theta_i}(s, a) \right]$$
最小化值函数损失:
$$J_V = \mathbb{E} \left[ \left( V_\psi(s) - y_V \right)^2 \right]$$
值函数 $V_\psi$ 的目标是逼近以下值:
$$y_V = \mathbb{E}{a \sim \pi} \left[ \min {i=1,2} Q_{\theta_i}(s, a) - \alpha \log \pi_\phi(a|s) \right]$$
最小化以下损失函数:
$$J_Q = \mathbb{E} \left[ \left( Q_{\theta_i}(s, a) - y \right)^2 \right] \quad (i = 1, 2)$$
使用 TD 目标更新 Q 值:
$$y = r + \gamma (1 - \text{done}) \cdot V_{\psi'}(s')$$
公式推导
1. Q 值更新 Q 值通过 Bellman 方程更新,目标是最小化 TD 误差:
$$y = r + \gamma (1 - \text{done}) \cdot V_{\psi'}(s')$$
损失函数为:
$$J_Q = \mathbb{E}{(s, a, r, s') \sim D} \left[ \left( Q {\theta_i}(s, a) - y \right)^2 \right]$$
2. 值函数更新 值函数估计策略的长期价值,目标值为:
$$y_V = \mathbb{E}{a \sim \pi} \left[ \min {i=1,2} Q_{\theta_i}(s, a) - \alpha \log \pi_\phi(a|s) \right]$$
损失函数为:
$$J_V = \mathbb{E}{s \sim D} \left[ \left( V \psi(s) - y_V \right)^2 \right]$$
3. 策略网络更新 策略网络的目标是最大化奖励和熵,等价于最小化:
$$J_\pi = \mathbb{E}{s \sim D, a \sim \pi} \left[ \alpha \log \pi \phi(a|s) - \min_{i=1,2} Q_{\theta_i}(s, a) \right]$$
4. 目标值函数更新 目标值函数使用软更新规则:
$$\psi' \gets \tau \psi + (1 - \tau) \psi'$$
其中 $\tau \in (0, 1]$ 控制更新步长。
Python 实现 以下是基于 PyTorch 的 Soft Actor-Critic (SAC) 算法完整实现。代码涵盖了从环境配置、网络定义到训练循环的全过程。
1. 参数设置 """
《SAC, Soft Actor-Critic 算法》
时间:2024.12
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
LR = 0.001
BATCH_SIZE = 256
MEMORY_CAPACITY = 100000
2. 策略网络 策略网络负责生成随机的策略动作。这里使用了重参数化技巧来保证梯度可导。
class PolicyNetwork (nn.Module):
def __init__ (self, state_dim, action_dim, max_action ):
super (PolicyNetwork, self ).__init__()
self .fc1 = nn.Linear(state_dim, 256 )
self .fc2 = nn.Linear(256 , 256 )
self .mean = nn.Linear(256 , action_dim)
self .log_std = nn.Linear(256 , action_dim)
self .max_action = max_action
def forward (self, state ):
x = torch.relu(self .fc1(state))
x = torch.relu(self .fc2(x))
mean = self .mean(x)
log_std = self .log_std(x).clamp(-20 , 2 )
std = torch.exp(log_std)
return mean, std
def sample (self, state ):
mean, std = self .forward(state)
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample()
y_t = torch.tanh(x_t)
action = y_t * self .max_action
log_prob = normal.log_prob(x_t)
log_prob -= torch.log(1 - y_t.pow (2 ) + 1e-6 )
log_prob = log_prob.sum (dim=-1 , keepdim=True )
return action, log_prob
3. Q 网络 Q 网络用于评估状态 - 动作对的价值。SAC 通常使用两个独立的 Q 网络来减少过估计偏差。
class QNetwork (nn.Module):
def __init__ (self, state_dim, action_dim ):
super (QNetwork, self ).__init__()
self .fc1 = nn.Linear(state_dim + action_dim, 256 )
self .fc2 = nn.Linear(256 , 256 )
self .fc3 = nn.Linear(256 , 1 )
def forward (self, state, action ):
x = torch.cat([state, action], dim=-1 )
x = torch.relu(self .fc1(x))
x = torch.relu(self .fc2(x))
x = self .fc3(x)
return x
4. 经验回放缓冲区 class ReplayBuffer :
def __init__ (self, capacity ):
self .buffer = deque(maxlen=capacity)
def push (self, state, action, reward, next_state, done ):
self .buffer.append((state, action, reward, next_state, done))
def sample (self, batch_size ):
batch = random.sample(self .buffer, batch_size)
states, actions, rewards, next_states, dones = zip (*batch)
return (
np.array(states),
np.array(actions),
np.array(rewards),
np.array(next_states),
np.array(dones)
)
def __len__ (self ):
return len (self .buffer)
5. SAC 智能体 class SACAgent :
def __init__ (self, state_dim, action_dim, max_action ):
self .device = torch.device("cuda" if torch.cuda.is_available() else "cpu" )
self .actor = PolicyNetwork(state_dim, action_dim, max_action).to(self .device)
self .q1 = QNetwork(state_dim, action_dim).to(self .device)
self .q2 = QNetwork(state_dim, action_dim).to(self .device)
self .actor_optimizer = optim.Adam(self .actor.parameters(), lr=LR)
self .q1_optimizer = optim.Adam(self .q1.parameters(), lr=LR)
self .q2_optimizer = optim.Adam(self .q2.parameters(), lr=LR)
self .replay_buffer = ReplayBuffer(MEMORY_CAPACITY)
self .max_action = max_action
def select_action (self, state ):
state = torch.FloatTensor(state).to(self .device).unsqueeze(0 )
action, _ = self .actor.sample(state)
return action.cpu().detach().numpy()[0 ]
def train (self ):
if len (self .replay_buffer) < BATCH_SIZE:
return
states, actions, rewards, next_states, dones = self .replay_buffer.sample(BATCH_SIZE)
states = torch.FloatTensor(states).to(self .device)
actions = torch.FloatTensor(actions).to(self .device)
rewards = torch.FloatTensor(rewards).unsqueeze(1 ).to(self .device)
next_states = torch.FloatTensor(next_states).to(self .device)
dones = torch.FloatTensor(dones).unsqueeze(1 ).to(self .device)
with torch.no_grad():
next_actions, log_probs = self .actor.sample(next_states)
target_q1 = self .q1(next_states, next_actions)
target_q2 = self .q2(next_states, next_actions)
target_q = torch.min (target_q1, target_q2) - ALPHA * log_probs
q_target = rewards + GAMMA * (1 - dones) * target_q
q1_loss = ((self .q1(states, actions) - q_target) ** 2 ).mean()
q2_loss = ((self .q2(states, actions) - q_target) ** 2 ).mean()
self .q1_optimizer.zero_grad()
q1_loss.backward()
self .q1_optimizer.step()
self .q2_optimizer.zero_grad()
q2_loss.backward()
self .q2_optimizer.step()
new_actions, log_probs = self .actor.sample(states)
q1_new = self .q1(states, new_actions)
q2_new = self .q2(states, new_actions)
q_new = torch.min (q1_new, q2_new)
actor_loss = (ALPHA * log_probs - q_new).mean()
self .actor_optimizer.zero_grad()
actor_loss.backward()
self .actor_optimizer.step()
def update_replay_buffer (self, state, action, reward, next_state, done ):
self .replay_buffer.push(state, action, reward, next_state, done)
6. 主函数循环 env = gym.make("Pendulum-v1" )
state_dim = env.observation_space.shape[0 ]
action_dim = env.action_space.shape[0 ]
max_action = float (env.action_space.high[0 ])
agent = SACAgent(state_dim, action_dim, max_action)
num_episodes = 500
for episode in range (num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.update_replay_buffer(state, action, reward, next_state, done)
agent.train()
state = next_state
episode_reward += reward
print (f"Episode {episode} , Reward: {episode_reward} " )
注意事项 本代码旨在展示 SAC 算法的核心逻辑与结构。实际应用中,建议根据具体任务调整超参数(如 ALPHA, LR),并进行适当的调优。此外,确保运行环境兼容(Python 3.8+, PyTorch 1.7+)。
SAC 优势
样本效率高 :使用离线经验池,充分利用历史数据。
探索能力强 :通过最大化熵,鼓励更广泛的探索。
稳定性好 :结合双 Q 网络和目标网络,降低训练波动。
适用于连续动作空间 :支持复杂控制任务,如机器人控制。
Haarnoja, Tuomas, et al. "Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor." arXiv preprint arXiv:1801.01290 (2018).
Haarnoja, Tuomas, et al. "Soft actor-critic algorithms and applications." arXiv preprint arXiv:1812.05905 (2018).
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online