#!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import ImitationPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--step-per-epoch", type=int, default=5000) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--gamma", default=0.99) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_il(): args = get_args() env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] # float print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) args.state_dim = args.state_shape[0] args.action_dim = args.action_shape[0] print("Max_action", args.max_action) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = Actor( net, action_shape=args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) optim = torch.optim.Adam(actor.parameters(), lr=args.lr) policy = ImitationPolicy( actor, optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector(policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "cql" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger if args.logger == "wandb": logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: # wandb logger.load(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) # trainer result = OfflineTrainer( policy=policy, buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ).run() pprint.pprint(result) else: watch() # Let's watch its performance! policy.eval() test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") if __name__ == "__main__": test_il()