import argparse import pprint import gym import numpy as np import torch from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.policy import BranchingDQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.common import BranchingNet def get_args(): parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument('--reward-threshold', type=float, default=None) # network architecture parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64]) parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64]) parser.add_argument("--action-per-branch", type=int, default=40) # training hyperparameters parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.01) parser.add_argument("--eps-train", type=float, default=0.76) parser.add_argument("--eps-decay", type=float, default=1e-4) parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--target-update-freq", type=int, default=200) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--step-per-epoch", type=int, default=80000) parser.add_argument("--step-per-collect", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) parser.add_argument("--batch-size", type=int, default=128) parser.add_argument("--training-num", type=int, default=10) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument('--render', type=float, default=0.) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) args = parser.parse_known_args()[0] return args def test_bdq(args=get_args()): env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) args.state_shape = env.observation_space.shape or env.observation_space.n args.num_branches = env.action_space.shape[0] if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) print("Actions per branch:", args.action_per_branch) train_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) for _ in range(args.training_num) ] ) test_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) for _ in range(args.test_num) ] ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = BranchingNet( args.state_shape, args.num_branches, args.action_per_branch, args.common_hidden_sizes, args.value_hidden_sizes, args.action_hidden_sizes, device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = BranchingDQNPolicy( net, optim, args.gamma, target_update_freq=args.target_update_freq ) # collector train_collector = Collector( policy, train_envs, VectorReplayBuffer(args.buffer_size, args.training_num), exploration_noise=True ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) def train_fn(epoch, env_step): # exp decay eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test) policy.set_eps(eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, ) # assert stop_fn(result["best_reward"]) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == "__main__": test_bdq(get_args())