import argparse import os import pprint import numpy as np import pytest import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import PSRLPolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger try: import envpool except ImportError: envpool = None def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="NChain-v0") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=50000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--episode-per-collect", type=int, default=1) parser.add_argument("--training-num", type=int, default=1) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--rew-mean-prior", type=float, default=0.0) parser.add_argument("--rew-std-prior", type=float, default=1.0) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eps", type=float, default=0.01) parser.add_argument("--add-done-loop", action="store_true", default=False) parser.add_argument( "--logger", type=str, default="none", # TODO: Change to "wandb" once wandb supports Gym >=0.26.0 choices=["wandb", "tensorboard", "none"], ) return parser.parse_known_args()[0] @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_psrl(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) if args.reward_threshold is None: default_reward_threshold = {"NChain-v0": 3400} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) print("reward threshold:", args.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # model n_action = args.action_shape n_state = args.state_shape trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy: PSRLPolicy = PSRLPolicy( trans_count_prior=trans_count_prior, rew_mean_prior=rew_mean_prior, rew_std_prior=rew_std_prior, action_space=env.action_space, discount_factor=args.gamma, epsilon=args.eps, add_done_loop=args.add_done_loop, ) # collector train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) train_collector.reset() test_collector = Collector(policy, test_envs) test_collector.reset() # Logger log_path = os.path.join(args.logdir, args.task, "psrl") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger: WandbLogger | TensorboardLogger | LazyLogger if args.logger == "wandb": logger = WandbLogger(save_interval=1, project="psrl", name="wandb_test", config=args) logger.load(writer) elif args.logger == "tensorboard": logger = TensorboardLogger(writer) else: logger = LazyLogger() def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold train_collector.collect(n_step=args.buffer_size, random=True) # trainer, test it without logger result = OnpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, repeat_per_collect=1, episode_per_test=args.test_num, batch_size=0, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger, test_in_train=False, ).run() if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! policy.eval() test_envs.seed(args.seed) test_collector.reset() stats = test_collector.collect(n_episode=args.test_num, render=args.render) stats.pprint_asdict() elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold if __name__ == "__main__": test_psrl()