import argparse import os import random import time import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.distributions.normal import Normal from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py')) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--torch_deterministic', type=bool, default=True) parser.add_argument('--cuda', type=bool, default=True) parser.add_argument('--env_id', type=str, default='Ant-v4') parser.add_argument('--total_time_steps', type=int, default=int(1e7)) parser.add_argument('--learning_rate', type=float, default=3e-4) parser.add_argument('--num_envs', type=int, default=8) parser.add_argument('--num_steps', type=int, default=256) parser.add_argument('--anneal_lr', type=bool, default=True) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gae_lambda', type=float, default=0.95) parser.add_argument('--num_mini_batches', type=int, default=4) parser.add_argument('--update_epochs', type=int, default=10) parser.add_argument('--norm_adv', type=bool, default=True) parser.add_argument('--clip_value_loss', type=bool, default=True) parser.add_argument('--c_1', type=float, default=0.5) parser.add_argument('--c_2', type=float, default=0.0) parser.add_argument('--max_grad_norm', type=float, default=0.5) parser.add_argument('--kld_max', type=float, default=0.02) a = parser.parse_args() a.batch_size = int(a.num_envs * a.num_steps) a.minibatch_size = int(a.batch_size // a.num_mini_batches) return a def make_env(env_id, gamma): def thunk(): env = gym.make(env_id) env = gym.wrappers.FlattenObservation(env) env = gym.wrappers.RecordEpisodeStatistics(env) env = gym.wrappers.ClipAction(env) env = gym.wrappers.NormalizeObservation(env) env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10)) env = gym.wrappers.NormalizeReward(env, gamma=gamma) env = gym.wrappers.TransformReward(env, lambda r: float(np.clip(r, -10, 10))) return env return thunk def layer_init(layer, s=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, s) torch.nn.init.constant_(layer.bias, bias_const) return layer class Agent(nn.Module): def __init__(self, e): super().__init__() self.critic = nn.Sequential( layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)), nn.Tanh(), layer_init(nn.Linear(64, 64)), nn.Tanh(), layer_init(nn.Linear(64, 1), s=1.0), ) self.actor_mean = nn.Sequential( layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)), nn.Tanh(), layer_init(nn.Linear(64, 64)), nn.Tanh(), layer_init(nn.Linear(64, np.array(e.single_action_space.shape).prod()), s=0.01), ) self.actor_log_std = nn.Parameter(torch.zeros(1, np.array(e.single_action_space.shape).prod())) def get_value(self, x): return self.critic(x) def get_action_and_value(self, x, a=None, show_all=False): action_mean = self.actor_mean(x) action_log_std = self.actor_log_std.expand_as(action_mean) action_std = torch.exp(action_log_std) probs = Normal(action_mean, action_std) if a is None: a = probs.sample() if show_all: return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x), probs return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x) def compute_kld(mu_1, sigma_1, mu_2, sigma_2): return torch.log(sigma_2 / sigma_1) + ((mu_1 - mu_2) ** 2 + (sigma_1 ** 2 - sigma_2 ** 2)) / (2 * sigma_2 ** 2) def main(env_id, seed): args = get_args() args.env_id = env_id args.seed = seed run_name = ( 'spo' + '_epoch_' + str(args.update_epochs) + '_seed_' + str(args.seed) ) # 保存训练日志 path_string = str(args.env_id) + '/' + run_name writer = SummaryWriter(path_string) writer.add_text( 'Hyperparameter', '|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])), ) # 随机种子 random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu') # 初始化环境 envs = gym.vector.SyncVectorEnv( [make_env(args.env_id, args.gamma) for _ in range(args.num_envs)] ) assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported' agent = Agent(envs).to(device) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) # 初始化存储 obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) log_probs = torch.zeros((args.num_steps, args.num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) dones = torch.zeros((args.num_steps, args.num_envs)).to(device) values = torch.zeros((args.num_steps, args.num_envs)).to(device) mean = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) std = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) # 开始收集数据 global_step = 0 start_time = time.time() next_obs, _ = envs.reset(seed=args.seed) next_obs = torch.Tensor(next_obs).to(device) next_done = torch.zeros(args.num_envs).to(device) num_updates = args.total_time_steps // args.batch_size for update in tqdm(range(1, num_updates + 1)): # 学习率线性递减 if args.anneal_lr: frac = 1.0 - (update - 1.0) / num_updates lr_now = frac * args.learning_rate optimizer.param_groups[0]['lr'] = lr_now for step in range(0, args.num_steps): global_step += 1 * args.num_envs obs[step] = next_obs dones[step] = next_done # 计算旧的策略网络输出动作概率分布的对数 with torch.no_grad(): action, log_prob, _, value, mean_std = agent.get_action_and_value(next_obs, show_all=True) values[step] = value.flatten() actions[step] = action log_probs[step] = log_prob # 均值和方差:(2048, num_envs, action_space_n) mean[step] = mean_std.loc std[step] = mean_std.scale # 更新环境 next_obs, reward, terminations, truncations, info = envs.step(action.cpu().numpy()) done = np.logical_or(terminations, truncations) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) if 'final_info' not in info: continue for item in info['final_info']: if item is None: continue # print(f"global_step={global_step}, episodic_return={float(item['episode']['r'][0])}") writer.add_scalar('charts/episodic_return', item['episode']['r'][0], global_step) # GAE估计优势函数 with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) last_gae_lam = 0 for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: next_non_terminal = 1.0 - next_done next_values = next_value else: next_non_terminal = 1.0 - dones[t + 1] next_values = values[t + 1] delta = rewards[t] + args.gamma * next_values * next_non_terminal - values[t] advantages[t] = last_gae_lam = delta + args.gamma * args.gae_lambda * next_non_terminal * last_gae_lam returns = advantages + values # ------------------------------- 上面收集了足够的数据,下面开始更新 ------------------------------- # # 将每个batch展平 b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) b_log_probs = log_probs.reshape(-1) b_actions = actions.reshape((-1,) + envs.single_action_space.shape) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) b_values = values.reshape(-1) # 得到一个batch中的均值和方差 b_mean = mean.reshape(args.batch_size, -1) b_std = std.reshape(args.batch_size, -1) # 更新策略网络和价值网络 b_index = np.arange(args.batch_size) for epoch in range(1, args.update_epochs + 1): np.random.shuffle(b_index) t = 0 for start in range(0, args.batch_size, args.minibatch_size): t += 1 end = start + args.minibatch_size mb_index = b_index[start:end] # 得到最新的策略网络和价值网络输出 _, new_log_prob, entropy, new_value, new_mean_std = agent.get_action_and_value(b_obs[mb_index], b_actions[mb_index], show_all=True) # 计算平均kl散度 new_mean = new_mean_std.loc.reshape(args.minibatch_size, -1) new_std = new_mean_std.scale.reshape(args.minibatch_size, -1) d = compute_kld(b_mean[mb_index], b_std[mb_index], new_mean, new_std).sum(1) writer.add_scalar('charts/average_kld', d.mean(), global_step) writer.add_scalar('others/min_kld', d.min(), global_step) writer.add_scalar('others/max_kld', d.max(), global_step) log_ratio = new_log_prob - b_log_probs[mb_index] ratios = log_ratio.exp() mb_advantages = b_advantages[mb_index] if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-12) # 策略网络损失 if epoch == 1 and t == 1: pg_loss = (-mb_advantages * ratios).mean() else: # d_clip d_clip = torch.clamp(input=d, min=0, max=args.kld_max) # d_clip / d ratio = d_clip / (d + 1e-12) # sign_a sign_a = torch.sign(mb_advantages) # (d_clip / d + sign_a - 1) * sign_a result = (ratio + sign_a - 1) * sign_a pg_loss = (-mb_advantages * ratios * result).mean() # 价值网络损失 new_value = new_value.view(-1) if args.clip_value_loss: v_loss_un_clipped = (new_value - b_returns[mb_index]) ** 2 v_clipped = b_values[mb_index] + torch.clamp( new_value - b_values[mb_index], -0.2, 0.2, ) v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2 v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped) v_loss = 0.5 * v_loss_max.mean() else: v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean() # 策略网络输出概率分布的熵 entropy_loss = entropy.mean() # 总损失 loss = pg_loss + v_loss * args.c_1 - entropy_loss * args.c_2 # 写入训练过程中的数据 writer.add_scalar('losses/policy_loss', pg_loss.item(), global_step) writer.add_scalar('losses/value_loss', v_loss.item(), global_step) writer.add_scalar('losses/entropy', entropy_loss.item(), global_step) # 更新网络参数 optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() y_pre, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pre) / var_y writer.add_scalar('others/explained_variance', explained_var, global_step) # 写入训练过程中的数据 writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step) writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step) # print('SPS:', int(global_step / (time.time() - start_time))) envs.close() writer.close() def run(): for env_id in [ 'Humanoid-v4', 'Ant-v4', 'HalfCheetah-v4', 'Hopper-v4', 'Humanoid-v4', 'HumanoidStandup-v4', 'InvertedDoublePendulum-v4' ]: for seed in range(1, 4): print(env_id, 'seed:', seed) main(env_id, seed) if __name__ == '__main__': run()