Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			216 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			216 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
import datetime
 | 
						|
import os
 | 
						|
import pickle
 | 
						|
import pprint
 | 
						|
 | 
						|
import gymnasium as gym
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from torch.utils.tensorboard import SummaryWriter
 | 
						|
 | 
						|
from tianshou.data import Collector, VectorReplayBuffer
 | 
						|
from tianshou.env import DummyVectorEnv
 | 
						|
from tianshou.policy import BCQPolicy
 | 
						|
from tianshou.trainer import OfflineTrainer
 | 
						|
from tianshou.utils import TensorboardLogger
 | 
						|
from tianshou.utils.net.common import MLP, Net
 | 
						|
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    from gather_pendulum_data import expert_file_name, gather_data
 | 
						|
else:  # pytest
 | 
						|
    from test.offline.gather_pendulum_data import expert_file_name, gather_data
 | 
						|
 | 
						|
 | 
						|
def get_args():
 | 
						|
    parser = argparse.ArgumentParser()
 | 
						|
    parser.add_argument("--task", type=str, default="Pendulum-v1")
 | 
						|
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
						|
    parser.add_argument("--seed", type=int, default=0)
 | 
						|
    parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64])
 | 
						|
    parser.add_argument("--actor-lr", type=float, default=1e-3)
 | 
						|
    parser.add_argument("--critic-lr", type=float, default=1e-3)
 | 
						|
    parser.add_argument("--epoch", type=int, default=5)
 | 
						|
    parser.add_argument("--step-per-epoch", type=int, default=500)
 | 
						|
    parser.add_argument("--batch-size", type=int, default=32)
 | 
						|
    parser.add_argument("--test-num", type=int, default=10)
 | 
						|
    parser.add_argument("--logdir", type=str, default="log")
 | 
						|
    parser.add_argument("--render", type=float, default=1 / 35)
 | 
						|
 | 
						|
    parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[32, 32])
 | 
						|
    # default to 2 * action_dim
 | 
						|
    parser.add_argument("--latent_dim", type=int, default=None)
 | 
						|
    parser.add_argument("--gamma", default=0.99)
 | 
						|
    parser.add_argument("--tau", default=0.005)
 | 
						|
    # Weighting for Clipped Double Q-learning in BCQ
 | 
						|
    parser.add_argument("--lmbda", default=0.75)
 | 
						|
    # Max perturbation hyper-parameter for BCQ
 | 
						|
    parser.add_argument("--phi", default=0.05)
 | 
						|
    parser.add_argument(
 | 
						|
        "--device",
 | 
						|
        type=str,
 | 
						|
        default="cuda" if torch.cuda.is_available() else "cpu",
 | 
						|
    )
 | 
						|
    parser.add_argument("--resume-path", type=str, default=None)
 | 
						|
    parser.add_argument(
 | 
						|
        "--watch",
 | 
						|
        default=False,
 | 
						|
        action="store_true",
 | 
						|
        help="watch the play of pre-trained policy only",
 | 
						|
    )
 | 
						|
    parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
 | 
						|
    parser.add_argument("--show-progress", action="store_true")
 | 
						|
    return parser.parse_known_args()[0]
 | 
						|
 | 
						|
 | 
						|
def test_bcq(args=get_args()):
 | 
						|
    if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
 | 
						|
        if args.load_buffer_name.endswith(".hdf5"):
 | 
						|
            buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
 | 
						|
        else:
 | 
						|
            with open(args.load_buffer_name, "rb") as f:
 | 
						|
                buffer = pickle.load(f)
 | 
						|
    else:
 | 
						|
        buffer = gather_data()
 | 
						|
    env = gym.make(args.task)
 | 
						|
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
						|
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
						|
    args.max_action = env.action_space.high[0]  # float
 | 
						|
    if args.reward_threshold is None:
 | 
						|
        # too low?
 | 
						|
        default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100}
 | 
						|
        args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
 | 
						|
 | 
						|
    args.state_dim = args.state_shape[0]
 | 
						|
    args.action_dim = args.action_shape[0]
 | 
						|
    # test_envs = gym.make(args.task)
 | 
						|
    test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
 | 
						|
    # seed
 | 
						|
    np.random.seed(args.seed)
 | 
						|
    torch.manual_seed(args.seed)
 | 
						|
    test_envs.seed(args.seed)
 | 
						|
 | 
						|
    # model
 | 
						|
    # perturbation network
 | 
						|
    net_a = MLP(
 | 
						|
        input_dim=args.state_dim + args.action_dim,
 | 
						|
        output_dim=args.action_dim,
 | 
						|
        hidden_sizes=args.hidden_sizes,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to(
 | 
						|
        args.device,
 | 
						|
    )
 | 
						|
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
 | 
						|
 | 
						|
    net_c = Net(
 | 
						|
        args.state_shape,
 | 
						|
        args.action_shape,
 | 
						|
        hidden_sizes=args.hidden_sizes,
 | 
						|
        concat=True,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    critic = Critic(net_c, device=args.device).to(args.device)
 | 
						|
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
 | 
						|
 | 
						|
    # vae
 | 
						|
    # output_dim = 0, so the last Module in the encoder is ReLU
 | 
						|
    vae_encoder = MLP(
 | 
						|
        input_dim=args.state_dim + args.action_dim,
 | 
						|
        hidden_sizes=args.vae_hidden_sizes,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    if not args.latent_dim:
 | 
						|
        args.latent_dim = args.action_dim * 2
 | 
						|
    vae_decoder = MLP(
 | 
						|
        input_dim=args.state_dim + args.latent_dim,
 | 
						|
        output_dim=args.action_dim,
 | 
						|
        hidden_sizes=args.vae_hidden_sizes,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    vae = VAE(
 | 
						|
        vae_encoder,
 | 
						|
        vae_decoder,
 | 
						|
        hidden_dim=args.vae_hidden_sizes[-1],
 | 
						|
        latent_dim=args.latent_dim,
 | 
						|
        max_action=args.max_action,
 | 
						|
        device=args.device,
 | 
						|
    ).to(args.device)
 | 
						|
    vae_optim = torch.optim.Adam(vae.parameters())
 | 
						|
 | 
						|
    policy = BCQPolicy(
 | 
						|
        actor_perturbation=actor,
 | 
						|
        actor_perturbation_optim=actor_optim,
 | 
						|
        critic=critic,
 | 
						|
        critic_optim=critic_optim,
 | 
						|
        vae=vae,
 | 
						|
        vae_optim=vae_optim,
 | 
						|
        action_space=env.action_space,
 | 
						|
        device=args.device,
 | 
						|
        gamma=args.gamma,
 | 
						|
        tau=args.tau,
 | 
						|
        lmbda=args.lmbda,
 | 
						|
    )
 | 
						|
 | 
						|
    # load a previous policy
 | 
						|
    if args.resume_path:
 | 
						|
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
 | 
						|
        print("Loaded agent from: ", args.resume_path)
 | 
						|
 | 
						|
    # collector
 | 
						|
    # buffer has been gathered
 | 
						|
    # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
 | 
						|
    test_collector = Collector(policy, test_envs)
 | 
						|
    # log
 | 
						|
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
 | 
						|
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
 | 
						|
    log_path = os.path.join(args.logdir, args.task, "bcq", log_file)
 | 
						|
    writer = SummaryWriter(log_path)
 | 
						|
    writer.add_text("args", str(args))
 | 
						|
    logger = TensorboardLogger(writer)
 | 
						|
 | 
						|
    def save_best_fn(policy):
 | 
						|
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | 
						|
 | 
						|
    def stop_fn(mean_rewards):
 | 
						|
        return mean_rewards >= args.reward_threshold
 | 
						|
 | 
						|
    def watch():
 | 
						|
        policy.load_state_dict(
 | 
						|
            torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
 | 
						|
        )
 | 
						|
        policy.eval()
 | 
						|
        collector = Collector(policy, env)
 | 
						|
        collector.collect(n_episode=1, render=1 / 35)
 | 
						|
 | 
						|
    # trainer
 | 
						|
    result = OfflineTrainer(
 | 
						|
        policy=policy,
 | 
						|
        buffer=buffer,
 | 
						|
        test_collector=test_collector,
 | 
						|
        max_epoch=args.epoch,
 | 
						|
        step_per_epoch=args.step_per_epoch,
 | 
						|
        episode_per_test=args.test_num,
 | 
						|
        batch_size=args.batch_size,
 | 
						|
        save_best_fn=save_best_fn,
 | 
						|
        stop_fn=stop_fn,
 | 
						|
        logger=logger,
 | 
						|
        show_progress=args.show_progress,
 | 
						|
    ).run()
 | 
						|
    assert stop_fn(result["best_reward"])
 | 
						|
 | 
						|
    # Let's watch its performance!
 | 
						|
    if __name__ == "__main__":
 | 
						|
        pprint.pprint(result)
 | 
						|
        env = gym.make(args.task)
 | 
						|
        policy.eval()
 | 
						|
        collector = Collector(policy, env)
 | 
						|
        result = collector.collect(n_episode=1, render=args.render)
 | 
						|
        rews, lens = result["rews"], result["lens"]
 | 
						|
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_bcq()
 |