跳到主要内容Soft Actor-Critic (SAC) 算法原理与 PyTorch 实现 | 极客日志PythonAI算法
Soft Actor-Critic (SAC) 算法原理与 PyTorch 实现
SAC 算法是一种基于最大熵原则的强化学习方法,专为连续动作空间设计。通过引入熵正则化项,它在最大化奖励的同时鼓励策略随机性,有效平衡了探索与利用。核心机制包括双 Q 网络减少过估计、目标网络稳定训练以及软更新策略。本文详细阐述了 SAC 的背景、数学推导及算法流程,并提供了完整的 PyTorch 实现代码,涵盖策略网络、价值网络、经验回放及训练循环,适合希望深入理解并复现该算法的开发者参考。
RedisGeek0 浏览 Soft Actor-Critic (SAC) 算法详解
Soft Actor-Critic (SAC) 是一种基于最大熵强化学习(Maximum Entropy RL)的先进算法,属于 Actor-Critic 方法的变体。它特别擅长处理连续动作空间,通过引入熵正则化项,有效解决了传统算法在探索与利用平衡、训练稳定性及样本效率方面的痛点。
SAC 背景与核心思想
1. 强化学习的挑战
在传统强化学习中,我们常面临以下问题:
- 探索与利用的平衡:初期难以充分探索新策略,后期又容易陷入局部最优。
- 不稳定性:连续动作空间中,训练容易出现发散或收敛缓慢。
- 样本效率:数据采集成本高,如何高效利用经验池数据是关键。
SAC 通过以下核心机制应对这些挑战:
- 最大熵强化学习:在最大化累计奖励的同时,最大化策略的随机性(熵),鼓励探索。
- 双 Q 网络:缓解 Q 值过估计问题。
- 目标网络:使用软更新的目标网络稳定 Q 值计算。
2. 最大熵强化学习的目标
传统强化学习的目标是最大化期望累计奖励:
$$J(\pi) = \mathbb{E}{\pi} \left[ \sum{t=0}^T \gamma^t r(s_t, a_t) \right]$$
SAC 则通过添加一个熵项,在奖励中加入策略随机性的权重,目标变为:
$$J(\pi) = \mathbb{E}{\pi} \left[ \sum{t=0}^T \gamma^t \left( r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot|s_t)) \right) \right]$$
其中:
- $\alpha$:熵系数,控制熵和奖励之间的平衡。
- $\mathcal{H}(\pi(\cdot|s_t)) = -\mathbb{E}_{a \sim \pi} [\log \pi(a|s_t)]$:表示策略的熵,鼓励策略更随机化。
效果:更好的探索能力使策略更加多样化,同时避免陷入次优策略,提升学习稳定性。
SAC 算法流程
SAC 使用了 Actor-Critic 框架,结合策略梯度和 Q 函数更新。以下是算法的关键步骤:
初始化
- 创建目标值函数网络 $V_{\psi'}$,并设置其参数为 $V_{\psi}$ 的初始值。
- 初始化策略网络 $\pi_\phi$ 和值函数网络 $V_\psi$。
- 初始化两组 Q 网络 $Q_{\theta_1}, Q_{\theta_2}$,用于计算 Q 值。
每一回合循环
- 采样动作:根据策略网络 $\pi_\phi$ 采样动作 $a \sim \pi(a|s)$。
- 执行动作:执行动作,记录 $(s, a, r, s', \text{done})$ 到经验池中。
- 更新 Q 网络:最小化 TD 误差,使用双 Q 网络取最小值作为目标。
- 更新值函数网络:逼近状态价值,最小化均方误差。
- 更新策略网络:最大化奖励和熵,等价于最小化特定损失函数。
- 更新目标值函数网络:使用软更新规则 $\psi' \gets \tau \psi + (1 - \tau) \psi'$。
公式推导
1. Q 值更新
Q 值通过 Bellman 方程更新,目标是最小化 TD 误差:
$$y = r + \gamma (1 - \text{done}) \cdot V_{\psi'}(s')$$
损失函数为:
$$J_Q = \mathbb{E}{(s, a, r, s') \sim D} \left[ \left( Q{\theta_i}(s, a) - y \right)^2 \right]$$
2. 值函数更新
值函数估计策略的长期价值,目标值为:
$$y_V = \mathbb{E}{a \sim \pi} \left[ \min{i=1,2} Q_{\theta_i}(s, a) - \alpha \log \pi_\phi(a|s) \right]$$
损失函数为:
$$J_V = \mathbb{E}{s \sim D} \left[ \left( V\psi(s) - y_V \right)^2 \right]$$
3. 策略网络更新
策略网络的目标是最大化奖励和熵,等价于最小化:
$$J_\pi = \mathbb{E}{s \sim D, a \sim \pi} \left[ \alpha \log \pi\phi(a|s) - \min_{i=1,2} Q_{\theta_i}(s, a) \right]$$
4. 目标值函数更新
目标值函数使用软更新规则:
$$\psi' \gets \tau \psi + (1 - \tau) \psi'$$
其中 $\tau \in (0, 1]$ 控制更新步长。
Python 实现
以下是基于 PyTorch 的 Soft Actor-Critic (SAC) 算法完整实现示例。代码结构清晰,便于理解各模块功能。
1. 参数设置
"""
SAC, Soft Actor-Critic 算法
时间:2024.12
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
LR = 0.001
BATCH_SIZE = 256
MEMORY_CAPACITY = 100000
2. 策略网络
策略网络用于生成随机的策略动作,采用重参数化技巧以支持梯度回传。
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.mean = nn.Linear(256, action_dim)
self.log_std = nn.Linear(256, action_dim)
self.max_action = max_action
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
mean = self.mean(x)
log_std = self.log_std(x).clamp(-20, 2)
std = torch.exp(log_std)
return mean, std
def sample(self, state):
mean, std = self.forward(state)
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample()
y_t = torch.tanh(x_t)
action = y_t * self.max_action
log_prob = normal.log_prob(x_t)
log_prob -= torch.log(1 - y_t.pow(2) + 1e-6)
log_prob = log_prob.sum(dim=-1, keepdim=True)
return action, log_prob
3. Q 网络
价值函数网络用于评估状态 - 动作对的价值,通常使用两个网络以减少过估计偏差。
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
4. 经验回放缓冲区
存储并采样过往经验,打破数据相关性,提升训练稳定性。
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions), np.array(rewards),
np.array(next_states), np.array(dones))
def __len__(self):
return len(self.buffer)
5. SAC 算法智能体
class SACAgent:
def __init__(self, state_dim, action_dim, max_action):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.actor = PolicyNetwork(state_dim, action_dim, max_action).to(self.device)
self.q1 = QNetwork(state_dim, action_dim).to(self.device)
self.q2 = QNetwork(state_dim, action_dim).to(self.device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LR)
self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=LR)
self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=LR)
self.replay_buffer = ReplayBuffer(MEMORY_CAPACITY)
self.max_action = max_action
def select_action(self, state):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
action, _ = self.actor.sample(state)
return action.cpu().detach().numpy()[0]
def train(self):
if len(self.replay_buffer) < BATCH_SIZE:
return
states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)
states = torch.FloatTensor(states).to(self.device)
actions = torch.FloatTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
with torch.no_grad():
next_actions, log_probs = self.actor.sample(next_states)
target_q1 = self.q1(next_states, next_actions)
target_q2 = self.q2(next_states, next_actions)
target_q = torch.min(target_q1, target_q2) - ALPHA * log_probs
q_target = rewards + GAMMA * (1 - dones) * target_q
q1_loss = ((self.q1(states, actions) - q_target) ** 2).mean()
q2_loss = ((self.q2(states, actions) - q_target) ** 2).mean()
self.q1_optimizer.zero_grad()
q1_loss.backward()
self.q1_optimizer.step()
self.q2_optimizer.zero_grad()
q2_loss.backward()
self.q2_optimizer.step()
new_actions, log_probs = self.actor.sample(states)
q1_new = self.q1(states, new_actions)
q2_new = self.q2(states, new_actions)
q_new = torch.min(q1_new, q2_new)
actor_loss = (ALPHA * log_probs - q_new).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
def update_replay_buffer(self, state, action, reward, next_state, done):
self.replay_buffer.push(state, action, reward, next_state, done)
6. 主函数循环
env = gym.make("Pendulum-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
agent = SACAgent(state_dim, action_dim, max_action)
num_episodes = 500
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.update_replay_buffer(state, action, reward, next_state, done)
agent.train()
state = next_state
episode_reward += reward
print(f"Episode {episode}, Reward: {episode_reward}")
SAC 优势总结
- 样本效率高:利用离线经验池,充分利用历史数据。
- 探索能力强:通过最大化熵,鼓励更广泛的探索,不易陷入局部最优。
- 稳定性好:结合双 Q 网络和目标网络,显著降低训练波动。
- 适用于连续动作空间:非常适合机器人控制等复杂任务。
该算法由 Haarnoja 等人提出,是深度强化学习领域的经典之作。实际应用中需根据具体环境调整超参数以获得最佳性能。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online