#!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--buffer-size", type=int, default=1000000) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--step-per-epoch", type=int, default=5000) parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[512, 512]) # default to 2 * action_dim parser.add_argument("--latent-dim", type=int) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) # Weighting for Clipped Double Q-learning in BCQ parser.add_argument("--lmbda", default=0.75) # Max perturbation hyper-parameter for BCQ parser.add_argument("--phi", default=0.05) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_bcq() -> None: args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # perturbation network net_a = MLP( input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, device=args.device, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 vae_decoder = MLP( input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, device=args.device, ) vae = VAE( vae_encoder, vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, device=args.device, ).to(args.device) vae_optim = torch.optim.Adam(vae.parameters()) policy: BCQPolicy = BCQPolicy( actor_perturbation=actor, actor_perturbation_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, action_space=env.action_space, critic2=critic2, critic2_optim=critic2_optim, vae=vae, vae_optim=vae_optim, device=args.device, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector(policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "bcq" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) # trainer result = OfflineTrainer( policy=policy, buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ).run() pprint.pprint(result) else: watch() # Let's watch its performance! policy.eval() test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) if __name__ == "__main__": test_bcq()