diff --git a/ppo.py b/ppo.py deleted file mode 100644 index 44f3a0d..0000000 --- a/ppo.py +++ /dev/null @@ -1,338 +0,0 @@ -import argparse -import os -import time - -import gymnasium as gym -import numpy as np -import torch -from stable_baselines3.common.atari_wrappers import FireResetEnv, EpisodicLifeEnv, ClipRewardEnv -from torch import nn, optim -from torch.distributions import Categorical -from torch.nn.utils.clip_grad import clip_grad_norm_ -from torch.utils.tensorboard.writer import SummaryWriter -from tqdm import tqdm - -from game_list import atari_list - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py')) - parser.add_argument('--env_id', type=str, default='ALE/Breakout-v5') - parser.add_argument('--seed', type=int, default=1) - parser.add_argument('--use_cuda', type=bool, default=True) - parser.add_argument('--learning_rate', type=float, default=2.5e-4) - parser.add_argument('--lr_decay', type=bool, default=True) - parser.add_argument('--total_steps', type=int, default=int(1e7)) - parser.add_argument('--num_envs', type=int, default=8) - parser.add_argument('--num_steps', type=int, default=128) - parser.add_argument('--update_epochs', type=int, default=8) - parser.add_argument('--num_mini_batches', type=int, default=4) - parser.add_argument('--gae_lambda', type=float, default=0.95) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--clip_value_loss', type=bool, default=True) - parser.add_argument('--c_1', type=float, default=1.0) - parser.add_argument('--c_2', type=float, default=0.01) - parser.add_argument('--clip_grad_norm', type=float, default=0.5) - parser.add_argument('--clip_epsilon', type=float, default=0.2) - args = parser.parse_args() - args.device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu') - args.batch_size = int(args.num_envs * args.num_steps) - args.minibatch_size = int(args.batch_size // args.num_mini_batches) - args.num_updates = int(args.total_steps // args.batch_size) - return args - - -def make_env(env_id): - def thunk(): - env = gym.make(env_id, frameskip=1, repeat_action_probability=0.0, full_action_space=False) - env = gym.wrappers.RecordEpisodeStatistics(env) - if 'FIRE' in env.unwrapped.get_action_meanings(): - env = FireResetEnv(env) - env = EpisodicLifeEnv(env) - env = ClipRewardEnv(env) - env = gym.wrappers.AtariPreprocessing(env, scale_obs=True) - env = gym.wrappers.FrameStack(env, 4) - return env - return thunk - - -def compute_advantages(rewards, flags, values, last_value, args): - advantages = torch.zeros((args.num_steps, args.num_envs)).to(args.device) - adv = torch.zeros(args.num_envs).to(args.device) - for i in reversed(range(args.num_steps)): - returns = rewards[i] + args.gamma * flags[i] * last_value - delta = returns - values[i] - adv = delta + args.gamma * args.gae_lambda * flags[i] * adv - advantages[i] = adv - last_value = values[i] - return advantages - - -class Buffer: - def __init__(self, num_steps, num_envs, observation_shape, action_dim, device): - self.states = np.zeros((num_steps, num_envs, *observation_shape), dtype=np.float32) - self.actions = np.zeros((num_steps, num_envs), dtype=np.int64) - self.rewards = np.zeros((num_steps, num_envs), dtype=np.float32) - self.flags = np.zeros((num_steps, num_envs), dtype=np.float32) - self.log_probs = np.zeros((num_steps, num_envs), dtype=np.float32) - self.probs = np.zeros((num_steps, num_envs, action_dim), dtype=np.float32) - self.values = np.zeros((num_steps, num_envs), dtype=np.float32) - self.step = 0 - self.num_steps = num_steps - self.device = device - - def push(self, state, action, reward, flag, log_prob, prob, value): - self.states[self.step] = state - self.actions[self.step] = action - self.rewards[self.step] = reward - self.flags[self.step] = flag - self.log_probs[self.step] = log_prob - self.probs[self.step] = prob - self.values[self.step] = value - self.step = (self.step + 1) % self.num_steps - - def get(self): - return ( - torch.from_numpy(self.states).to(self.device), - torch.from_numpy(self.actions).to(self.device), - torch.from_numpy(self.rewards).to(self.device), - torch.from_numpy(self.flags).to(self.device), - torch.from_numpy(self.log_probs).to(self.device), - torch.from_numpy(self.values).to(self.device), - ) - - def get_probs(self): - return torch.from_numpy(self.probs).to(self.device) - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.orthogonal_(layer.weight, std) - torch.nn.init.constant_(layer.bias, bias_const) - return layer - - -class Agent(nn.Module): - def __init__(self, action_dim, device): - super().__init__() - self.encoder = nn.Sequential( - layer_init(nn.Conv2d(4, 32, 8, stride=4)), - nn.ReLU(), - layer_init(nn.Conv2d(32, 64, 4, stride=2)), - nn.ReLU(), - layer_init(nn.Conv2d(64, 64, 3, stride=1)), - nn.ReLU(), - nn.Flatten(), - layer_init(nn.Linear(64 * 7 * 7, 512)), - nn.ReLU(), - ) - self.actor_net = layer_init(nn.Linear(512, action_dim), std=0.01) - self.critic_net = layer_init(nn.Linear(512, 1), std=1) - - if device.type == 'cuda': - self.cuda() - - def forward(self, state): - hidden = self.encoder(state) - actor_value = self.actor_net(hidden) - distribution = Categorical(logits=actor_value) - action = distribution.sample() - log_prob = distribution.log_prob(action) - value = self.critic_net(hidden).squeeze(-1) - return action, log_prob, value, distribution.probs - - def evaluate(self, states, actions): - hidden = self.encoder(states) - actor_values = self.actor_net(hidden) - distribution = Categorical(logits=actor_values) - log_probs = distribution.log_prob(actions) - entropy = distribution.entropy() - values = self.critic_net(hidden).squeeze(-1) - return log_probs, values, entropy, distribution.probs - - def critic(self, state): - return self.critic_net(self.encoder(state)).squeeze(-1) - - -def train(env_id, seed): - args = get_args() - args.env_id = env_id - args.seed = seed - run_name = ( - 'ppo' + - '_epoch_' + str(args.update_epochs) + - '_seed_' + str(args.seed) - ) - - # 保存训练日志 - path_string = 'new_' + str(args.env_id)[4:] + '/' + run_name - writer = SummaryWriter(path_string) - writer.add_text( - 'Hyperparameter', - '|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])), - ) - - # 初始化并行环境 - envs = gym.vector.AsyncVectorEnv([make_env(args.env_id) for _ in range(args.num_envs)]) - - # 状态空间和动作空间 - observation_shape = envs.single_observation_space.shape - action_dim = envs.single_action_space.n - - # 随机数种子 - if args.seed: - numpy_rng = np.random.default_rng(args.seed) - torch.manual_seed(args.seed) - state, _ = envs.reset(seed=args.seed) - else: - numpy_rng = np.random.default_rng() - state, _ = envs.reset() - - # 价值网络和策略网络 - agent = Agent(action_dim, args.device) - - # 优化器 - optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate) - - # 存储数据的buffer - rollout_buffer = Buffer(args.num_steps, args.num_envs, observation_shape, action_dim, args.device) - global_step = 0 - start_time = time.time() - - # 开始收集数据 - for _ in tqdm(range(args.num_updates)): - - # 学习率线性递减 - if args.lr_decay: - optimizer.param_groups[0]['lr'] -= (args.learning_rate - 1e-12) / args.num_updates - - for _ in range(args.num_steps): - global_step += 1 * args.num_envs - - with torch.no_grad(): - action, log_prob, value, prob = agent(torch.from_numpy(state).to(args.device).float()) - - action = action.cpu().numpy() - next_state, reward, terminated, truncated, all_info = envs.step(action) - - # 存储数据 - flag = 1.0 - np.logical_or(terminated, truncated) - log_prob = log_prob.cpu().numpy() - prob = prob.cpu().numpy() - value = value.cpu().numpy() - rollout_buffer.push(state, action, reward, flag, log_prob, prob, value) - state = next_state - - if 'final_info' not in all_info: - continue - - # 写入训练过程的数据 - for info in all_info['final_info']: - if info is None: - continue - if 'episode' in info.keys(): - writer.add_scalar('charts/episodic_return', info['episode']['r'], global_step) - # print(float(info['episode']['r'])) - break - - # ------------------------------- 上面收集了足够的数据,下面开始更新 ------------------------------- # - states, actions, rewards, flags, log_probs, values = rollout_buffer.get() - probs = rollout_buffer.get_probs() - - with torch.no_grad(): - last_value = agent.critic(torch.from_numpy(next_state).to(args.device).float()) - - # 计算优势值和TD目标 - advantages = compute_advantages(rewards, flags, values, last_value, args) - td_target = advantages + values - - # 将数据展平 - states = states.reshape(-1, *observation_shape) - actions = actions.reshape(-1) - log_probs = log_probs.reshape(-1) - probs = probs.reshape((-1, action_dim)) - td_target = td_target.reshape(-1) - advantages = advantages.reshape(-1) - values = values.reshape(-1) - batch_indexes = np.arange(args.batch_size) - - # 更新策略网络和价值网络 - for e in range(1, args.update_epochs + 1): - numpy_rng.shuffle(batch_indexes) - for start in range(0, args.batch_size, args.minibatch_size): - end = start + args.minibatch_size - index = batch_indexes[start:end] - - # 得到最新的策略网络和价值网络输出 - new_log_probs, td_predict, entropy, new_probs = agent.evaluate(states[index], actions[index]) - log_ratio = new_log_probs - log_probs[index] - ratios = log_ratio.exp() - - # 计算kl散度 - d = torch.sum( - probs[index] * torch.log((probs[index] + 1e-12) / (new_probs + 1e-12)), 1 - ) - writer.add_scalar('charts/average_kld', d.mean(), global_step) - - # 优势值标准化 - b_advantages = advantages[index] - b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-12) - - # 策略网络损失 - policy_loss_1 = b_advantages * ratios - policy_loss_2 = b_advantages * torch.clamp( - ratios, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon - ) - policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() - - # 价值网络损失 - if args.clip_value_loss: - v_loss_un_clipped = (td_predict - td_target[index]) ** 2 - v_clipped = td_target[index] + torch.clamp( - td_predict - td_target[index], - -args.clip_epsilon, - args.clip_epsilon, - ) - v_loss_clipped = (v_clipped - td_target[index]) ** 2 - v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped) - value_loss = 0.5 * v_loss_max.mean() - else: - value_loss = 0.5 * ((td_predict - td_target[index]) ** 2).mean() - - entropy_loss = entropy.mean() - - # 保存训练过程中的一些数据 - writer.add_scalar('losses/value_loss', value_loss.item(), global_step) - writer.add_scalar('losses/policy_loss', policy_loss.item(), global_step) - writer.add_scalar('losses/entropy', entropy_loss.item(), global_step) - writer.add_scalar('losses/delta', torch.abs(ratios - 1).mean().item(), global_step) - - # 总的损失 - loss = policy_loss + value_loss * args.c_1 - entropy_loss * args.c_2 - - # 更新网络参数 - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(agent.parameters(), args.clip_grad_norm) - optimizer.step() - - explained_var = ( - np.nan if torch.var(td_target) == 0 else 1 - torch.var(td_target - values) / torch.var(td_target) - ) - writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step) - writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step) - writer.add_scalar('others/explained_var', explained_var, global_step) - - envs.close() - writer.close() - - -def main(): - for env_id in atari_list: - for seed in [1, 2, 3]: - print(env_id, seed) - train('ALE/' + env_id + '-v5', seed) - - -if __name__ == '__main__': - main()