Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			224 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			224 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python3
 | 
						|
 | 
						|
import argparse
 | 
						|
import datetime
 | 
						|
import os
 | 
						|
import pprint
 | 
						|
 | 
						|
import gymnasium as gym
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from torch.utils.tensorboard import SummaryWriter
 | 
						|
 | 
						|
from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer
 | 
						|
from tianshou.data import Collector
 | 
						|
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
 | 
						|
from tianshou.exploration import GaussianNoise
 | 
						|
from tianshou.policy import TD3BCPolicy
 | 
						|
from tianshou.trainer import OfflineTrainer
 | 
						|
from tianshou.utils import TensorboardLogger, WandbLogger
 | 
						|
from tianshou.utils.net.common import Net
 | 
						|
from tianshou.utils.net.continuous import Actor, Critic
 | 
						|
 | 
						|
 | 
						|
def get_args():
 | 
						|
    parser = argparse.ArgumentParser()
 | 
						|
    parser.add_argument("--task", type=str, default="HalfCheetah-v2")
 | 
						|
    parser.add_argument("--seed", type=int, default=0)
 | 
						|
    parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2")
 | 
						|
    parser.add_argument("--buffer-size", type=int, default=1000000)
 | 
						|
    parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
 | 
						|
    parser.add_argument("--actor-lr", type=float, default=3e-4)
 | 
						|
    parser.add_argument("--critic-lr", type=float, default=3e-4)
 | 
						|
    parser.add_argument("--epoch", type=int, default=200)
 | 
						|
    parser.add_argument("--step-per-epoch", type=int, default=5000)
 | 
						|
    parser.add_argument("--n-step", type=int, default=3)
 | 
						|
    parser.add_argument("--batch-size", type=int, default=256)
 | 
						|
 | 
						|
    parser.add_argument("--alpha", type=float, default=2.5)
 | 
						|
    parser.add_argument("--exploration-noise", type=float, default=0.1)
 | 
						|
    parser.add_argument("--policy-noise", type=float, default=0.2)
 | 
						|
    parser.add_argument("--noise-clip", type=float, default=0.5)
 | 
						|
    parser.add_argument("--update-actor-freq", type=int, default=2)
 | 
						|
    parser.add_argument("--tau", type=float, default=0.005)
 | 
						|
    parser.add_argument("--gamma", type=float, default=0.99)
 | 
						|
    parser.add_argument("--norm-obs", type=int, default=1)
 | 
						|
 | 
						|
    parser.add_argument("--eval-freq", type=int, default=1)
 | 
						|
    parser.add_argument("--test-num", type=int, default=10)
 | 
						|
    parser.add_argument("--logdir", type=str, default="log")
 | 
						|
    parser.add_argument("--render", type=float, default=1 / 35)
 | 
						|
    parser.add_argument(
 | 
						|
        "--device",
 | 
						|
        type=str,
 | 
						|
        default="cuda" if torch.cuda.is_available() else "cpu",
 | 
						|
    )
 | 
						|
    parser.add_argument("--resume-path", type=str, default=None)
 | 
						|
    parser.add_argument("--resume-id", type=str, default=None)
 | 
						|
    parser.add_argument(
 | 
						|
        "--logger",
 | 
						|
        type=str,
 | 
						|
        default="tensorboard",
 | 
						|
        choices=["tensorboard", "wandb"],
 | 
						|
    )
 | 
						|
    parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
 | 
						|
    parser.add_argument(
 | 
						|
        "--watch",
 | 
						|
        default=False,
 | 
						|
        action="store_true",
 | 
						|
        help="watch the play of pre-trained policy only",
 | 
						|
    )
 | 
						|
    return parser.parse_args()
 | 
						|
 | 
						|
 | 
						|
def test_td3_bc():
 | 
						|
    args = get_args()
 | 
						|
    env = gym.make(args.task)
 | 
						|
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
						|
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
						|
    args.max_action = env.action_space.high[0]  # float
 | 
						|
    print("device:", args.device)
 | 
						|
    print("Observations shape:", args.state_shape)
 | 
						|
    print("Actions shape:", args.action_shape)
 | 
						|
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
 | 
						|
 | 
						|
    args.state_dim = args.state_shape[0]
 | 
						|
    args.action_dim = args.action_shape[0]
 | 
						|
    print("Max_action", args.max_action)
 | 
						|
 | 
						|
    test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
 | 
						|
    if args.norm_obs:
 | 
						|
        test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
 | 
						|
 | 
						|
    # seed
 | 
						|
    np.random.seed(args.seed)
 | 
						|
    torch.manual_seed(args.seed)
 | 
						|
    test_envs.seed(args.seed)
 | 
						|
 | 
						|
    # model
 | 
						|
    # actor network
 | 
						|
    net_a = Net(
 | 
						|
        args.state_shape,
 | 
						|
        hidden_sizes=args.hidden_sizes,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    actor = Actor(
 | 
						|
        net_a,
 | 
						|
        action_shape=args.action_shape,
 | 
						|
        max_action=args.max_action,
 | 
						|
        device=args.device,
 | 
						|
    ).to(args.device)
 | 
						|
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
 | 
						|
 | 
						|
    # critic network
 | 
						|
    net_c1 = Net(
 | 
						|
        args.state_shape,
 | 
						|
        args.action_shape,
 | 
						|
        hidden_sizes=args.hidden_sizes,
 | 
						|
        concat=True,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    net_c2 = Net(
 | 
						|
        args.state_shape,
 | 
						|
        args.action_shape,
 | 
						|
        hidden_sizes=args.hidden_sizes,
 | 
						|
        concat=True,
 | 
						|
        device=args.device,
 | 
						|
    )
 | 
						|
    critic1 = Critic(net_c1, device=args.device).to(args.device)
 | 
						|
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
 | 
						|
    critic2 = Critic(net_c2, device=args.device).to(args.device)
 | 
						|
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
 | 
						|
 | 
						|
    policy = TD3BCPolicy(
 | 
						|
        actor=actor,
 | 
						|
        actor_optim=actor_optim,
 | 
						|
        critic=critic1,
 | 
						|
        critic_optim=critic1_optim,
 | 
						|
        critic2=critic2,
 | 
						|
        critic2_optim=critic2_optim,
 | 
						|
        tau=args.tau,
 | 
						|
        gamma=args.gamma,
 | 
						|
        exploration_noise=GaussianNoise(sigma=args.exploration_noise),
 | 
						|
        policy_noise=args.policy_noise,
 | 
						|
        update_actor_freq=args.update_actor_freq,
 | 
						|
        noise_clip=args.noise_clip,
 | 
						|
        alpha=args.alpha,
 | 
						|
        estimation_step=args.n_step,
 | 
						|
        action_space=env.action_space,
 | 
						|
    )
 | 
						|
 | 
						|
    # load a previous policy
 | 
						|
    if args.resume_path:
 | 
						|
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
 | 
						|
        print("Loaded agent from: ", args.resume_path)
 | 
						|
 | 
						|
    # collector
 | 
						|
    test_collector = Collector(policy, test_envs)
 | 
						|
 | 
						|
    # log
 | 
						|
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
 | 
						|
    args.algo_name = "td3_bc"
 | 
						|
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
 | 
						|
    log_path = os.path.join(args.logdir, log_name)
 | 
						|
 | 
						|
    # logger
 | 
						|
    if args.logger == "wandb":
 | 
						|
        logger = WandbLogger(
 | 
						|
            save_interval=1,
 | 
						|
            name=log_name.replace(os.path.sep, "__"),
 | 
						|
            run_id=args.resume_id,
 | 
						|
            config=args,
 | 
						|
            project=args.wandb_project,
 | 
						|
        )
 | 
						|
    writer = SummaryWriter(log_path)
 | 
						|
    writer.add_text("args", str(args))
 | 
						|
    if args.logger == "tensorboard":
 | 
						|
        logger = TensorboardLogger(writer)
 | 
						|
    else:  # wandb
 | 
						|
        logger.load(writer)
 | 
						|
 | 
						|
    def save_best_fn(policy):
 | 
						|
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
 | 
						|
 | 
						|
    def watch():
 | 
						|
        if args.resume_path is None:
 | 
						|
            args.resume_path = os.path.join(log_path, "policy.pth")
 | 
						|
 | 
						|
        policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
 | 
						|
        policy.eval()
 | 
						|
        collector = Collector(policy, env)
 | 
						|
        collector.collect(n_episode=1, render=1 / 35)
 | 
						|
 | 
						|
    if not args.watch:
 | 
						|
        replay_buffer = load_buffer_d4rl(args.expert_data_task)
 | 
						|
        if args.norm_obs:
 | 
						|
            replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer)
 | 
						|
            test_envs.set_obs_rms(obs_rms)
 | 
						|
        # trainer
 | 
						|
        result = OfflineTrainer(
 | 
						|
            policy=policy,
 | 
						|
            buffer=replay_buffer,
 | 
						|
            test_collector=test_collector,
 | 
						|
            max_epoch=args.epoch,
 | 
						|
            step_per_epoch=args.step_per_epoch,
 | 
						|
            episode_per_test=args.test_num,
 | 
						|
            batch_size=args.batch_size,
 | 
						|
            save_best_fn=save_best_fn,
 | 
						|
            logger=logger,
 | 
						|
        ).run()
 | 
						|
        pprint.pprint(result)
 | 
						|
    else:
 | 
						|
        watch()
 | 
						|
 | 
						|
    # Let's watch its performance!
 | 
						|
    policy.eval()
 | 
						|
    test_envs.seed(args.seed)
 | 
						|
    test_collector.reset()
 | 
						|
    result = test_collector.collect(n_episode=args.test_num, render=args.render)
 | 
						|
    print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_td3_bc()
 |