diff --git a/examples/mujoco/config/logger.yml b/examples/mujoco/config/logger.yml new file mode 100644 index 0000000..16fe055 --- /dev/null +++ b/examples/mujoco/config/logger.yml @@ -0,0 +1,5 @@ +# Default logger config, keep in sync with LoggerConfig dataclass + +logger: tensorboard +logdir: log +wandb_project: mujoco.benchmark diff --git a/examples/mujoco/config/sampling.yml b/examples/mujoco/config/sampling.yml new file mode 100644 index 0000000..87ee291 --- /dev/null +++ b/examples/mujoco/config/sampling.yml @@ -0,0 +1,10 @@ +# Default config for sampling, epochs, parallelization, buffers, collectors, and batching. +# Keep in sync with RLSamplingConfig dataclass. +epoch: 100 +step_per_epoch: 30000 +batch_size: 64 +training_num: 64 +test_num: 10 +buffer_size: 4096 +step_per_collect: 2048 +repeat_per_collect: 10 diff --git a/examples/mujoco/default_config.yml b/examples/mujoco/default_config.yml new file mode 100644 index 0000000..6fdb01e --- /dev/null +++ b/examples/mujoco/default_config.yml @@ -0,0 +1,45 @@ +# General config +logger: "tensorboard" +wandb_project: "mujoco.benchmark" +seed: 24 +logdir: "log" +device: "cpu" +watch: false +render: 0.0 +resume_path: null +resume_id: null + +# Training: NN +lr: 3e-4 +hidden_sizes: [64, 64] +lr_decay: true + +# Training: sampling +training_num: 64 +test_num: 10 +repeat_per_collect: 10 +batch_size: 64 +epoch: 100 +step_per_epoch: 30000 +step_per_collect: 2048 +buffer_size: 4096 + +# Training: RL modelling +gamma: 0.99 +rew_norm: true +dual_clip: null +value_clip: false +norm_adv: false +recompute_adv: true +gae_lambda: 0.95 + +# Training: PPO specifics +ent_coef: 0.0 +vf_coef: 0.25 +bound_action_method: "clip" +max_grad_norm: 0.5 +eps_clip: 0.2 + + +# Mujoco +task: "Ant-v3" diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index bb84629..430a14b 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -10,7 +10,9 @@ except ImportError: envpool = None -def make_mujoco_env(task, seed, training_num, test_num, obs_norm): +def make_mujoco_env( + task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool +): """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. @@ -18,17 +20,16 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm): :return: a tuple of (single env, training envs, test envs). """ if envpool is not None: - train_envs = env = envpool.make_gymnasium(task, num_envs=training_num, seed=seed) - test_envs = envpool.make_gymnasium(task, num_envs=test_num, seed=seed) + train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed) + test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed) else: warnings.warn( "Recommend using envpool (pip install envpool) " "to run Mujoco environments more efficiently.", ) env = gym.make(task) - train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(training_num)]) - test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) - env.seed(seed) + train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) + test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) train_envs.seed(seed) test_envs.seed(seed) if obs_norm: diff --git a/examples/mujoco/mujoco_ppo_cfg.py b/examples/mujoco/mujoco_ppo_cfg.py new file mode 100644 index 0000000..bd4628a --- /dev/null +++ b/examples/mujoco/mujoco_ppo_cfg.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint +from collections.abc import Sequence +from typing import Literal, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch +from jsonargparse import CLI +from torch import nn +from torch.distributions import Independent, Normal +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.tensorboard import SummaryWriter + +from mujoco_env import make_mujoco_env +from tianshou.config import ( + BasicExperimentConfig, + LoggerConfig, + NNConfig, + PGConfig, + PPOConfig, + RLAgentConfig, + RLSamplingConfig, +) +from tianshou.config.utils import collect_configs +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import VectorEnvNormObs +from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.trainer import OnpolicyTrainer +from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def set_seed(seed=42): + np.random.seed(seed) + torch.manual_seed(seed) + + +def get_logger_for_run( + algo_name: str, + task: str, + logger_config: LoggerConfig, + config: dict, + seed: int, + resume_id: Optional[Union[str, int]], +) -> Tuple[str, Union[WandbLogger, TensorboardLogger]]: + """ + + :param algo_name: + :param task: + :param logger_config: + :param config: the experiment config + :param seed: + :param resume_id: used as run_id by wandb, unused for tensorboard + :return: + """ + """Returns the log_path and logger.""" + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(task, algo_name, str(seed), now) + log_path = os.path.join(logger_config.logdir, log_name) + + logger = get_logger( + logger_config.logger, + log_path, + log_name=log_name, + run_id=resume_id, + config=config, + wandb_project=logger_config.wandb_project, + ) + return log_path, logger + + +def get_continuous_env_info( + env: gym.Env, +) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]: + if not isinstance(env.action_space, gym.spaces.Box): + raise ValueError( + "Only environments with continuous action space are supported here. " + f"But got env with action space: {env.action_space.__class__}." + ) + state_shape = env.observation_space.shape or env.observation_space.n + if not state_shape: + raise ValueError("Observation space shape is not defined") + action_shape = env.action_space.shape + max_action = env.action_space.high[0] + return state_shape, action_shape, max_action + + +def resume_from_checkpoint( + path: str, + policy: BasePolicy, + train_envs: VectorEnvNormObs | None = None, + test_envs: VectorEnvNormObs | None = None, + device: str | int | torch.device | None = None, +): + ckpt = torch.load(path, map_location=device) + policy.load_state_dict(ckpt["model"]) + if train_envs: + train_envs.set_obs_rms(ckpt["obs_rms"]) + if test_envs: + test_envs.set_obs_rms(ckpt["obs_rms"]) + print("Loaded agent and obs. running means from: ", path) + + +def watch_agent(n_episode, policy: BasePolicy, test_collector: Collector, render=0.0): + policy.eval() + test_collector.reset() + result = test_collector.collect(n_episode=n_episode, render=render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +def get_train_test_collector( + buffer_size: int, + policy: BasePolicy, + train_envs: VectorEnvNormObs, + test_envs: VectorEnvNormObs, +): + if len(train_envs) > 1: + buffer = VectorReplayBuffer(buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + return test_collector, train_collector + + +TShape = Union[int, Sequence[int]] + + +def get_actor_critic( + state_shape: TShape, + hidden_sizes: Sequence[int], + action_shape: TShape, + device: str | int | torch.device = "cpu", +): + net_a = Net( + state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device + ) + actor = ActorProb(net_a, action_shape, unbounded=True, device=device).to(device) + net_c = Net( + state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device + ) + # TODO: twice device? + critic = Critic(net_c, device=device).to(device) + return actor, critic + + +def get_logger( + kind: Literal["wandb", "tensorboard"], + log_path: str, + log_name="", + run_id: Optional[Union[str, int]] = None, + config: Optional[Union[dict, argparse.Namespace]] = None, + wandb_project: Optional[str] = None, +): + writer = SummaryWriter(log_path) + writer.add_text("args", str(config)) + if kind == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=run_id, + config=config, + project=wandb_project, + ) + logger.load(writer) + elif kind == "tensorboard": + logger = TensorboardLogger(writer) + else: + raise ValueError(f"Unknown logger: {kind}") + return logger + + +def get_lr_scheduler(optim, step_per_epoch: int, step_per_collect: int, epochs: int): + """Decay learning rate to 0 linearly.""" + max_update_num = np.ceil(step_per_epoch / step_per_collect) * epochs + lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + return lr_scheduler + + +def init_and_get_optim(actor: nn.Module, critic: nn.Module, lr: float): + """Initializes layers of actor and critic. + + :param actor: + :param critic: + :param lr: + :return: + """ + actor_critic = ActorCritic(actor, critic) + torch.nn.init.constant_(actor.sigma_param, -0.5) + for m in actor_critic.modules(): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + if hasattr(actor, "mu"): + # For continuous action spaces with Gaussian policies + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in actor.mu.modules(): + # TODO: seems like biases are initialized twice for the actor + if isinstance(m, torch.nn.Linear): + torch.nn.init.zeros_(m.bias) + m.weight.data.copy_(0.01 * m.weight.data) + optim = torch.optim.Adam(actor_critic.parameters(), lr=lr) + return optim + + +def main( + experiment_config: BasicExperimentConfig, + logger_config: LoggerConfig, + sampling_config: RLSamplingConfig, + general_config: RLAgentConfig, + pg_config: PGConfig, + ppo_config: PPOConfig, + nn_config: NNConfig, +): + """ + Run the PPO test on the provided parameters. + + :param experiment_config: BasicExperimentConfig - not ML or RL specific + :param logger_config: LoggerConfig + :param sampling_config: SamplingConfig - + sampling, epochs, parallelization, buffers, collectors, and batching. + :param general_config: RLAgentConfig - general RL agent config + :param pg_config: PGConfig: common to most policy gradient algorithms + :param ppo_config: PPOConfig - PPO specific config + :param nn_config: NNConfig - NN-training specific config + + :return: None + """ + full_config = collect_configs(*locals().values()) + set_seed(experiment_config.seed) + + # create test and train envs, add env info to config + env, train_envs, test_envs = make_mujoco_env( + task=experiment_config.task, + seed=experiment_config.seed, + num_train_envs=sampling_config.num_train_envs, + num_test_envs=sampling_config.num_test_envs, + obs_norm=True, + ) + + # adding env_info to logged config + state_shape, action_shape, max_action = get_continuous_env_info(env) + full_config["env_info"] = { + "state_shape": state_shape, + "action_shape": action_shape, + "max_action": max_action, + } + log_path, logger = get_logger_for_run( + "ppo", + experiment_config.task, + logger_config, + full_config, + experiment_config.seed, + experiment_config.resume_id, + ) + + # Setup NNs + actor, critic = get_actor_critic( + state_shape, nn_config.hidden_sizes, action_shape, experiment_config.device + ) + optim = init_and_get_optim(actor, critic, nn_config.lr) + + lr_scheduler = None + if nn_config.lr_decay: + lr_scheduler = get_lr_scheduler( + optim, + sampling_config.step_per_epoch, + sampling_config.step_per_collect, + sampling_config.num_epochs, + ) + + # Create policy + def dist_fn(*logits): + return Independent(Normal(*logits), 1) + + policy = PPOPolicy( + # nn-stuff + actor, + critic, + optim, + dist_fn=dist_fn, + lr_scheduler=lr_scheduler, + # env-stuff + action_space=train_envs.action_space, + action_scaling=True, + # general_config + discount_factor=general_config.gamma, + gae_lambda=general_config.gae_lambda, + reward_normalization=general_config.rew_norm, + action_bound_method=general_config.action_bound_method, + # pg_config + max_grad_norm=pg_config.max_grad_norm, + vf_coef=pg_config.vf_coef, + ent_coef=pg_config.ent_coef, + # ppo_config + eps_clip=ppo_config.eps_clip, + value_clip=ppo_config.value_clip, + dual_clip=ppo_config.dual_clip, + advantage_normalization=ppo_config.norm_adv, + recompute_advantage=ppo_config.recompute_adv, + ) + + if experiment_config.resume_path: + resume_from_checkpoint( + experiment_config.resume_path, + policy, + train_envs=train_envs, + test_envs=test_envs, + device=experiment_config.device, + ) + + test_collector, train_collector = get_train_test_collector( + sampling_config.buffer_size, policy, test_envs, train_envs + ) + + # TODO: test num is the number of test envs but used as episode_per_test + # here and in watch_agent + if not experiment_config.watch: + # RL training + def save_best_fn(pol: nn.Module): + state = {"model": pol.state_dict(), "obs_rms": train_envs.get_obs_rms()} + torch.save(state, os.path.join(log_path, "policy.pth")) + + trainer = OnpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + repeat_per_collect=sampling_config.repeat_per_collect, + episode_per_test=sampling_config.num_test_envs, + batch_size=sampling_config.batch_size, + step_per_collect=sampling_config.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + result = trainer.run() + pprint.pprint(result) + + watch_agent( + sampling_config.num_test_envs, + policy, + test_collector, + render=experiment_config.render, + ) + + +if __name__ == "__main__": + CLI(main) diff --git a/tianshou/config/__init__.py b/tianshou/config/__init__.py new file mode 100644 index 0000000..27c9ec6 --- /dev/null +++ b/tianshou/config/__init__.py @@ -0,0 +1 @@ +from .config import * diff --git a/tianshou/config/config.py b/tianshou/config/config.py new file mode 100644 index 0000000..ed9fa26 --- /dev/null +++ b/tianshou/config/config.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import Literal, Optional, Sequence + +import torch +from jsonargparse import set_docstring_parse_options + +set_docstring_parse_options(attribute_docstrings=True) + + +@dataclass +class BasicExperimentConfig: + """Generic config for setting up the experiment, not RL or training specific.""" + + seed: int = 42 + task: str = "Ant-v4" + """Mujoco specific""" + render: float = 0.0 + """Milliseconds between rendered frames""" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + resume_id: Optional[int] = None + """For restoring a model and running means of env-specifics from a checkpoint""" + resume_path: str = None + """For restoring a model and running means of env-specifics from a checkpoint""" + watch: bool = False + """If True, will not perform training and only watch the restored policy""" + + +@dataclass +class LoggerConfig: + """Logging config""" + + logdir: str = "log" + logger: Literal["tensorboard", "wandb"] = "tensorboard" + wandb_project: str = "mujoco.benchmark" + """Only used if logger is wandb.""" + + +@dataclass +class RLSamplingConfig: + """Sampling, epochs, parallelization, buffers, collectors, and batching.""" + + num_epochs: int = 100 + step_per_epoch: int = 30000 + batch_size: int = 64 + num_train_envs: int = 64 + num_test_envs: int = 10 + buffer_size: int = 4096 + step_per_collect: int = 2048 + repeat_per_collect: int = 10 + + +@dataclass +class RLAgentConfig: + """Config common to most RL algorithms""" + + gamma: float = 0.99 + """Discount factor""" + gae_lambda: float = 0.95 + """For Generalized Advantage Estimate (equivalent to TD(lambda))""" + action_bound_method: Optional[Literal["clip", "tanh"]] = "clip" + """How to map original actions in range (-inf, inf) to [-1, 1]""" + rew_norm: bool = True + """Whether to normalize rewards""" + + +@dataclass +class PGConfig: + """Config of general policy-gradient algorithms""" + + ent_coef: float = 0.0 + vf_coef: float = 0.25 + max_grad_norm: float = 0.5 + + +@dataclass +class PPOConfig: + """PPO specific config""" + + value_clip: bool = False + norm_adv: bool = False + """Whether to normalize advantages""" + eps_clip: float = 0.2 + dual_clip: Optional[float] = None + recompute_adv: bool = True + + +@dataclass +class NNConfig: + hidden_sizes: Sequence[int] = (64, 64) + lr: float = 3e-4 + lr_decay: bool = True diff --git a/tianshou/config/utils.py b/tianshou/config/utils.py new file mode 100644 index 0000000..e3a0998 --- /dev/null +++ b/tianshou/config/utils.py @@ -0,0 +1,25 @@ +from dataclasses import asdict, is_dataclass + + +def collect_configs(*confs): + """ + Collect instances of dataclasses to a single dict mapping the + classname to the values. If any of the passed objects is not a + dataclass or if two instances of the same config class are passed, + an error will be raised. + + :param confs: dataclasses + :return: Dictionary mapping class names to their instances. + """ + result = {} + + for conf in confs: + if not is_dataclass(conf): + raise ValueError(f"Object {conf.__class__.__name__} is not a dataclass.") + + if conf.__class__.__name__ in result: + raise ValueError(f"Duplicate instance of {conf.__class__.__name__} found.") + + result[conf.__class__.__name__] = asdict(conf) + + return result