#!/usr/bin/env python3 import argparse import datetime import os import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer from tianshou.data import Collector from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") parser.add_argument("--buffer-size", type=int, default=1000000) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--actor-lr", type=float, default=3e-4) parser.add_argument("--critic-lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--step-per-epoch", type=int, default=5000) parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--alpha", type=float, default=2.5) parser.add_argument("--exploration-noise", type=float, default=0.1) parser.add_argument("--policy-noise", type=float, default=0.2) parser.add_argument("--noise-clip", type=float, default=0.5) parser.add_argument("--update-actor-freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--norm-obs", type=int, default=1) parser.add_argument("--eval-freq", type=int, default=1) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) return parser.parse_args() def test_td3_bc() -> None: args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action args.min_action = space_info.action_info.min_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", args.min_action, args.max_action) args.state_dim = space_info.observation_info.obs_dim args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) test_envs: BaseVectorEnv test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) if args.norm_obs: test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # actor network net_a = Net( args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = Actor( net_a, action_shape=args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) # critic network net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy: TD3BCPolicy = TD3BCPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, estimation_step=args.n_step, action_space=env.action_space, ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector(policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "td3_bc" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) logger.load(writer) def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) if args.norm_obs: replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) test_envs.set_obs_rms(obs_rms) # trainer result = OfflineTrainer( policy=policy, buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, ).run() pprint.pprint(result) else: watch() # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) if __name__ == "__main__": test_td3_bc()