Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			236 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			236 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
import os
 | 
						|
import pickle
 | 
						|
import pprint
 | 
						|
 | 
						|
import gymnasium as gym
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from torch.utils.tensorboard import SummaryWriter
 | 
						|
 | 
						|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
 | 
						|
from tianshou.env import DummyVectorEnv
 | 
						|
from tianshou.policy import RainbowPolicy
 | 
						|
from tianshou.trainer import OffpolicyTrainer
 | 
						|
from tianshou.utils import TensorboardLogger
 | 
						|
from tianshou.utils.net.common import Net
 | 
						|
from tianshou.utils.net.discrete import NoisyLinear
 | 
						|
 | 
						|
 | 
						|
def get_args():
 | 
						|
    parser = argparse.ArgumentParser()
 | 
						|
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
						|
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
						|
    parser.add_argument("--seed", type=int, default=1626)
 | 
						|
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
						|
    parser.add_argument("--eps-train", type=float, default=0.1)
 | 
						|
    parser.add_argument("--buffer-size", type=int, default=20000)
 | 
						|
    parser.add_argument("--lr", type=float, default=1e-3)
 | 
						|
    parser.add_argument("--gamma", type=float, default=0.9)
 | 
						|
    parser.add_argument("--num-atoms", type=int, default=51)
 | 
						|
    parser.add_argument("--v-min", type=float, default=-10.0)
 | 
						|
    parser.add_argument("--v-max", type=float, default=10.0)
 | 
						|
    parser.add_argument("--noisy-std", type=float, default=0.1)
 | 
						|
    parser.add_argument("--n-step", type=int, default=3)
 | 
						|
    parser.add_argument("--target-update-freq", type=int, default=320)
 | 
						|
    parser.add_argument("--epoch", type=int, default=10)
 | 
						|
    parser.add_argument("--step-per-epoch", type=int, default=8000)
 | 
						|
    parser.add_argument("--step-per-collect", type=int, default=8)
 | 
						|
    parser.add_argument("--update-per-step", type=float, default=0.125)
 | 
						|
    parser.add_argument("--batch-size", type=int, default=64)
 | 
						|
    parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128])
 | 
						|
    parser.add_argument("--training-num", type=int, default=8)
 | 
						|
    parser.add_argument("--test-num", type=int, default=100)
 | 
						|
    parser.add_argument("--logdir", type=str, default="log")
 | 
						|
    parser.add_argument("--render", type=float, default=0.0)
 | 
						|
    parser.add_argument("--prioritized-replay", action="store_true", default=False)
 | 
						|
    parser.add_argument("--alpha", type=float, default=0.6)
 | 
						|
    parser.add_argument("--beta", type=float, default=0.4)
 | 
						|
    parser.add_argument("--beta-final", type=float, default=1.0)
 | 
						|
    parser.add_argument("--resume", action="store_true")
 | 
						|
    parser.add_argument(
 | 
						|
        "--device",
 | 
						|
        type=str,
 | 
						|
        default="cuda" if torch.cuda.is_available() else "cpu",
 | 
						|
    )
 | 
						|
    parser.add_argument("--save-interval", type=int, default=4)
 | 
						|
    return parser.parse_known_args()[0]
 | 
						|
 | 
						|
 | 
						|
def test_rainbow(args=get_args()):
 | 
						|
    env = gym.make(args.task)
 | 
						|
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
						|
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
						|
    if args.reward_threshold is None:
 | 
						|
        default_reward_threshold = {"CartPole-v0": 195}
 | 
						|
        args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
 | 
						|
    # train_envs = gym.make(args.task)
 | 
						|
    # you can also use tianshou.env.SubprocVectorEnv
 | 
						|
    train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)])
 | 
						|
    # test_envs = gym.make(args.task)
 | 
						|
    test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
 | 
						|
    # seed
 | 
						|
    np.random.seed(args.seed)
 | 
						|
    torch.manual_seed(args.seed)
 | 
						|
    train_envs.seed(args.seed)
 | 
						|
    test_envs.seed(args.seed)
 | 
						|
 | 
						|
    # model
 | 
						|
 | 
						|
    def noisy_linear(x, y):
 | 
						|
        return NoisyLinear(x, y, args.noisy_std)
 | 
						|
 | 
						|
    net = Net(
 | 
						|
        args.state_shape,
 | 
						|
        args.action_shape,
 | 
						|
        hidden_sizes=args.hidden_sizes,
 | 
						|
        device=args.device,
 | 
						|
        softmax=True,
 | 
						|
        num_atoms=args.num_atoms,
 | 
						|
        dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}),
 | 
						|
    )
 | 
						|
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
 | 
						|
    policy = RainbowPolicy(
 | 
						|
        model=net,
 | 
						|
        optim=optim,
 | 
						|
        discount_factor=args.gamma,
 | 
						|
        action_space=env.action_space,
 | 
						|
        num_atoms=args.num_atoms,
 | 
						|
        v_min=args.v_min,
 | 
						|
        v_max=args.v_max,
 | 
						|
        estimation_step=args.n_step,
 | 
						|
        target_update_freq=args.target_update_freq,
 | 
						|
    ).to(args.device)
 | 
						|
    # buffer
 | 
						|
    if args.prioritized_replay:
 | 
						|
        buf = PrioritizedVectorReplayBuffer(
 | 
						|
            args.buffer_size,
 | 
						|
            buffer_num=len(train_envs),
 | 
						|
            alpha=args.alpha,
 | 
						|
            beta=args.beta,
 | 
						|
            weight_norm=True,
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
 | 
						|
    # collector
 | 
						|
    train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
 | 
						|
    test_collector = Collector(policy, test_envs, exploration_noise=True)
 | 
						|
    # policy.set_eps(1)
 | 
						|
    train_collector.collect(n_step=args.batch_size * args.training_num)
 | 
						|
    # log
 | 
						|
    log_path = os.path.join(args.logdir, args.task, "rainbow")
 | 
						|
    writer = SummaryWriter(log_path)
 | 
						|
    logger = TensorboardLogger(writer, save_interval=args.save_interval)
 | 
						|
 | 
						|
    def save_best_fn(policy):
 | 
						|
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | 
						|
 | 
						|
    def stop_fn(mean_rewards):
 | 
						|
        return mean_rewards >= args.reward_threshold
 | 
						|
 | 
						|
    def train_fn(epoch, env_step):
 | 
						|
        # eps annealing, just a demo
 | 
						|
        if env_step <= 10000:
 | 
						|
            policy.set_eps(args.eps_train)
 | 
						|
        elif env_step <= 50000:
 | 
						|
            eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train)
 | 
						|
            policy.set_eps(eps)
 | 
						|
        else:
 | 
						|
            policy.set_eps(0.1 * args.eps_train)
 | 
						|
        # beta annealing, just a demo
 | 
						|
        if args.prioritized_replay:
 | 
						|
            if env_step <= 10000:
 | 
						|
                beta = args.beta
 | 
						|
            elif env_step <= 50000:
 | 
						|
                beta = args.beta - (env_step - 10000) / 40000 * (args.beta - args.beta_final)
 | 
						|
            else:
 | 
						|
                beta = args.beta_final
 | 
						|
            buf.set_beta(beta)
 | 
						|
 | 
						|
    def test_fn(epoch, env_step):
 | 
						|
        policy.set_eps(args.eps_test)
 | 
						|
 | 
						|
    def save_checkpoint_fn(epoch, env_step, gradient_step):
 | 
						|
        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
 | 
						|
        ckpt_path = os.path.join(log_path, "checkpoint.pth")
 | 
						|
        # Example: saving by epoch num
 | 
						|
        # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
 | 
						|
        torch.save(
 | 
						|
            {
 | 
						|
                "model": policy.state_dict(),
 | 
						|
                "optim": optim.state_dict(),
 | 
						|
            },
 | 
						|
            ckpt_path,
 | 
						|
        )
 | 
						|
        buffer_path = os.path.join(log_path, "train_buffer.pkl")
 | 
						|
        with open(buffer_path, "wb") as f:
 | 
						|
            pickle.dump(train_collector.buffer, f)
 | 
						|
        return ckpt_path
 | 
						|
 | 
						|
    if args.resume:
 | 
						|
        # load from existing checkpoint
 | 
						|
        print(f"Loading agent under {log_path}")
 | 
						|
        ckpt_path = os.path.join(log_path, "checkpoint.pth")
 | 
						|
        if os.path.exists(ckpt_path):
 | 
						|
            checkpoint = torch.load(ckpt_path, map_location=args.device)
 | 
						|
            policy.load_state_dict(checkpoint["model"])
 | 
						|
            policy.optim.load_state_dict(checkpoint["optim"])
 | 
						|
            print("Successfully restore policy and optim.")
 | 
						|
        else:
 | 
						|
            print("Fail to restore policy and optim.")
 | 
						|
        buffer_path = os.path.join(log_path, "train_buffer.pkl")
 | 
						|
        if os.path.exists(buffer_path):
 | 
						|
            with open(buffer_path, "rb") as f:
 | 
						|
                train_collector.buffer = pickle.load(f)
 | 
						|
            print("Successfully restore buffer.")
 | 
						|
        else:
 | 
						|
            print("Fail to restore buffer.")
 | 
						|
 | 
						|
    # trainer
 | 
						|
    result = OffpolicyTrainer(
 | 
						|
        policy=policy,
 | 
						|
        train_collector=train_collector,
 | 
						|
        test_collector=test_collector,
 | 
						|
        max_epoch=args.epoch,
 | 
						|
        step_per_epoch=args.step_per_epoch,
 | 
						|
        step_per_collect=args.step_per_collect,
 | 
						|
        episode_per_test=args.test_num,
 | 
						|
        batch_size=args.batch_size,
 | 
						|
        update_per_step=args.update_per_step,
 | 
						|
        train_fn=train_fn,
 | 
						|
        test_fn=test_fn,
 | 
						|
        stop_fn=stop_fn,
 | 
						|
        save_best_fn=save_best_fn,
 | 
						|
        logger=logger,
 | 
						|
        resume_from_log=args.resume,
 | 
						|
        save_checkpoint_fn=save_checkpoint_fn,
 | 
						|
    ).run()
 | 
						|
    assert stop_fn(result["best_reward"])
 | 
						|
 | 
						|
    if __name__ == "__main__":
 | 
						|
        pprint.pprint(result)
 | 
						|
        # Let's watch its performance!
 | 
						|
        env = gym.make(args.task)
 | 
						|
        policy.eval()
 | 
						|
        policy.set_eps(args.eps_test)
 | 
						|
        collector = Collector(policy, env)
 | 
						|
        result = collector.collect(n_episode=1, render=args.render)
 | 
						|
        rews, lens = result["rews"], result["lens"]
 | 
						|
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
 | 
						|
 | 
						|
 | 
						|
def test_rainbow_resume(args=get_args()):
 | 
						|
    args.resume = True
 | 
						|
    test_rainbow(args)
 | 
						|
 | 
						|
 | 
						|
def test_prainbow(args=get_args()):
 | 
						|
    args.prioritized_replay = True
 | 
						|
    args.gamma = 0.95
 | 
						|
    args.seed = 1
 | 
						|
    test_rainbow(args)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_rainbow(get_args())
 |