#!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--step-per-epoch", type=int, default=5000) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--gamma", default=0.99) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_il() -> None: args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = Actor( net, action_shape=args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) optim = torch.optim.Adam(actor.parameters(), lr=args.lr) policy: ImitationPolicy = ImitationPolicy( actor=actor, optim=optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector(policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "cql" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) # trainer result = OfflineTrainer( policy=policy, buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ).run() pprint.pprint(result) else: watch() # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, eval_mode=True, ) print(collector_stats) if __name__ == "__main__": test_il()