Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			264 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			264 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
import datetime
 | 
						|
import os
 | 
						|
import pprint
 | 
						|
import sys
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from atari_network import DQN
 | 
						|
from atari_wrapper import make_atari_env
 | 
						|
from torch.utils.tensorboard import SummaryWriter
 | 
						|
 | 
						|
from tianshou.data import Collector, VectorReplayBuffer
 | 
						|
from tianshou.policy import DQNPolicy
 | 
						|
from tianshou.policy.modelbased.icm import ICMPolicy
 | 
						|
from tianshou.trainer import OffpolicyTrainer
 | 
						|
from tianshou.utils import TensorboardLogger, WandbLogger
 | 
						|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
 | 
						|
 | 
						|
 | 
						|
def get_args():
 | 
						|
    parser = argparse.ArgumentParser()
 | 
						|
    parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
 | 
						|
    parser.add_argument("--seed", type=int, default=0)
 | 
						|
    parser.add_argument("--scale-obs", type=int, default=0)
 | 
						|
    parser.add_argument("--eps-test", type=float, default=0.005)
 | 
						|
    parser.add_argument("--eps-train", type=float, default=1.0)
 | 
						|
    parser.add_argument("--eps-train-final", type=float, default=0.05)
 | 
						|
    parser.add_argument("--buffer-size", type=int, default=100000)
 | 
						|
    parser.add_argument("--lr", type=float, default=0.0001)
 | 
						|
    parser.add_argument("--gamma", type=float, default=0.99)
 | 
						|
    parser.add_argument("--n-step", type=int, default=3)
 | 
						|
    parser.add_argument("--target-update-freq", type=int, default=500)
 | 
						|
    parser.add_argument("--epoch", type=int, default=100)
 | 
						|
    parser.add_argument("--step-per-epoch", type=int, default=100000)
 | 
						|
    parser.add_argument("--step-per-collect", type=int, default=10)
 | 
						|
    parser.add_argument("--update-per-step", type=float, default=0.1)
 | 
						|
    parser.add_argument("--batch-size", type=int, default=32)
 | 
						|
    parser.add_argument("--training-num", type=int, default=10)
 | 
						|
    parser.add_argument("--test-num", type=int, default=10)
 | 
						|
    parser.add_argument("--logdir", type=str, default="log")
 | 
						|
    parser.add_argument("--render", type=float, default=0.0)
 | 
						|
    parser.add_argument(
 | 
						|
        "--device",
 | 
						|
        type=str,
 | 
						|
        default="cuda" if torch.cuda.is_available() else "cpu",
 | 
						|
    )
 | 
						|
    parser.add_argument("--frames-stack", type=int, default=4)
 | 
						|
    parser.add_argument("--resume-path", type=str, default=None)
 | 
						|
    parser.add_argument("--resume-id", type=str, default=None)
 | 
						|
    parser.add_argument(
 | 
						|
        "--logger",
 | 
						|
        type=str,
 | 
						|
        default="tensorboard",
 | 
						|
        choices=["tensorboard", "wandb"],
 | 
						|
    )
 | 
						|
    parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
 | 
						|
    parser.add_argument(
 | 
						|
        "--watch",
 | 
						|
        default=False,
 | 
						|
        action="store_true",
 | 
						|
        help="watch the play of pre-trained policy only",
 | 
						|
    )
 | 
						|
    parser.add_argument("--save-buffer-name", type=str, default=None)
 | 
						|
    parser.add_argument(
 | 
						|
        "--icm-lr-scale",
 | 
						|
        type=float,
 | 
						|
        default=0.0,
 | 
						|
        help="use intrinsic curiosity module with this lr scale",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--icm-reward-scale",
 | 
						|
        type=float,
 | 
						|
        default=0.01,
 | 
						|
        help="scaling factor for intrinsic curiosity reward",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--icm-forward-loss-weight",
 | 
						|
        type=float,
 | 
						|
        default=0.2,
 | 
						|
        help="weight for the forward model loss in ICM",
 | 
						|
    )
 | 
						|
    return parser.parse_args()
 | 
						|
 | 
						|
 | 
						|
def test_dqn(args=get_args()):
 | 
						|
    env, train_envs, test_envs = make_atari_env(
 | 
						|
        args.task,
 | 
						|
        args.seed,
 | 
						|
        args.training_num,
 | 
						|
        args.test_num,
 | 
						|
        scale=args.scale_obs,
 | 
						|
        frame_stack=args.frames_stack,
 | 
						|
    )
 | 
						|
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
						|
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
						|
    # should be N_FRAMES x H x W
 | 
						|
    print("Observations shape:", args.state_shape)
 | 
						|
    print("Actions shape:", args.action_shape)
 | 
						|
    # seed
 | 
						|
    np.random.seed(args.seed)
 | 
						|
    torch.manual_seed(args.seed)
 | 
						|
    # define model
 | 
						|
    net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
 | 
						|
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
 | 
						|
    # define policy
 | 
						|
    policy = DQNPolicy(
 | 
						|
        model=net,
 | 
						|
        optim=optim,
 | 
						|
        action_space=env.action_space,
 | 
						|
        discount_factor=args.gamma,
 | 
						|
        estimation_step=args.n_step,
 | 
						|
        target_update_freq=args.target_update_freq,
 | 
						|
    )
 | 
						|
    if args.icm_lr_scale > 0:
 | 
						|
        feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True)
 | 
						|
        action_dim = np.prod(args.action_shape)
 | 
						|
        feature_dim = feature_net.output_dim
 | 
						|
        icm_net = IntrinsicCuriosityModule(
 | 
						|
            feature_net.net,
 | 
						|
            feature_dim,
 | 
						|
            action_dim,
 | 
						|
            hidden_sizes=[512],
 | 
						|
            device=args.device,
 | 
						|
        )
 | 
						|
        icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
 | 
						|
        policy = ICMPolicy(
 | 
						|
            policy=policy,
 | 
						|
            model=icm_net,
 | 
						|
            optim=icm_optim,
 | 
						|
            action_space=env.action_space,
 | 
						|
            lr_scale=args.icm_lr_scale,
 | 
						|
            reward_scale=args.icm_reward_scale,
 | 
						|
            forward_loss_weight=args.icm_forward_loss_weight,
 | 
						|
        ).to(args.device)
 | 
						|
    # load a previous policy
 | 
						|
    if args.resume_path:
 | 
						|
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
 | 
						|
        print("Loaded agent from: ", args.resume_path)
 | 
						|
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
 | 
						|
    # when you have enough RAM
 | 
						|
    buffer = VectorReplayBuffer(
 | 
						|
        args.buffer_size,
 | 
						|
        buffer_num=len(train_envs),
 | 
						|
        ignore_obs_next=True,
 | 
						|
        save_only_last_obs=True,
 | 
						|
        stack_num=args.frames_stack,
 | 
						|
    )
 | 
						|
    # collector
 | 
						|
    train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
 | 
						|
    test_collector = Collector(policy, test_envs, exploration_noise=True)
 | 
						|
 | 
						|
    # log
 | 
						|
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
 | 
						|
    args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
 | 
						|
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
 | 
						|
    log_path = os.path.join(args.logdir, log_name)
 | 
						|
 | 
						|
    # logger
 | 
						|
    if args.logger == "wandb":
 | 
						|
        logger = WandbLogger(
 | 
						|
            save_interval=1,
 | 
						|
            name=log_name.replace(os.path.sep, "__"),
 | 
						|
            run_id=args.resume_id,
 | 
						|
            config=args,
 | 
						|
            project=args.wandb_project,
 | 
						|
        )
 | 
						|
    writer = SummaryWriter(log_path)
 | 
						|
    writer.add_text("args", str(args))
 | 
						|
    if args.logger == "tensorboard":
 | 
						|
        logger = TensorboardLogger(writer)
 | 
						|
    else:  # wandb
 | 
						|
        logger.load(writer)
 | 
						|
 | 
						|
    def save_best_fn(policy):
 | 
						|
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | 
						|
 | 
						|
    def stop_fn(mean_rewards: float) -> bool:
 | 
						|
        if env.spec.reward_threshold:
 | 
						|
            return mean_rewards >= env.spec.reward_threshold
 | 
						|
        if "Pong" in args.task:
 | 
						|
            return mean_rewards >= 20
 | 
						|
        return False
 | 
						|
 | 
						|
    def train_fn(epoch, env_step):
 | 
						|
        # nature DQN setting, linear decay in the first 1M steps
 | 
						|
        if env_step <= 1e6:
 | 
						|
            eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
 | 
						|
        else:
 | 
						|
            eps = args.eps_train_final
 | 
						|
        policy.set_eps(eps)
 | 
						|
        if env_step % 1000 == 0:
 | 
						|
            logger.write("train/env_step", env_step, {"train/eps": eps})
 | 
						|
 | 
						|
    def test_fn(epoch, env_step):
 | 
						|
        policy.set_eps(args.eps_test)
 | 
						|
 | 
						|
    def save_checkpoint_fn(epoch, env_step, gradient_step):
 | 
						|
        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
 | 
						|
        ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
 | 
						|
        torch.save({"model": policy.state_dict()}, ckpt_path)
 | 
						|
        return ckpt_path
 | 
						|
 | 
						|
    # watch agent's performance
 | 
						|
    def watch():
 | 
						|
        print("Setup test envs ...")
 | 
						|
        policy.eval()
 | 
						|
        policy.set_eps(args.eps_test)
 | 
						|
        test_envs.seed(args.seed)
 | 
						|
        if args.save_buffer_name:
 | 
						|
            print(f"Generate buffer with size {args.buffer_size}")
 | 
						|
            buffer = VectorReplayBuffer(
 | 
						|
                args.buffer_size,
 | 
						|
                buffer_num=len(test_envs),
 | 
						|
                ignore_obs_next=True,
 | 
						|
                save_only_last_obs=True,
 | 
						|
                stack_num=args.frames_stack,
 | 
						|
            )
 | 
						|
            collector = Collector(policy, test_envs, buffer, exploration_noise=True)
 | 
						|
            result = collector.collect(n_step=args.buffer_size)
 | 
						|
            print(f"Save buffer into {args.save_buffer_name}")
 | 
						|
            # Unfortunately, pickle will cause oom with 1M buffer size
 | 
						|
            buffer.save_hdf5(args.save_buffer_name)
 | 
						|
        else:
 | 
						|
            print("Testing agent ...")
 | 
						|
            test_collector.reset()
 | 
						|
            result = test_collector.collect(n_episode=args.test_num, render=args.render)
 | 
						|
        rew = result["rews"].mean()
 | 
						|
        print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
 | 
						|
 | 
						|
    if args.watch:
 | 
						|
        watch()
 | 
						|
        sys.exit(0)
 | 
						|
 | 
						|
    # test train_collector and start filling replay buffer
 | 
						|
    train_collector.collect(n_step=args.batch_size * args.training_num)
 | 
						|
    # trainer
 | 
						|
    result = OffpolicyTrainer(
 | 
						|
        policy=policy,
 | 
						|
        train_collector=train_collector,
 | 
						|
        test_collector=test_collector,
 | 
						|
        max_epoch=args.epoch,
 | 
						|
        step_per_epoch=args.step_per_epoch,
 | 
						|
        step_per_collect=args.step_per_collect,
 | 
						|
        episode_per_test=args.test_num,
 | 
						|
        batch_size=args.batch_size,
 | 
						|
        train_fn=train_fn,
 | 
						|
        test_fn=test_fn,
 | 
						|
        stop_fn=stop_fn,
 | 
						|
        save_best_fn=save_best_fn,
 | 
						|
        logger=logger,
 | 
						|
        update_per_step=args.update_per_step,
 | 
						|
        test_in_train=False,
 | 
						|
        resume_from_log=args.resume_id is not None,
 | 
						|
        save_checkpoint_fn=save_checkpoint_fn,
 | 
						|
    ).run()
 | 
						|
 | 
						|
    pprint.pprint(result)
 | 
						|
    watch()
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_dqn(get_args())
 |