import os import torch import pprint import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import QRDQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer from atari_network import QRDQN from atari_wrapper import wrap_deepmind def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--num-quantiles', type=int, default=200) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) parser.add_argument('--watch', default=False, action='store_true', help='watch the play of pre-trained policy only') return parser.parse_args() def make_atari_env(args): return wrap_deepmind(args.task, frame_stack=args.frames_stack) def make_atari_env_watch(args): return wrap_deepmind(args.task, frame_stack=args.frames_stack, episode_life=False, clip_rewards=False) def test_qrdqn(args=get_args()): env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments train_envs = SubprocVectorEnv([lambda: make_atari_env(args) for _ in range(args.training_num)]) test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = QRDQNPolicy( net, optim, args.gamma, args.num_quantiles, args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): if env.env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 else: return False def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) writer.add_scalar('train/eps', eps, global_step=env_step) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # watch agent's performance def watch(): print("Testing agent ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() if __name__ == '__main__': test_qrdqn(get_args())