import argparse import os import pickle import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import SACPolicy from tianshou.policy.base import BasePolicy from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.space_info import SpaceInfo def expert_file_name() -> str: return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v1.pkl") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=7) parser.add_argument("--step-per-epoch", type=int, default=8000) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--training-num", type=int, default=10) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.125) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume-path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) # sac: parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--auto-alpha", type=int, default=1) parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] def gather_data() -> VectorReplayBuffer: """Return expert buffer data.""" args = get_args() env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, device=args.device, unbounded=True, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) policy: SACPolicy[SACTrainingStats] = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, estimation_step=args.n_step, action_space=env.action_space, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer OffpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, step_per_epoch=args.step_per_epoch, step_per_collect=args.step_per_collect, episode_per_test=args.test_num, batch_size=args.batch_size, update_per_step=args.update_per_step, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ).run() train_collector.reset() collector_stats = train_collector.collect(n_step=args.buffer_size) print(collector_stats) if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buffer, f) return buffer