import argparse import os import pickle import gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import SACPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic def expert_file_name(): return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v1.pkl") def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v1') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--epoch', type=int, default=7) parser.add_argument('--step-per-epoch', type=int, default=8000) parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) parser.add_argument('--resume-path', type=str, default=None) parser.add_argument( '--watch', default=False, action='store_true', help='watch the play of pre-trained policy only' ) # sac: parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--auto-alpha', type=int, default=1) parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) args = parser.parse_known_args()[0] return args def gather_data(): """Return expert buffer data.""" args = get_args() env = gym.make(args.task) if args.task == 'Pendulum-v0': env.spec.reward_threshold = -250 args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)] ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, max_action=args.max_action, device=args.device, unbounded=True, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, estimation_step=args.n_step, action_space=env.action_space, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold # trainer offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, save_fn=save_fn, stop_fn=stop_fn, logger=logger, ) train_collector.reset() result = train_collector.collect(n_step=args.buffer_size) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: pickle.dump(buffer, open(args.save_buffer_name, "wb")) return buffer