Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			229 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			229 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import datetime
 | |
| import os
 | |
| import pprint
 | |
| import sys
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| from env import make_vizdoom_env
 | |
| from network import C51
 | |
| from torch.utils.tensorboard import SummaryWriter
 | |
| 
 | |
| from tianshou.data import Collector, VectorReplayBuffer
 | |
| from tianshou.policy import C51Policy
 | |
| from tianshou.trainer import OffpolicyTrainer
 | |
| from tianshou.utils import TensorboardLogger, WandbLogger
 | |
| 
 | |
| 
 | |
| def get_args():
 | |
|     parser = argparse.ArgumentParser()
 | |
|     parser.add_argument("--task", type=str, default="D1_basic")
 | |
|     parser.add_argument("--seed", type=int, default=0)
 | |
|     parser.add_argument("--eps-test", type=float, default=0.005)
 | |
|     parser.add_argument("--eps-train", type=float, default=1.0)
 | |
|     parser.add_argument("--eps-train-final", type=float, default=0.05)
 | |
|     parser.add_argument("--buffer-size", type=int, default=2000000)
 | |
|     parser.add_argument("--lr", type=float, default=0.0001)
 | |
|     parser.add_argument("--gamma", type=float, default=0.99)
 | |
|     parser.add_argument("--num-atoms", type=int, default=51)
 | |
|     parser.add_argument("--v-min", type=float, default=-10.0)
 | |
|     parser.add_argument("--v-max", type=float, default=10.0)
 | |
|     parser.add_argument("--n-step", type=int, default=3)
 | |
|     parser.add_argument("--target-update-freq", type=int, default=500)
 | |
|     parser.add_argument("--epoch", type=int, default=300)
 | |
|     parser.add_argument("--step-per-epoch", type=int, default=100000)
 | |
|     parser.add_argument("--step-per-collect", type=int, default=10)
 | |
|     parser.add_argument("--update-per-step", type=float, default=0.1)
 | |
|     parser.add_argument("--batch-size", type=int, default=64)
 | |
|     parser.add_argument("--training-num", type=int, default=10)
 | |
|     parser.add_argument("--test-num", type=int, default=100)
 | |
|     parser.add_argument("--logdir", type=str, default="log")
 | |
|     parser.add_argument("--render", type=float, default=0.0)
 | |
|     parser.add_argument(
 | |
|         "--device",
 | |
|         type=str,
 | |
|         default="cuda" if torch.cuda.is_available() else "cpu",
 | |
|     )
 | |
|     parser.add_argument("--frames-stack", type=int, default=4)
 | |
|     parser.add_argument("--skip-num", type=int, default=4)
 | |
|     parser.add_argument("--resume-path", type=str, default=None)
 | |
|     parser.add_argument("--resume-id", type=str, default=None)
 | |
|     parser.add_argument(
 | |
|         "--logger",
 | |
|         type=str,
 | |
|         default="tensorboard",
 | |
|         choices=["tensorboard", "wandb"],
 | |
|     )
 | |
|     parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark")
 | |
|     parser.add_argument(
 | |
|         "--watch",
 | |
|         default=False,
 | |
|         action="store_true",
 | |
|         help="watch the play of pre-trained policy only",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--save-lmp",
 | |
|         default=False,
 | |
|         action="store_true",
 | |
|         help="save lmp file for replay whole episode",
 | |
|     )
 | |
|     parser.add_argument("--save-buffer-name", type=str, default=None)
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def test_c51(args=get_args()):
 | |
|     # make environments
 | |
|     env, train_envs, test_envs = make_vizdoom_env(
 | |
|         args.task,
 | |
|         args.skip_num,
 | |
|         (args.frames_stack, 84, 84),
 | |
|         args.save_lmp,
 | |
|         args.seed,
 | |
|         args.training_num,
 | |
|         args.test_num,
 | |
|     )
 | |
|     args.state_shape = env.observation_space.shape
 | |
|     args.action_shape = env.action_space.shape or env.action_space.n
 | |
|     # should be N_FRAMES x H x W
 | |
|     print("Observations shape:", args.state_shape)
 | |
|     print("Actions shape:", args.action_shape)
 | |
|     # seed
 | |
|     np.random.seed(args.seed)
 | |
|     torch.manual_seed(args.seed)
 | |
|     # define model
 | |
|     net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
 | |
|     optim = torch.optim.Adam(net.parameters(), lr=args.lr)
 | |
|     # define policy
 | |
|     policy = C51Policy(
 | |
|         model=net,
 | |
|         optim=optim,
 | |
|         discount_factor=args.gamma,
 | |
|         action_space=env.action_space,
 | |
|         num_atoms=args.num_atoms,
 | |
|         v_min=args.v_min,
 | |
|         v_max=args.v_max,
 | |
|         estimation_step=args.n_step,
 | |
|         target_update_freq=args.target_update_freq,
 | |
|     ).to(args.device)
 | |
|     # load a previous policy
 | |
|     if args.resume_path:
 | |
|         policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
 | |
|         print("Loaded agent from: ", args.resume_path)
 | |
|     # replay buffer: `save_last_obs` and `stack_num` can be removed together
 | |
|     # when you have enough RAM
 | |
|     buffer = VectorReplayBuffer(
 | |
|         args.buffer_size,
 | |
|         buffer_num=len(train_envs),
 | |
|         ignore_obs_next=True,
 | |
|         save_only_last_obs=True,
 | |
|         stack_num=args.frames_stack,
 | |
|     )
 | |
|     # collector
 | |
|     train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
 | |
|     test_collector = Collector(policy, test_envs, exploration_noise=True)
 | |
| 
 | |
|     # log
 | |
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
 | |
|     args.algo_name = "c51"
 | |
|     log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
 | |
|     log_path = os.path.join(args.logdir, log_name)
 | |
| 
 | |
|     # logger
 | |
|     if args.logger == "wandb":
 | |
|         logger = WandbLogger(
 | |
|             save_interval=1,
 | |
|             name=log_name.replace(os.path.sep, "__"),
 | |
|             run_id=args.resume_id,
 | |
|             config=args,
 | |
|             project=args.wandb_project,
 | |
|         )
 | |
|     writer = SummaryWriter(log_path)
 | |
|     writer.add_text("args", str(args))
 | |
|     if args.logger == "tensorboard":
 | |
|         logger = TensorboardLogger(writer)
 | |
|     else:  # wandb
 | |
|         logger.load(writer)
 | |
| 
 | |
|     def save_best_fn(policy):
 | |
|         torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | |
| 
 | |
|     def stop_fn(mean_rewards: float) -> bool:
 | |
|         if env.spec.reward_threshold:
 | |
|             return mean_rewards >= env.spec.reward_threshold
 | |
|         return False
 | |
| 
 | |
|     def train_fn(epoch, env_step):
 | |
|         # nature DQN setting, linear decay in the first 1M steps
 | |
|         if env_step <= 1e6:
 | |
|             eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
 | |
|         else:
 | |
|             eps = args.eps_train_final
 | |
|         policy.set_eps(eps)
 | |
|         if env_step % 1000 == 0:
 | |
|             logger.write("train/env_step", env_step, {"train/eps": eps})
 | |
| 
 | |
|     def test_fn(epoch, env_step):
 | |
|         policy.set_eps(args.eps_test)
 | |
| 
 | |
|     # watch agent's performance
 | |
|     def watch():
 | |
|         print("Setup test envs ...")
 | |
|         policy.eval()
 | |
|         policy.set_eps(args.eps_test)
 | |
|         test_envs.seed(args.seed)
 | |
|         if args.save_buffer_name:
 | |
|             print(f"Generate buffer with size {args.buffer_size}")
 | |
|             buffer = VectorReplayBuffer(
 | |
|                 args.buffer_size,
 | |
|                 buffer_num=len(test_envs),
 | |
|                 ignore_obs_next=True,
 | |
|                 save_only_last_obs=True,
 | |
|                 stack_num=args.frames_stack,
 | |
|             )
 | |
|             collector = Collector(policy, test_envs, buffer, exploration_noise=True)
 | |
|             result = collector.collect(n_step=args.buffer_size)
 | |
|             print(f"Save buffer into {args.save_buffer_name}")
 | |
|             # Unfortunately, pickle will cause oom with 1M buffer size
 | |
|             buffer.save_hdf5(args.save_buffer_name)
 | |
|         else:
 | |
|             print("Testing agent ...")
 | |
|             test_collector.reset()
 | |
|             result = test_collector.collect(n_episode=args.test_num, render=args.render)
 | |
|         rew = result["rews"].mean()
 | |
|         lens = result["lens"].mean() * args.skip_num
 | |
|         print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
 | |
|         print(f'Mean length (over {result["n/ep"]} episodes): {lens}')
 | |
| 
 | |
|     if args.watch:
 | |
|         watch()
 | |
|         sys.exit(0)
 | |
| 
 | |
|     # test train_collector and start filling replay buffer
 | |
|     train_collector.collect(n_step=args.batch_size * args.training_num)
 | |
|     # trainer
 | |
|     result = OffpolicyTrainer(
 | |
|         policy=policy,
 | |
|         train_collector=train_collector,
 | |
|         test_collector=test_collector,
 | |
|         max_epoch=args.epoch,
 | |
|         step_per_epoch=args.step_per_epoch,
 | |
|         step_per_collect=args.step_per_collect,
 | |
|         episode_per_test=args.test_num,
 | |
|         batch_size=args.batch_size,
 | |
|         train_fn=train_fn,
 | |
|         test_fn=test_fn,
 | |
|         stop_fn=stop_fn,
 | |
|         save_best_fn=save_best_fn,
 | |
|         logger=logger,
 | |
|         update_per_step=args.update_per_step,
 | |
|         test_in_train=False,
 | |
|     ).run()
 | |
| 
 | |
|     pprint.pprint(result)
 | |
|     watch()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     test_c51(get_args())
 |