2024-03-25 23:25:12 +08:00

330 lines
14 KiB
Python

import argparse
import os
import random
import time
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv
)
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py'))
parser.add_argument('--gym_id', type=str, default='BreakoutNoFrameskip-v4')
parser.add_argument('--learning_rate', type=float, default=2.5e-4)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--total_steps', type=int, default=int(1e7))
parser.add_argument('--use_cuda', type=bool, default=True)
parser.add_argument('--num_envs', type=int, default=8)
parser.add_argument('--num_steps', type=int, default=128)
parser.add_argument('--lr_decay', type=bool, default=True)
parser.add_argument('--use_gae', type=bool, default=True)
parser.add_argument('--gae_lambda', type=float, default=0.95)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num_mini_batches', type=int, default=4)
parser.add_argument('--update_epochs', type=int, default=8)
parser.add_argument('--norm_adv', type=bool, default=True)
parser.add_argument('--clip_value_loss', type=bool, default=True)
parser.add_argument('--c_1', type=float, default=1.0)
parser.add_argument('--c_2', type=float, default=0.01)
parser.add_argument('--max_grad_norm', type=float, default=0.5)
parser.add_argument('--kld_max', type=float, default=0.02)
a = parser.parse_args()
a.batch_size = int(a.num_envs * a.num_steps)
a.minibatch_size = int(a.batch_size // a.num_mini_batches)
return a
def make_env(gym_id, seed):
def thunk():
env = gym.make(gym_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
return thunk
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, e):
super(Agent, self).__init__()
self.network = nn.Sequential(
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(64 * 7 * 7, 512)),
nn.ReLU(),
)
self.actor = layer_init(nn.Linear(512, e.single_action_space.n), std=0.01)
self.critic = layer_init(nn.Linear(512, 1), std=1)
def get_value(self, x):
return self.critic(self.network(x / 255.0))
def get_action_and_value(self, x, a=None, show_all=False):
hidden = self.network(x / 255.0)
log = self.actor(hidden)
p = Categorical(logits=log)
if a is None:
a = p.sample()
if show_all:
return a, p.log_prob(a), p.entropy(), self.critic(hidden), p.probs
return a, p.log_prob(a), p.entropy(), self.critic(hidden)
def main(env_id, seed):
args = get_args()
args.gym_id = env_id
args.seed = seed
run_name = (
'spo_' + str(args.kld_max) +
'_epoch_' + str(args.update_epochs) +
'_seed_' + str(args.seed)
)
# 保存训练日志
path_string = str(args.gym_id).split('NoFrameskip')[0] + '/' + run_name
writer = SummaryWriter(path_string)
writer.add_text(
'Hyperparameter',
'|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])),
)
# 随机数种子
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# 初始化环境
device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu')
envs = gym.vector.SyncVectorEnv(
[make_env(args.gym_id, args.seed + i) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), 'only discrete action space is supported'
agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
# 初始化存储
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
probs = torch.zeros((args.num_steps, args.num_envs, envs.single_action_space.n)).to(device)
log_probs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
# 开始收集数据
global_step = 0
start_time = time.time()
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)
num_updates = int(args.total_steps // args.batch_size)
for update in tqdm(range(1, num_updates + 1)):
# 学习率是否衰减
if args.lr_decay:
frac = 1.0 - (update - 1.0) / num_updates
lr_now = frac * args.learning_rate
optimizer.param_groups[0]['lr'] = lr_now
for step in range(0, args.num_steps):
# 每一步更新步数还要乘上并行的环境数
global_step += 1 * args.num_envs
obs[step] = next_obs
dones[step] = next_done
# 计算旧的策略网络输出动作概率分布的对数
with torch.no_grad():
action, log_prob, _, value, prob = agent.get_action_and_value(next_obs, show_all=True)
values[step] = value.flatten()
actions[step] = action
probs[step] = prob
log_probs[step] = log_prob
# 更新环境
next_obs, reward, done, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
# 如果并行环境中有一个环境结束了,则输出总步数以及回合奖励和回合长度
for item in info:
if 'episode' in item.keys():
# print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
writer.add_scalar('charts/episodic_return', item['episode']['r'], global_step)
break
# 计算GAE
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
if args.use_gae:
advantages = torch.zeros_like(rewards).to(device)
last_gae_lam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_non_terminal = 1.0 - next_done
next_values = next_value
else:
next_non_terminal = 1.0 - dones[t + 1]
next_values = values[t + 1]
delta = rewards[t] + args.gamma * next_values * next_non_terminal - values[t]
advantages[t] = last_gae_lam = (
delta + args.gamma * args.gae_lambda * next_non_terminal * last_gae_lam
)
returns = advantages + values
else:
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_non_terminal = 1.0 - next_done
next_return = next_value
else:
next_non_terminal = 1.0 - dones[t + 1]
next_return = returns[t + 1]
returns[t] = rewards[t] + args.gamma * next_non_terminal * next_return
advantages = returns - values
# ------------------------------- 上面收集了足够的数据,下面开始更新 ------------------------------- #
# 将每个batch展平
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_probs = probs.reshape((-1, envs.single_action_space.n))
b_log_probs = log_probs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# 更新策略网络和价值网络
b_index = np.arange(args.batch_size)
for epoch in range(1, args.update_epochs + 1):
np.random.shuffle(b_index)
t = 0
for start in range(0, args.batch_size, args.minibatch_size):
t += 1
end = start + args.minibatch_size
mb_index = b_index[start:end]
# 得到最新的策略网络和价值网络输出
_, new_log_prob, entropy, new_value, new_probs = (
agent.get_action_and_value(b_obs[mb_index], b_actions.long()[mb_index], show_all=True)
)
# 计算kl散度
d = torch.sum(
b_probs[mb_index] * torch.log((b_probs[mb_index] + 1e-12) / (new_probs + 1e-12)), 1
)
writer.add_scalar('charts/average_kld', d.mean(), global_step)
writer.add_scalar('others/min_kld', d.min(), global_step)
writer.add_scalar('others/max_kld', d.max(), global_step)
log_ratio = new_log_prob - b_log_probs[mb_index]
ratios = log_ratio.exp()
# 优势值归一化
mb_advantages = b_advantages[mb_index]
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-12)
# 策略网络损失
new_value = new_value.view(-1)
if epoch == 1 and t == 1:
pg_loss = (-mb_advantages * ratios).mean()
else:
d_clip = torch.clamp(input=d, min=0, max=args.kld_max)
# d_clip / d
ratio = d_clip / (d + 1e-12)
# sign_a
sign_a = torch.sign(mb_advantages)
# (d_clip / d + sign_a - 1) * sign_a
result = (ratio + sign_a - 1) * sign_a
# 策略网络损失
pg_loss = (-mb_advantages * ratios * result).mean()
# 价值网络损失
new_value = new_value.view(-1)
if args.clip_value_loss:
v_loss_un_clipped = (new_value - b_returns[mb_index]) ** 2
v_clipped = b_values[mb_index] + torch.clamp(
new_value - b_values[mb_index],
-0.2,
0.2,
)
v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()
entropy_loss = entropy.mean()
# 损失
loss = pg_loss + v_loss * args.c_1 - entropy_loss * args.c_2
# 写入训练过程的一些数据
writer.add_scalar('losses/value_loss', v_loss.item(), global_step)
writer.add_scalar('losses/policy_loss', pg_loss.item(), global_step)
writer.add_scalar('losses/entropy', entropy_loss.item(), global_step)
writer.add_scalar('losses/delta', torch.abs(ratios - 1).mean().item(), global_step)
# 更新网络参数
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
y_pre, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pre) / var_y
# 写入训练过程的一些数据
writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar('others/explained_variance', explained_var, global_step)
writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)
envs.close()
writer.close()
def run():
for env_id in ['Breakout']:
for seed in [1, 2, 3]:
print(env_id + 'NoFrameskip-v4', 'seed:', seed)
main(env_id + 'NoFrameskip-v4', seed)
if __name__ == '__main__':
run()