Add files via upload
This commit is contained in:
parent
f6e04bd8df
commit
0fee4006f7
311
MuJoCo/ppo.py
Normal file
311
MuJoCo/ppo.py
Normal file
@ -0,0 +1,311 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py'))
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--torch_deterministic', type=bool, default=True)
|
||||
parser.add_argument('--cuda', type=bool, default=True)
|
||||
parser.add_argument('--env_id', type=str, default='Humanoid-v4')
|
||||
parser.add_argument('--total_time_steps', type=int, default=int(1e7))
|
||||
parser.add_argument('--learning_rate', type=float, default=3e-4)
|
||||
parser.add_argument('--num_envs', type=int, default=8)
|
||||
parser.add_argument('--num_steps', type=int, default=256)
|
||||
parser.add_argument('--anneal_lr', type=bool, default=True)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--gae_lambda', type=float, default=0.95)
|
||||
parser.add_argument('--num_mini_batches', type=int, default=4)
|
||||
parser.add_argument('--update_epochs', type=int, default=10)
|
||||
parser.add_argument('--norm_adv', type=bool, default=True)
|
||||
parser.add_argument('--clip_value_loss', type=bool, default=True)
|
||||
parser.add_argument('--c_1', type=float, default=0.5)
|
||||
parser.add_argument('--c_2', type=float, default=0.0)
|
||||
parser.add_argument('--max_grad_norm', type=float, default=0.5)
|
||||
parser.add_argument('--clip_epsilon', type=float, default=0.2)
|
||||
a = parser.parse_args()
|
||||
a.batch_size = int(a.num_envs * a.num_steps)
|
||||
a.minibatch_size = int(a.batch_size // a.num_mini_batches)
|
||||
return a
|
||||
|
||||
|
||||
def make_env(env_id, gamma):
|
||||
def thunk():
|
||||
env = gym.make(env_id)
|
||||
env = gym.wrappers.FlattenObservation(env)
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env)
|
||||
env = gym.wrappers.ClipAction(env)
|
||||
env = gym.wrappers.NormalizeObservation(env)
|
||||
env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10))
|
||||
env = gym.wrappers.NormalizeReward(env, gamma=gamma)
|
||||
env = gym.wrappers.TransformReward(env, lambda r: float(np.clip(r, -10, 10)))
|
||||
return env
|
||||
return thunk
|
||||
|
||||
|
||||
def layer_init(layer, s=np.sqrt(2), bias_const=0.0):
|
||||
torch.nn.init.orthogonal_(layer.weight, s)
|
||||
torch.nn.init.constant_(layer.bias, bias_const)
|
||||
return layer
|
||||
|
||||
|
||||
class Agent(nn.Module):
|
||||
def __init__(self, e):
|
||||
super().__init__()
|
||||
self.critic = nn.Sequential(
|
||||
layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, 1), s=1.0),
|
||||
)
|
||||
self.actor_mean = nn.Sequential(
|
||||
layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, np.array(e.single_action_space.shape).prod()), s=0.01),
|
||||
)
|
||||
self.actor_log_std = nn.Parameter(torch.zeros(1, np.array(e.single_action_space.shape).prod()))
|
||||
|
||||
def get_value(self, x):
|
||||
return self.critic(x)
|
||||
|
||||
def get_action_and_value(self, x, a=None, show_all=False):
|
||||
action_mean = self.actor_mean(x)
|
||||
action_log_std = self.actor_log_std.expand_as(action_mean)
|
||||
action_std = torch.exp(action_log_std)
|
||||
probs = Normal(action_mean, action_std)
|
||||
if a is None:
|
||||
a = probs.sample()
|
||||
if show_all:
|
||||
return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x), probs
|
||||
return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x)
|
||||
|
||||
|
||||
def compute_kld(mu_1, sigma_1, mu_2, sigma_2):
|
||||
return torch.log(sigma_2 / sigma_1) + ((mu_1 - mu_2) ** 2 + (sigma_1 ** 2 - sigma_2 ** 2)) / (2 * sigma_2 ** 2)
|
||||
|
||||
|
||||
def main(env_id, seed):
|
||||
args = get_args()
|
||||
args.env_id = env_id
|
||||
args.seed = seed
|
||||
run_name = (
|
||||
'ppo' +
|
||||
'_epoch_' + str(args.update_epochs) +
|
||||
'_seed_' + str(args.seed)
|
||||
)
|
||||
|
||||
# Save training logs
|
||||
path_string = str(args.env_id) + '/' + run_name
|
||||
writer = SummaryWriter(path_string)
|
||||
writer.add_text(
|
||||
'Hyperparameter',
|
||||
'|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])),
|
||||
)
|
||||
|
||||
# Random seed
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.backends.cudnn.deterministic = args.torch_deterministic
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
|
||||
|
||||
# Initialize environments
|
||||
envs = gym.vector.SyncVectorEnv(
|
||||
[make_env(args.env_id, args.gamma) for _ in range(args.num_envs)]
|
||||
)
|
||||
assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'
|
||||
|
||||
agent = Agent(envs).to(device)
|
||||
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
|
||||
|
||||
# Initialize buffer
|
||||
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
|
||||
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
||||
log_probs = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
mean = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
||||
std = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
||||
|
||||
# Data collection
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
next_obs, _ = envs.reset(seed=args.seed)
|
||||
next_obs = torch.Tensor(next_obs).to(device)
|
||||
next_done = torch.zeros(args.num_envs).to(device)
|
||||
num_updates = args.total_time_steps // args.batch_size
|
||||
|
||||
for update in tqdm(range(1, num_updates + 1)):
|
||||
|
||||
# Linear decay of learning rate
|
||||
if args.anneal_lr:
|
||||
frac = 1.0 - (update - 1.0) / num_updates
|
||||
lr_now = frac * args.learning_rate
|
||||
optimizer.param_groups[0]['lr'] = lr_now
|
||||
|
||||
for step in range(0, args.num_steps):
|
||||
global_step += 1 * args.num_envs
|
||||
obs[step] = next_obs
|
||||
dones[step] = next_done
|
||||
|
||||
# Compute the logarithm of the action probability output by the old policy network
|
||||
with torch.no_grad():
|
||||
action, log_prob, _, value, mean_std = agent.get_action_and_value(next_obs, show_all=True)
|
||||
values[step] = value.flatten()
|
||||
actions[step] = action
|
||||
log_probs[step] = log_prob
|
||||
|
||||
# Mean and standard deviation (mini_batch_size, num_envs, action_dim)
|
||||
mean[step] = mean_std.loc
|
||||
std[step] = mean_std.scale
|
||||
|
||||
# Update the environments
|
||||
next_obs, reward, terminations, truncations, info = envs.step(action.cpu().numpy())
|
||||
done = np.logical_or(terminations, truncations)
|
||||
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
||||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
|
||||
|
||||
if 'final_info' not in info:
|
||||
continue
|
||||
|
||||
for item in info['final_info']:
|
||||
if item is None:
|
||||
continue
|
||||
writer.add_scalar('charts/episodic_return', item['episode']['r'][0], global_step)
|
||||
|
||||
# Use GAE (Generalized Advantage Estimation) technique to estimate the advantage function
|
||||
with torch.no_grad():
|
||||
next_value = agent.get_value(next_obs).reshape(1, -1)
|
||||
advantages = torch.zeros_like(rewards).to(device)
|
||||
last_gae_lam = 0
|
||||
for t in reversed(range(args.num_steps)):
|
||||
if t == args.num_steps - 1:
|
||||
next_non_terminal = 1.0 - next_done
|
||||
next_values = next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - dones[t + 1]
|
||||
next_values = values[t + 1]
|
||||
delta = rewards[t] + args.gamma * next_values * next_non_terminal - values[t]
|
||||
advantages[t] = last_gae_lam = delta + args.gamma * args.gae_lambda * next_non_terminal * last_gae_lam
|
||||
returns = advantages + values
|
||||
|
||||
# ---------------------- We have collected enough data, now let's start training ---------------------- #
|
||||
# Flatten each batch
|
||||
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
|
||||
b_log_probs = log_probs.reshape(-1)
|
||||
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
|
||||
b_advantages = advantages.reshape(-1)
|
||||
b_returns = returns.reshape(-1)
|
||||
b_values = values.reshape(-1)
|
||||
|
||||
# Obtain the mean and the standard deviation of a batch
|
||||
b_mean = mean.reshape(args.batch_size, -1)
|
||||
b_std = std.reshape(args.batch_size, -1)
|
||||
|
||||
# Update the policy network and value network
|
||||
b_index = np.arange(args.batch_size)
|
||||
for epoch in range(1, args.update_epochs + 1):
|
||||
np.random.shuffle(b_index)
|
||||
t = 0
|
||||
for start in range(0, args.batch_size, args.minibatch_size):
|
||||
t += 1
|
||||
end = start + args.minibatch_size
|
||||
mb_index = b_index[start:end]
|
||||
|
||||
# The latest outputs of the policy network and value network
|
||||
_, new_log_prob, entropy, new_value, new_mean_std = agent.get_action_and_value(b_obs[mb_index],
|
||||
b_actions[mb_index],
|
||||
show_all=True)
|
||||
# Compute KL divergence
|
||||
new_mean = new_mean_std.loc.reshape(args.minibatch_size, -1)
|
||||
new_std = new_mean_std.scale.reshape(args.minibatch_size, -1)
|
||||
d = compute_kld(b_mean[mb_index], b_std[mb_index], new_mean, new_std).sum(1)
|
||||
|
||||
writer.add_scalar('charts/average_kld', d.mean(), global_step)
|
||||
writer.add_scalar('others/min_kld', d.min(), global_step)
|
||||
writer.add_scalar('others/max_kld', d.max(), global_step)
|
||||
|
||||
log_ratio = new_log_prob - b_log_probs[mb_index]
|
||||
ratios = log_ratio.exp()
|
||||
mb_advantages = b_advantages[mb_index]
|
||||
|
||||
# Advantage normalization
|
||||
if args.norm_adv:
|
||||
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-12)
|
||||
|
||||
# Policy loss
|
||||
pg_loss1 = -mb_advantages * ratios
|
||||
pg_loss2 = -mb_advantages * torch.clamp(ratios, 1 - args.clip_epsilon, 1 + args.clip_epsilon)
|
||||
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
||||
|
||||
# Value loss
|
||||
new_value = new_value.view(-1)
|
||||
if args.clip_value_loss:
|
||||
v_loss_un_clipped = (new_value - b_returns[mb_index]) ** 2
|
||||
v_clipped = b_values[mb_index] + torch.clamp(
|
||||
new_value - b_values[mb_index],
|
||||
-args.clip_epsilon,
|
||||
args.clip_epsilon,
|
||||
)
|
||||
v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
|
||||
v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped)
|
||||
v_loss = 0.5 * v_loss_max.mean()
|
||||
else:
|
||||
v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()
|
||||
|
||||
# Policy entropy
|
||||
entropy_loss = entropy.mean()
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss + v_loss * args.c_1 - entropy_loss * args.c_2
|
||||
|
||||
# Save the data during the training process
|
||||
writer.add_scalar('losses/policy_loss', pg_loss.item(), global_step)
|
||||
writer.add_scalar('losses/value_loss', v_loss.item(), global_step)
|
||||
writer.add_scalar('losses/entropy', entropy_loss.item(), global_step)
|
||||
|
||||
# Update network parameters
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
|
||||
y_pre, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
|
||||
var_y = np.var(y_true)
|
||||
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pre) / var_y
|
||||
writer.add_scalar('others/explained_variance', explained_var, global_step)
|
||||
|
||||
# Save the data during the training process
|
||||
writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
|
||||
writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)
|
||||
|
||||
envs.close()
|
||||
writer.close()
|
||||
|
||||
|
||||
def run():
|
||||
for env_id in ['Humanoid-v4']:
|
||||
for seed in range(1, 6):
|
||||
print(env_id, 'seed:', seed)
|
||||
main(env_id, seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
320
MuJoCo/spo.py
Normal file
320
MuJoCo/spo.py
Normal file
@ -0,0 +1,320 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py'))
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--torch_deterministic', type=bool, default=True)
|
||||
parser.add_argument('--cuda', type=bool, default=True)
|
||||
parser.add_argument('--env_id', type=str, default='Humanoid-v4')
|
||||
parser.add_argument('--total_time_steps', type=int, default=int(1e7))
|
||||
parser.add_argument('--learning_rate', type=float, default=3e-4)
|
||||
parser.add_argument('--num_envs', type=int, default=8)
|
||||
parser.add_argument('--num_steps', type=int, default=256)
|
||||
parser.add_argument('--anneal_lr', type=bool, default=True)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--gae_lambda', type=float, default=0.95)
|
||||
parser.add_argument('--num_mini_batches', type=int, default=4)
|
||||
parser.add_argument('--update_epochs', type=int, default=10)
|
||||
parser.add_argument('--norm_adv', type=bool, default=True)
|
||||
parser.add_argument('--clip_value_loss', type=bool, default=True)
|
||||
parser.add_argument('--c_1', type=float, default=0.5)
|
||||
parser.add_argument('--c_2', type=float, default=0.0)
|
||||
parser.add_argument('--max_grad_norm', type=float, default=0.5)
|
||||
parser.add_argument('--kld_max', type=float, default=0.02)
|
||||
a = parser.parse_args()
|
||||
a.batch_size = int(a.num_envs * a.num_steps)
|
||||
a.minibatch_size = int(a.batch_size // a.num_mini_batches)
|
||||
return a
|
||||
|
||||
|
||||
def make_env(env_id, gamma):
|
||||
def thunk():
|
||||
env = gym.make(env_id)
|
||||
env = gym.wrappers.FlattenObservation(env)
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env)
|
||||
env = gym.wrappers.ClipAction(env)
|
||||
env = gym.wrappers.NormalizeObservation(env)
|
||||
env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10))
|
||||
env = gym.wrappers.NormalizeReward(env, gamma=gamma)
|
||||
env = gym.wrappers.TransformReward(env, lambda r: float(np.clip(r, -10, 10)))
|
||||
return env
|
||||
return thunk
|
||||
|
||||
|
||||
def layer_init(layer, s=np.sqrt(2), bias_const=0.0):
|
||||
torch.nn.init.orthogonal_(layer.weight, s)
|
||||
torch.nn.init.constant_(layer.bias, bias_const)
|
||||
return layer
|
||||
|
||||
|
||||
class Agent(nn.Module):
|
||||
def __init__(self, e):
|
||||
super().__init__()
|
||||
self.critic = nn.Sequential(
|
||||
layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, 1), s=1.0),
|
||||
)
|
||||
self.actor_mean = nn.Sequential(
|
||||
layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, 64)),
|
||||
nn.Tanh(),
|
||||
layer_init(nn.Linear(64, np.array(e.single_action_space.shape).prod()), s=0.01),
|
||||
)
|
||||
self.actor_log_std = nn.Parameter(torch.zeros(1, np.array(e.single_action_space.shape).prod()))
|
||||
|
||||
def get_value(self, x):
|
||||
return self.critic(x)
|
||||
|
||||
def get_action_and_value(self, x, a=None, show_all=False):
|
||||
action_mean = self.actor_mean(x)
|
||||
action_log_std = self.actor_log_std.expand_as(action_mean)
|
||||
action_std = torch.exp(action_log_std)
|
||||
probs = Normal(action_mean, action_std)
|
||||
if a is None:
|
||||
a = probs.sample()
|
||||
if show_all:
|
||||
return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x), probs
|
||||
return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x)
|
||||
|
||||
|
||||
def compute_kld(mu_1, sigma_1, mu_2, sigma_2):
|
||||
return torch.log(sigma_2 / sigma_1) + ((mu_1 - mu_2) ** 2 + (sigma_1 ** 2 - sigma_2 ** 2)) / (2 * sigma_2 ** 2)
|
||||
|
||||
|
||||
def main(env_id, seed):
|
||||
args = get_args()
|
||||
args.env_id = env_id
|
||||
args.seed = seed
|
||||
run_name = (
|
||||
'spo' +
|
||||
'_epoch_' + str(args.update_epochs) +
|
||||
'_seed_' + str(args.seed)
|
||||
)
|
||||
|
||||
# Save training logs
|
||||
path_string = str(args.env_id) + '/' + run_name
|
||||
writer = SummaryWriter(path_string)
|
||||
writer.add_text(
|
||||
'Hyperparameter',
|
||||
'|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])),
|
||||
)
|
||||
|
||||
# Random seed
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.backends.cudnn.deterministic = args.torch_deterministic
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
|
||||
|
||||
# Initialize environments
|
||||
envs = gym.vector.SyncVectorEnv(
|
||||
[make_env(args.env_id, args.gamma) for _ in range(args.num_envs)]
|
||||
)
|
||||
assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'
|
||||
|
||||
agent = Agent(envs).to(device)
|
||||
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
|
||||
|
||||
# Initialize buffer
|
||||
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
|
||||
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
||||
log_probs = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
||||
mean = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
||||
std = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
||||
|
||||
# Data collection
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
next_obs, _ = envs.reset(seed=args.seed)
|
||||
next_obs = torch.Tensor(next_obs).to(device)
|
||||
next_done = torch.zeros(args.num_envs).to(device)
|
||||
num_updates = args.total_time_steps // args.batch_size
|
||||
|
||||
for update in tqdm(range(1, num_updates + 1)):
|
||||
|
||||
# Linear decay of learning rate
|
||||
if args.anneal_lr:
|
||||
frac = 1.0 - (update - 1.0) / num_updates
|
||||
lr_now = frac * args.learning_rate
|
||||
optimizer.param_groups[0]['lr'] = lr_now
|
||||
|
||||
for step in range(0, args.num_steps):
|
||||
global_step += 1 * args.num_envs
|
||||
obs[step] = next_obs
|
||||
dones[step] = next_done
|
||||
|
||||
# Compute the logarithm of the action probability output by the old policy network
|
||||
with torch.no_grad():
|
||||
action, log_prob, _, value, mean_std = agent.get_action_and_value(next_obs, show_all=True)
|
||||
values[step] = value.flatten()
|
||||
actions[step] = action
|
||||
log_probs[step] = log_prob
|
||||
|
||||
# Mean and standard deviation (mini_batch_size, num_envs, action_dim)
|
||||
mean[step] = mean_std.loc
|
||||
std[step] = mean_std.scale
|
||||
|
||||
# Update the environments
|
||||
next_obs, reward, terminations, truncations, info = envs.step(action.cpu().numpy())
|
||||
done = np.logical_or(terminations, truncations)
|
||||
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
||||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
|
||||
|
||||
if 'final_info' not in info:
|
||||
continue
|
||||
|
||||
for item in info['final_info']:
|
||||
if item is None:
|
||||
continue
|
||||
writer.add_scalar('charts/episodic_return', item['episode']['r'][0], global_step)
|
||||
|
||||
# Use GAE (Generalized Advantage Estimation) technique to estimate the advantage function
|
||||
with torch.no_grad():
|
||||
next_value = agent.get_value(next_obs).reshape(1, -1)
|
||||
advantages = torch.zeros_like(rewards).to(device)
|
||||
last_gae_lam = 0
|
||||
for t in reversed(range(args.num_steps)):
|
||||
if t == args.num_steps - 1:
|
||||
next_non_terminal = 1.0 - next_done
|
||||
next_values = next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - dones[t + 1]
|
||||
next_values = values[t + 1]
|
||||
delta = rewards[t] + args.gamma * next_values * next_non_terminal - values[t]
|
||||
advantages[t] = last_gae_lam = delta + args.gamma * args.gae_lambda * next_non_terminal * last_gae_lam
|
||||
returns = advantages + values
|
||||
|
||||
# ---------------------- We have collected enough data, now let's start training ---------------------- #
|
||||
# Flatten each batch
|
||||
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
|
||||
b_log_probs = log_probs.reshape(-1)
|
||||
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
|
||||
b_advantages = advantages.reshape(-1)
|
||||
b_returns = returns.reshape(-1)
|
||||
b_values = values.reshape(-1)
|
||||
|
||||
# Obtain the mean and the standard deviation of a batch
|
||||
b_mean = mean.reshape(args.batch_size, -1)
|
||||
b_std = std.reshape(args.batch_size, -1)
|
||||
|
||||
# Update the policy network and value network
|
||||
b_index = np.arange(args.batch_size)
|
||||
for epoch in range(1, args.update_epochs + 1):
|
||||
np.random.shuffle(b_index)
|
||||
t = 0
|
||||
for start in range(0, args.batch_size, args.minibatch_size):
|
||||
t += 1
|
||||
end = start + args.minibatch_size
|
||||
mb_index = b_index[start:end]
|
||||
|
||||
# The latest outputs of the policy network and value network
|
||||
_, new_log_prob, entropy, new_value, new_mean_std = agent.get_action_and_value(b_obs[mb_index],
|
||||
b_actions[mb_index],
|
||||
show_all=True)
|
||||
# Compute KL divergence
|
||||
new_mean = new_mean_std.loc.reshape(args.minibatch_size, -1)
|
||||
new_std = new_mean_std.scale.reshape(args.minibatch_size, -1)
|
||||
d = compute_kld(b_mean[mb_index], b_std[mb_index], new_mean, new_std).sum(1)
|
||||
|
||||
writer.add_scalar('charts/average_kld', d.mean(), global_step)
|
||||
writer.add_scalar('others/min_kld', d.min(), global_step)
|
||||
writer.add_scalar('others/max_kld', d.max(), global_step)
|
||||
|
||||
log_ratio = new_log_prob - b_log_probs[mb_index]
|
||||
ratios = log_ratio.exp()
|
||||
mb_advantages = b_advantages[mb_index]
|
||||
|
||||
# Advantage normalization
|
||||
if args.norm_adv:
|
||||
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-12)
|
||||
|
||||
# Policy loss (main code of SPO)
|
||||
if epoch == 1 and t == 1:
|
||||
pg_loss = (-mb_advantages * ratios).mean()
|
||||
else:
|
||||
# d_clip
|
||||
d_clip = torch.clamp(input=d, min=0, max=args.kld_max)
|
||||
# d_clip / d
|
||||
ratio = d_clip / (d + 1e-12)
|
||||
# sign_a
|
||||
sign_a = torch.sign(mb_advantages)
|
||||
# (d_clip / d + sign_a - 1) * sign_a
|
||||
result = (ratio + sign_a - 1) * sign_a
|
||||
pg_loss = (-mb_advantages * ratios * result).mean()
|
||||
|
||||
# Value loss
|
||||
new_value = new_value.view(-1)
|
||||
if args.clip_value_loss:
|
||||
v_loss_un_clipped = (new_value - b_returns[mb_index]) ** 2
|
||||
v_clipped = b_values[mb_index] + torch.clamp(
|
||||
new_value - b_values[mb_index],
|
||||
-0.2,
|
||||
0.2,
|
||||
)
|
||||
v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
|
||||
v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped)
|
||||
v_loss = 0.5 * v_loss_max.mean()
|
||||
else:
|
||||
v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()
|
||||
|
||||
# Policy entropy
|
||||
entropy_loss = entropy.mean()
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss + v_loss * args.c_1 - entropy_loss * args.c_2
|
||||
|
||||
# Save the data during the training process
|
||||
writer.add_scalar('losses/policy_loss', pg_loss.item(), global_step)
|
||||
writer.add_scalar('losses/value_loss', v_loss.item(), global_step)
|
||||
writer.add_scalar('losses/entropy', entropy_loss.item(), global_step)
|
||||
|
||||
# Update network parameters
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
|
||||
y_pre, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
|
||||
var_y = np.var(y_true)
|
||||
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pre) / var_y
|
||||
writer.add_scalar('others/explained_variance', explained_var, global_step)
|
||||
|
||||
# Save the data during the training process
|
||||
writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
|
||||
writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)
|
||||
|
||||
envs.close()
|
||||
writer.close()
|
||||
|
||||
|
||||
def run():
|
||||
for env_id in ['Humanoid-v4']:
|
||||
for seed in range(1, 6):
|
||||
print(env_id, 'seed:', seed)
|
||||
main(env_id, seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
Loading…
x
Reference in New Issue
Block a user