"""《PPO 算法的代码》 时间:2024.12 环境:gym 作者:不去幼儿园 """
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gym
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorCritic, self).__init__()
self.shared_layer = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU()
)
self.actor = nn.Sequential(
nn.Linear(128, action_dim),
nn.Softmax(dim=-1)
)
self.critic = nn.Linear(128, 1)
def forward(self, state):
shared = self.shared_layer(state)
action_probs = self.actor(shared)
state_value = self.critic(shared)
return action_probs, state_value
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
def clear(self):
self.states = []
self.actions = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
class PPO:
def __init__(self, state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4):
self.policy = ActorCritic(state_dim, action_dim).to(device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
self.policy_old = ActorCritic(state_dim, action_dim).to(device)
self.policy_old.load_state_dict(self.policy.state_dict())
self.MseLoss = nn.MSELoss()
self.gamma = gamma
self.eps_clip = eps_clip
self.K_epochs = K_epochs
def select_action(self, state, memory):
state = torch.FloatTensor(state).to(device)
action_probs, _ = self.policy_old(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
return action.item()
def update(self, memory):
old_states = torch.stack(memory.states).to(device).detach()
old_actions = torch.stack(memory.actions).to(device).detach()
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
rewards = []
discounted_reward = 0
for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
if is_terminal:
discounted_reward = 0
discounted_reward = reward + (self.gamma * discounted_reward)
rewards.insert(0, discounted_reward)
rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
for _ in range(self.K_epochs):
action_probs, state_values = self.policy(old_states)
dist = Categorical(action_probs)
new_logprobs = dist.log_prob(old_actions)
entropy = dist.entropy()
ratios = torch.exp(new_logprobs - old_logprobs.detach())
advantages = rewards - state_values.detach().squeeze()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
loss_actor = -torch.min(surr1, surr2).mean()
loss_critic = self.MseLoss(state_values.squeeze(), rewards)
loss = loss_actor + 0.5 * loss_critic - 0.01 * entropy.mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.policy_old.load_state_dict(self.policy.state_dict())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo = PPO(state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4)
memory = Memory()
for episode in range(1, 1001):
state = env.reset()
total_reward = 0
for t in range(300):
action = ppo.select_action(state, memory)
state, reward, done, _ = env.step(action)
memory.rewards.append(reward)
memory.is_terminals.append(done)
total_reward += reward
if done:
break
ppo.update(memory)
memory.clear()
print(f"Episode {episode}, Total Reward: {total_reward}")
env.close()