| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  | import datetime | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | import os | 
					
						
							|  |  |  | import pprint | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import torch | 
					
						
							|  |  |  | from atari_network import QRDQN | 
					
						
							| 
									
										
										
										
											2022-02-25 07:40:33 +08:00
										 |  |  | from atari_wrapper import make_atari_env | 
					
						
							| 
									
										
										
										
											2022-03-08 08:38:56 -05:00
										 |  |  | from torch.utils.tensorboard import SummaryWriter | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.data import Collector, VectorReplayBuffer | 
					
						
							|  |  |  | from tianshou.policy import QRDQNPolicy | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | from tianshou.trainer import offpolicy_trainer | 
					
						
							| 
									
										
										
										
											2022-03-08 08:38:56 -05:00
										 |  |  | from tianshou.utils import TensorboardLogger, WandbLogger | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_args(): | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser() | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |     parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") | 
					
						
							|  |  |  |     parser.add_argument("--seed", type=int, default=0) | 
					
						
							|  |  |  |     parser.add_argument("--scale-obs", type=int, default=0) | 
					
						
							|  |  |  |     parser.add_argument("--eps-test", type=float, default=0.005) | 
					
						
							|  |  |  |     parser.add_argument("--eps-train", type=float, default=1.) | 
					
						
							|  |  |  |     parser.add_argument("--eps-train-final", type=float, default=0.05) | 
					
						
							|  |  |  |     parser.add_argument("--buffer-size", type=int, default=100000) | 
					
						
							|  |  |  |     parser.add_argument("--lr", type=float, default=0.0001) | 
					
						
							|  |  |  |     parser.add_argument("--gamma", type=float, default=0.99) | 
					
						
							|  |  |  |     parser.add_argument("--num-quantiles", type=int, default=200) | 
					
						
							|  |  |  |     parser.add_argument("--n-step", type=int, default=3) | 
					
						
							|  |  |  |     parser.add_argument("--target-update-freq", type=int, default=500) | 
					
						
							|  |  |  |     parser.add_argument("--epoch", type=int, default=100) | 
					
						
							|  |  |  |     parser.add_argument("--step-per-epoch", type=int, default=100000) | 
					
						
							|  |  |  |     parser.add_argument("--step-per-collect", type=int, default=10) | 
					
						
							|  |  |  |     parser.add_argument("--update-per-step", type=float, default=0.1) | 
					
						
							|  |  |  |     parser.add_argument("--batch-size", type=int, default=32) | 
					
						
							|  |  |  |     parser.add_argument("--training-num", type=int, default=10) | 
					
						
							|  |  |  |     parser.add_argument("--test-num", type=int, default=10) | 
					
						
							|  |  |  |     parser.add_argument("--logdir", type=str, default="log") | 
					
						
							|  |  |  |     parser.add_argument("--render", type=float, default=0.) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |         "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |     parser.add_argument("--frames-stack", type=int, default=4) | 
					
						
							|  |  |  |     parser.add_argument("--resume-path", type=str, default=None) | 
					
						
							|  |  |  |     parser.add_argument("--resume-id", type=str, default=None) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |         "--logger", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="tensorboard", | 
					
						
							|  |  |  |         choices=["tensorboard", "wandb"], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument("--wandb-project", type=str, default="atari.benchmark") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--watch", | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         default=False, | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |         action="store_true", | 
					
						
							|  |  |  |         help="watch the play of pre-trained policy only" | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |     parser.add_argument("--save-buffer-name", type=str, default=None) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     return parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-25 07:40:33 +08:00
										 |  |  | def test_qrdqn(args=get_args()): | 
					
						
							|  |  |  |     env, train_envs, test_envs = make_atari_env( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         args.task, | 
					
						
							| 
									
										
										
										
											2022-02-25 07:40:33 +08:00
										 |  |  |         args.seed, | 
					
						
							|  |  |  |         args.training_num, | 
					
						
							|  |  |  |         args.test_num, | 
					
						
							|  |  |  |         scale=args.scale_obs, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         frame_stack=args.frames_stack, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     args.state_shape = env.observation_space.shape or env.observation_space.n | 
					
						
							| 
									
										
										
										
											2021-05-23 12:43:03 +08:00
										 |  |  |     args.action_shape = env.action_space.shape or env.action_space.n | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     # should be N_FRAMES x H x W | 
					
						
							|  |  |  |     print("Observations shape:", args.state_shape) | 
					
						
							|  |  |  |     print("Actions shape:", args.action_shape) | 
					
						
							|  |  |  |     # seed | 
					
						
							|  |  |  |     np.random.seed(args.seed) | 
					
						
							|  |  |  |     torch.manual_seed(args.seed) | 
					
						
							|  |  |  |     # define model | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     optim = torch.optim.Adam(net.parameters(), lr=args.lr) | 
					
						
							|  |  |  |     # define policy | 
					
						
							|  |  |  |     policy = QRDQNPolicy( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         net, | 
					
						
							|  |  |  |         optim, | 
					
						
							|  |  |  |         args.gamma, | 
					
						
							|  |  |  |         args.num_quantiles, | 
					
						
							|  |  |  |         args.n_step, | 
					
						
							|  |  |  |         target_update_freq=args.target_update_freq | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     ).to(args.device) | 
					
						
							|  |  |  |     # load a previous policy | 
					
						
							|  |  |  |     if args.resume_path: | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |         print("Loaded agent from: ", args.resume_path) | 
					
						
							|  |  |  |     # replay buffer: `save_last_obs` and `stack_num` can be removed together | 
					
						
							|  |  |  |     # when you have enough RAM | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     buffer = VectorReplayBuffer( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         args.buffer_size, | 
					
						
							|  |  |  |         buffer_num=len(train_envs), | 
					
						
							|  |  |  |         ignore_obs_next=True, | 
					
						
							|  |  |  |         save_only_last_obs=True, | 
					
						
							|  |  |  |         stack_num=args.frames_stack | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     # collector | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) | 
					
						
							|  |  |  |     test_collector = Collector(policy, test_envs, exploration_noise=True) | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     # log | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | 
					
						
							|  |  |  |     args.algo_name = "qrdqn" | 
					
						
							|  |  |  |     log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) | 
					
						
							|  |  |  |     log_path = os.path.join(args.logdir, log_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # logger | 
					
						
							|  |  |  |     if args.logger == "wandb": | 
					
						
							|  |  |  |         logger = WandbLogger( | 
					
						
							|  |  |  |             save_interval=1, | 
					
						
							|  |  |  |             name=log_name.replace(os.path.sep, "__"), | 
					
						
							|  |  |  |             run_id=args.resume_id, | 
					
						
							|  |  |  |             config=args, | 
					
						
							|  |  |  |             project=args.wandb_project, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     writer = SummaryWriter(log_path) | 
					
						
							| 
									
										
										
										
											2021-02-24 14:48:42 +08:00
										 |  |  |     writer.add_text("args", str(args)) | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |     if args.logger == "tensorboard": | 
					
						
							|  |  |  |         logger = TensorboardLogger(writer) | 
					
						
							|  |  |  |     else:  # wandb | 
					
						
							|  |  |  |         logger.load(writer) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-21 16:29:27 -04:00
										 |  |  |     def save_best_fn(policy): | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |         torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def stop_fn(mean_rewards): | 
					
						
							| 
									
										
										
										
											2021-05-28 18:44:23 -07:00
										 |  |  |         if env.spec.reward_threshold: | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |             return mean_rewards >= env.spec.reward_threshold | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |         elif "Pong" in args.task: | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |             return mean_rewards >= 20 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def train_fn(epoch, env_step): | 
					
						
							|  |  |  |         # nature DQN setting, linear decay in the first 1M steps | 
					
						
							|  |  |  |         if env_step <= 1e6: | 
					
						
							|  |  |  |             eps = args.eps_train - env_step / 1e6 * \ | 
					
						
							|  |  |  |                 (args.eps_train - args.eps_train_final) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             eps = args.eps_train_final | 
					
						
							|  |  |  |         policy.set_eps(eps) | 
					
						
							| 
									
										
										
										
											2021-09-09 00:51:39 +08:00
										 |  |  |         if env_step % 1000 == 0: | 
					
						
							|  |  |  |             logger.write("train/env_step", env_step, {"train/eps": eps}) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_fn(epoch, env_step): | 
					
						
							|  |  |  |         policy.set_eps(args.eps_test) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # watch agent's performance | 
					
						
							|  |  |  |     def watch(): | 
					
						
							| 
									
										
										
										
											2021-04-20 22:59:21 +08:00
										 |  |  |         print("Setup test envs ...") | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |         policy.eval() | 
					
						
							|  |  |  |         policy.set_eps(args.eps_test) | 
					
						
							|  |  |  |         test_envs.seed(args.seed) | 
					
						
							| 
									
										
										
										
											2021-04-20 22:59:21 +08:00
										 |  |  |         if args.save_buffer_name: | 
					
						
							|  |  |  |             print(f"Generate buffer with size {args.buffer_size}") | 
					
						
							|  |  |  |             buffer = VectorReplayBuffer( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 args.buffer_size, | 
					
						
							|  |  |  |                 buffer_num=len(test_envs), | 
					
						
							|  |  |  |                 ignore_obs_next=True, | 
					
						
							|  |  |  |                 save_only_last_obs=True, | 
					
						
							|  |  |  |                 stack_num=args.frames_stack | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             collector = Collector(policy, test_envs, buffer, exploration_noise=True) | 
					
						
							| 
									
										
										
										
											2021-04-20 22:59:21 +08:00
										 |  |  |             result = collector.collect(n_step=args.buffer_size) | 
					
						
							|  |  |  |             print(f"Save buffer into {args.save_buffer_name}") | 
					
						
							|  |  |  |             # Unfortunately, pickle will cause oom with 1M buffer size | 
					
						
							|  |  |  |             buffer.save_hdf5(args.save_buffer_name) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             print("Testing agent ...") | 
					
						
							|  |  |  |             test_collector.reset() | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             result = test_collector.collect( | 
					
						
							|  |  |  |                 n_episode=args.test_num, render=args.render | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-04-20 22:59:21 +08:00
										 |  |  |         rew = result["rews"].mean() | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  |         print(f"Mean reward (over {result['n/ep']} episodes): {rew}") | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if args.watch: | 
					
						
							|  |  |  |         watch() | 
					
						
							|  |  |  |         exit(0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # test train_collector and start filling replay buffer | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     train_collector.collect(n_step=args.batch_size * args.training_num) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     # trainer | 
					
						
							|  |  |  |     result = offpolicy_trainer( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         policy, | 
					
						
							|  |  |  |         train_collector, | 
					
						
							|  |  |  |         test_collector, | 
					
						
							|  |  |  |         args.epoch, | 
					
						
							|  |  |  |         args.step_per_epoch, | 
					
						
							|  |  |  |         args.step_per_collect, | 
					
						
							|  |  |  |         args.test_num, | 
					
						
							|  |  |  |         args.batch_size, | 
					
						
							|  |  |  |         train_fn=train_fn, | 
					
						
							|  |  |  |         test_fn=test_fn, | 
					
						
							|  |  |  |         stop_fn=stop_fn, | 
					
						
							| 
									
										
										
										
											2022-03-21 16:29:27 -04:00
										 |  |  |         save_best_fn=save_best_fn, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         logger=logger, | 
					
						
							|  |  |  |         update_per_step=args.update_per_step, | 
					
						
							| 
									
										
										
										
											2022-02-25 07:40:33 +08:00
										 |  |  |         test_in_train=False, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     pprint.pprint(result) | 
					
						
							|  |  |  |     watch() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-06 17:40:47 -05:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2021-01-28 09:27:05 +08:00
										 |  |  |     test_qrdqn(get_args()) |