import argparse import datetime import os import pickle import pprint import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import CQLPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data else: # pytest from test.offline.gather_pendulum_data import expert_file_name, gather_data def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--auto-alpha', default=True, action='store_true') parser.add_argument('--alpha-lr', type=float, default=1e-3) parser.add_argument('--cql-alpha-lr', type=float, default=1e-3) parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=500) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--cql-weight", type=float, default=1.0) parser.add_argument("--with-lagrange", type=bool, default=True) parser.add_argument("--lagrange-threshold", type=float, default=10.0) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval-freq", type=int, default=1) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=1 / 35) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) parser.add_argument('--resume-path', type=str, default=None) parser.add_argument( '--watch', default=False, action='store_true', help='watch the play of pre-trained policy only', ) parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) args = parser.parse_known_args()[0] return args def test_cql(args=get_args()): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: buffer = pickle.load(open(args.load_buffer_name, "rb")) else: buffer = gather_data() env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] # float if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) args.state_dim = args.state_shape[0] args.action_dim = args.action_shape[0] # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # actor network net_a = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = ActorProb( net_a, action_shape=args.action_shape, device=args.device, unbounded=True, conditioned_sigma=True, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) # critic network net_c1 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) policy = CQLPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, cql_alpha_lr=args.cql_alpha_lr, cql_weight=args.cql_weight, tau=args.tau, gamma=args.gamma, alpha=args.alpha, temperature=args.temperature, with_lagrange=args.with_lagrange, lagrange_threshold=args.lagrange_threshold, min_action=np.min(env.action_space.low), max_action=np.max(env.action_space.high), device=args.device, ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' log_path = os.path.join(args.logdir, args.task, 'cql', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def watch(): policy.load_state_dict( torch.load( os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') ) ) policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) # trainer trainer = OfflineTrainer( policy, buffer, test_collector, args.epoch, args.step_per_epoch, args.test_num, args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) print(info) assert stop_fn(info["best_reward"]) # Let's watch its performance! if __name__ == "__main__": pprint.pprint(info) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': test_cql()