import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import SACPolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic def expert_file_name(): return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v1.pkl") def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=7) parser.add_argument("--step-per-epoch", type=int, default=8000) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--training-num", type=int, default=10) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.125) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume-path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) # sac: parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto-alpha", type=int, default=1) parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--rew-norm", action="store_true", default=False) parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] def gather_data(): """Return expert buffer data.""" args = get_args() env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, device=args.device, unbounded=True, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, estimation_step=args.n_step, action_space=env.action_space, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold # trainer OffpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, step_per_collect=args.step_per_collect, episode_per_test=args.test_num, batch_size=args.batch_size, update_per_step=args.update_per_step, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ).run() train_collector.reset() result = train_collector.collect(n_step=args.buffer_size) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buffer, f) return buffer