diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 430a14b..7e3c38c 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -2,7 +2,9 @@ import warnings import gymnasium as gym +from tianshou.config import RLSamplingConfig, BasicExperimentConfig from tianshou.env import ShmemVectorEnv, VectorEnvNormObs +from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments try: import envpool @@ -38,3 +40,19 @@ def make_mujoco_env( test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) test_envs.set_obs_rms(train_envs.get_obs_rms()) return env, train_envs, test_envs + + +class MujocoEnvFactory(EnvFactory): + def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig): + self.sampling_config = sampling_config + self.experiment_config = experiment_config + + def create_envs(self) -> ContinuousEnvironments: + env, train_envs, test_envs = make_mujoco_env( + task=self.experiment_config.task, + seed=self.experiment_config.seed, + num_train_envs=self.sampling_config.num_train_envs, + num_test_envs=self.sampling_config.num_test_envs, + obs_norm=True, + ) + return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py new file mode 100644 index 0000000..7356afe --- /dev/null +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +import datetime +import os + +from jsonargparse import CLI +from torch.distributions import Independent, Normal + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.config import ( + BasicExperimentConfig, + LoggerConfig, + NNConfig, + PGConfig, + PPOConfig, + RLAgentConfig, + RLSamplingConfig, +) +from tianshou.highlevel.agent import PPOAgentFactory +from tianshou.highlevel.logger import DefaultLoggerFactory +from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory +from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory +from tianshou.highlevel.experiment import RLExperiment + + +def main( + experiment_config: BasicExperimentConfig, + logger_config: LoggerConfig, + sampling_config: RLSamplingConfig, + general_config: RLAgentConfig, + pg_config: PGConfig, + ppo_config: PPOConfig, + nn_config: NNConfig, +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now) + logger_factory = DefaultLoggerFactory(logger_config) + + env_factory = MujocoEnvFactory(experiment_config, sampling_config) + + def dist_fn(*logits): + return Independent(Normal(*logits), 1) + + actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes) + critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes) + optim_factory = AdamOptimizerFactory(lr=nn_config.lr) + lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config) + agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config, + actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory) + + experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config, + env_factory, + logger_factory, + agent_factory) + + experiment.run(log_name) + + +if __name__ == "__main__": + CLI(main) diff --git a/tianshou/config/config.py b/tianshou/config/config.py index ed9fa26..c7f9a6c 100644 --- a/tianshou/config/config.py +++ b/tianshou/config/config.py @@ -14,8 +14,8 @@ class BasicExperimentConfig: seed: int = 42 task: str = "Ant-v4" """Mujoco specific""" - render: float = 0.0 - """Milliseconds between rendered frames""" + render: Optional[float] = 0.0 + """Milliseconds between rendered frames; if None, no rendering""" device: str = "cuda" if torch.cuda.is_available() else "cpu" resume_id: Optional[int] = None """For restoring a model and running means of env-specifics from a checkpoint""" @@ -23,6 +23,7 @@ class BasicExperimentConfig: """For restoring a model and running means of env-specifics from a checkpoint""" watch: bool = False """If True, will not perform training and only watch the restored policy""" + watch_num_episodes = 10 @dataclass diff --git a/tianshou/highlevel/__init__.py b/tianshou/highlevel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py new file mode 100644 index 0000000..8314aeb --- /dev/null +++ b/tianshou/highlevel/agent.py @@ -0,0 +1,146 @@ +import os +from abc import abstractmethod, ABC +from typing import Callable + +import torch + +from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig +from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector +from tianshou.highlevel.env import Environments +from tianshou.highlevel.logger import Logger +from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice +from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory +from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.trainer import BaseTrainer, OnpolicyTrainer +from tianshou.utils.net.common import ActorCritic + + +CHECKPOINT_DICT_KEY_MODEL = "model" +CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" + + +class AgentFactory(ABC): + @abstractmethod + def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + pass + + @staticmethod + def _create_save_best_fn(envs: Environments, log_path: str) -> Callable: + def save_best_fn(pol: torch.nn.Module): + state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()} + torch.save(state, os.path.join(log_path, "policy.pth")) + + return save_best_fn + + @staticmethod + def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice): + ckpt = torch.load(path, map_location=device) + policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL]) + if envs.train_envs: + envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS]) + if envs.test_envs: + envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS]) + print("Loaded agent and obs. running means from: ", path) # TODO logging + + @abstractmethod + def create_train_test_collector(self, + policy: BasePolicy, + envs: Environments): + pass + + @abstractmethod + def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector, + envs: Environments, logger: Logger) -> BaseTrainer: + pass + + +class OnpolicyAgentFactory(AgentFactory, ABC): + def __init__(self, sampling_config: RLSamplingConfig): + self.sampling_config = sampling_config + + def create_train_test_collector(self, + policy: BasePolicy, + envs: Environments): + buffer_size = self.sampling_config.buffer_size + train_envs = envs.train_envs + if len(train_envs) > 1: + buffer = VectorReplayBuffer(buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, envs.test_envs) + return train_collector, test_collector + + def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector, + envs: Environments, logger: Logger) -> OnpolicyTrainer: + sampling_config = self.sampling_config + return OnpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + repeat_per_collect=sampling_config.repeat_per_collect, + episode_per_test=sampling_config.num_test_envs, + batch_size=sampling_config.batch_size, + step_per_collect=sampling_config.step_per_collect, + save_best_fn=self._create_save_best_fn(envs, logger.log_path), + logger=logger.logger, + test_in_train=False, + ) + + +class PPOAgentFactory(OnpolicyAgentFactory): + def __init__(self, general_config: RLAgentConfig, + pg_config: PGConfig, + ppo_config: PPOConfig, + sampling_config: RLSamplingConfig, + nn_config: NNConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + dist_fn, + lr_scheduler_factory: LRSchedulerFactory): + super().__init__(sampling_config) + self.optimizer_factory = optimizer_factory + self.critic_factory = critic_factory + self.actor_factory = actor_factory + self.ppo_config = ppo_config + self.pg_config = pg_config + self.general_config = general_config + self.lr_scheduler_factory = lr_scheduler_factory + self.dist_fn = dist_fn + self.nn_config = nn_config + + def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device) + actor_critic = ActorCritic(actor, critic) + optim = self.optimizer_factory.create_optimizer(actor_critic) + lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim) + return PPOPolicy( + # nn-stuff + actor, + critic, + optim, + dist_fn=self.dist_fn, + lr_scheduler=lr_scheduler, + # env-stuff + action_space=envs.get_action_space(), + action_scaling=True, + # general_config + discount_factor=self.general_config.gamma, + gae_lambda=self.general_config.gae_lambda, + reward_normalization=self.general_config.rew_norm, + action_bound_method=self.general_config.action_bound_method, + # pg_config + max_grad_norm=self.pg_config.max_grad_norm, + vf_coef=self.pg_config.vf_coef, + ent_coef=self.pg_config.ent_coef, + # ppo_config + eps_clip=self.ppo_config.eps_clip, + value_clip=self.ppo_config.value_clip, + dual_clip=self.ppo_config.dual_clip, + advantage_normalization=self.ppo_config.norm_adv, + recompute_advantage=self.ppo_config.recompute_adv, + ) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py new file mode 100644 index 0000000..3869476 --- /dev/null +++ b/tianshou/highlevel/env.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import Tuple, Optional, Dict, Any, Union, Sequence + +import gymnasium as gym + +from tianshou.env import BaseVectorEnv + +TShape = Union[int, Sequence[int]] + + +class Environments(ABC): + def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + self.env = env + self.train_envs = train_envs + self.test_envs = test_envs + + def info(self) -> Dict[str, Any]: + return { + "action_shape": self.get_action_shape(), + "state_shape": self.get_state_shape() + } + + @abstractmethod + def get_action_shape(self) -> TShape: + pass + + @abstractmethod + def get_state_shape(self) -> TShape: + pass + + def get_action_space(self) -> gym.Space: + return self.env.action_space + + +class ContinuousEnvironments(Environments): + def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + super().__init__(env, train_envs, test_envs) + self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) + + def info(self): + d = super().info() + d["max_action"] = self.max_action + return d + + @staticmethod + def _get_continuous_env_info( + env: gym.Env, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]: + if not isinstance(env.action_space, gym.spaces.Box): + raise ValueError( + "Only environments with continuous action space are supported here. " + f"But got env with action space: {env.action_space.__class__}." + ) + state_shape = env.observation_space.shape or env.observation_space.n + if not state_shape: + raise ValueError("Observation space shape is not defined") + action_shape = env.action_space.shape + max_action = env.action_space.high[0] + return state_shape, action_shape, max_action + + def get_action_shape(self) -> TShape: + return self.action_shape + + def get_state_shape(self) -> TShape: + return self.state_shape + + +class EnvFactory(ABC): + @abstractmethod + def create_envs(self) -> Environments: + pass \ No newline at end of file diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py new file mode 100644 index 0000000..b175e42 --- /dev/null +++ b/tianshou/highlevel/experiment.py @@ -0,0 +1,76 @@ +from pprint import pprint +from typing import Generic, TypeVar + +import numpy as np +import torch + +from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig +from tianshou.data import Collector +from tianshou.highlevel.agent import AgentFactory +from tianshou.highlevel.env import EnvFactory +from tianshou.highlevel.logger import LoggerFactory +from tianshou.policy import BasePolicy +from tianshou.trainer import BaseTrainer + +TPolicy = TypeVar("TPolicy", bound=BasePolicy) +TTrainer = TypeVar("TTrainer", bound=BaseTrainer) + + +class RLExperiment(Generic[TPolicy, TTrainer]): + def __init__(self, + config: BasicExperimentConfig, + logger_config: LoggerConfig, + general_config: RLAgentConfig, + sampling_config: RLSamplingConfig, + env_factory: EnvFactory, + logger_factory: LoggerFactory, + agent_factory: AgentFactory): + self.config = config + self.logger_config = logger_config + self.general_config = general_config + self.sampling_config = sampling_config + self.env_factory = env_factory + self.logger_factory = logger_factory + self.agent_factory = agent_factory + + def _set_seed(self): + seed = self.config.seed + np.random.seed(seed) + torch.manual_seed(seed) + + def _build_config_dict(self) -> dict: + return { + # TODO + } + + def run(self, log_name: str): + self._set_seed() + + envs = self.env_factory.create_envs() + + full_config = self._build_config_dict() + full_config.update(envs.info()) + + run_id = self.config.resume_id + logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config) + + policy = self.agent_factory.create_policy(envs, self.config.device) + if self.config.resume_path: + self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device) + + train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs) + + if not self.config.watch: + trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger) + result = trainer.run() + pprint(result) # TODO logging + + self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render) + + @staticmethod + def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render): + policy.eval() + test_collector.reset() + result = test_collector.collect(n_episode=num_episodes, render=render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py new file mode 100644 index 0000000..c1c0c31 --- /dev/null +++ b/tianshou/highlevel/logger.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +import os +from dataclasses import dataclass +from typing import Union, Optional + +from torch.utils.tensorboard import SummaryWriter + +from tianshou.config import LoggerConfig +from tianshou.utils import TensorboardLogger, WandbLogger + + +TLogger = Union[TensorboardLogger, WandbLogger] + + +@dataclass +class Logger: + logger: TLogger + log_path: str + + +class LoggerFactory(ABC): + @abstractmethod + def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger: + pass + + +class DefaultLoggerFactory(LoggerFactory): + def __init__(self, config: LoggerConfig): + self.config = config + + def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger: + writer = SummaryWriter(self.config.logdir) + writer.add_text("args", str(self.config)) + if self.config.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=run_id, + config=config_dict, + project=self.config.wandb_project, + ) + logger.load(writer) + elif self.config.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: + raise ValueError(f"Unknown logger: {self.config.logger}") + log_path = os.path.join(self.config.logdir, log_name) + os.makedirs(log_path, exist_ok=True) + return Logger(logger=logger, log_path=log_path) diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module.py new file mode 100644 index 0000000..bf11934 --- /dev/null +++ b/tianshou/highlevel/module.py @@ -0,0 +1,90 @@ +from abc import abstractmethod, ABC +from typing import Sequence + +import torch +from torch import nn +import numpy as np + +from tianshou.highlevel.env import Environments +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic + +TDevice = str | int | torch.device + + +def init_linear_orthogonal(m: torch.nn.Module): + """ + Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0 + + :param m: the module whose submodules are to be processed + """ + for m in m.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + + +class ActorFactory(ABC): + @abstractmethod + def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + pass + + @staticmethod + def _init_linear(actor: torch.nn.Module): + """ + Initializes linear layers of an actor module using default mechanisms + :param module: the actor module + """ + init_linear_orthogonal(actor) + if hasattr(actor, "mu"): + # For continuous action spaces with Gaussian policies + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in actor.mu.modules(): + if isinstance(m, torch.nn.Linear): + m.weight.data.copy_(0.01 * m.weight.data) + + +class ContinuousActorFactory(ActorFactory, ABC): + pass + + +class ContinuousActorProbFactory(ContinuousActorFactory): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + net_a = Net( + envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device + ) + actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device) + + # init params + torch.nn.init.constant_(actor.sigma_param, -0.5) + self._init_linear(actor) + + return actor + + +class CriticFactory(ABC): + @abstractmethod + def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + pass + + +class ContinuousCriticFactory(CriticFactory, ABC): + pass + + +class ContinuousNetCriticFactory(ContinuousCriticFactory): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + net_c = Net( + envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device + ) + critic = ContinuousCritic(net_c, device=device).to(device) + init_linear_orthogonal(critic) + return critic diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py new file mode 100644 index 0000000..43cbb53 --- /dev/null +++ b/tianshou/highlevel/optim.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod +from typing import Union, Iterable, Dict, Any, Optional + +import numpy as np +import torch +from torch import Tensor +from torch.optim import Adam +from torch.optim.lr_scheduler import LRScheduler, LambdaLR + +from tianshou.config import RLSamplingConfig, NNConfig + +TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]] + + +class OptimizerFactory(ABC): + @abstractmethod + def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer: + pass + + +class TorchOptimizerFactory(OptimizerFactory): + def __init__(self, optim_class, **kwargs): + self.optim_class = optim_class + self.kwargs = kwargs + + def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer: + return self.optim_class(module.parameters(), **self.kwargs) + + +class AdamOptimizerFactory(OptimizerFactory): + def __init__(self, lr): + self.lr = lr + + def create_optimizer(self, module: torch.nn.Module) -> Adam: + return Adam(module.parameters(), lr=self.lr) + + +class LRSchedulerFactory(ABC): + @abstractmethod + def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]: + pass + + +class LinearLRSchedulerFactory(LRSchedulerFactory): + def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig): + self.nn_config = nn_config + self.sampling_config = sampling_config + + def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]: + lr_scheduler = None + if self.nn_config.lr_decay: + max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs + lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + return lr_scheduler