diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 15d9393..b4e4d44 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -2,9 +2,9 @@ import warnings import gymnasium as gym -from tianshou.config import BasicExperimentConfig, RLSamplingConfig from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory +from tianshou.highlevel.experiment import RLSamplingConfig try: import envpool @@ -41,14 +41,15 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in class MujocoEnvFactory(EnvFactory): - def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig): + def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig): + self.task = task self.sampling_config = sampling_config - self.experiment_config = experiment_config + self.seed = seed def create_envs(self) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( - task=self.experiment_config.task, - seed=self.experiment_config.seed, + task=self.task, + seed=self.seed, num_train_envs=self.sampling_config.num_train_envs, num_test_envs=self.sampling_config.num_test_envs, obs_norm=True, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 7cba814..570dc61 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -9,17 +9,13 @@ from jsonargparse import CLI from torch.distributions import Independent, Normal from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.config import ( - BasicExperimentConfig, - LoggerConfig, - PGConfig, - PPOConfig, - RLAgentConfig, +from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig +from tianshou.highlevel.experiment import ( + RLExperiment, + RLExperimentConfig, RLSamplingConfig, ) -from tianshou.highlevel.agent import PPOAgentFactory -from tianshou.highlevel.experiment import RLExperiment -from tianshou.highlevel.logger import DefaultLoggerFactory +from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig from tianshou.highlevel.module import ( ContinuousActorProbFactory, ContinuousNetCriticFactory, @@ -35,19 +31,20 @@ class NNConfig: def main( - experiment_config: BasicExperimentConfig, + experiment_config: RLExperimentConfig, logger_config: LoggerConfig, sampling_config: RLSamplingConfig, general_config: RLAgentConfig, pg_config: PGConfig, ppo_config: PPOConfig, nn_config: NNConfig, + task: str = "Ant-v4", ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now) + log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) logger_factory = DefaultLoggerFactory(logger_config) - env_factory = MujocoEnvFactory(experiment_config, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) def dist_fn(*logits): return Independent(Normal(*logits), 1) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 24a8f20..0d9ef4b 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,14 +7,13 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.config import ( - BasicExperimentConfig, - LoggerConfig, +from tianshou.highlevel.agent import SACAgentFactory, SACConfig +from tianshou.highlevel.experiment import ( + RLExperiment, + RLExperimentConfig, RLSamplingConfig, ) -from tianshou.highlevel.agent import SACAgentFactory -from tianshou.highlevel.experiment import RLExperiment -from tianshou.highlevel.logger import DefaultLoggerFactory +from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig from tianshou.highlevel.module import ( ContinuousActorProbFactory, ContinuousNetCriticFactory, @@ -23,17 +22,18 @@ from tianshou.highlevel.optim import AdamOptimizerFactory def main( - experiment_config: BasicExperimentConfig, + experiment_config: RLExperimentConfig, logger_config: LoggerConfig, sampling_config: RLSamplingConfig, - sac_config: SACAgentFactory.Config, + sac_config: SACConfig, hidden_sizes: Sequence[int] = (256, 256), + task: str = "Ant-v4", ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now) + log_name = os.path.join(task, "sac", str(experiment_config.seed), now) logger_factory = DefaultLoggerFactory(logger_config) - env_factory = MujocoEnvFactory(experiment_config, sampling_config) + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) critic_factory = ContinuousNetCriticFactory(hidden_sizes) diff --git a/tianshou/config/__init__.py b/tianshou/config/__init__.py deleted file mode 100644 index 9f9107c..0000000 --- a/tianshou/config/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"] - -from .config import ( - BasicExperimentConfig, - PGConfig, - PPOConfig, - RLAgentConfig, - RLSamplingConfig, - LoggerConfig, -) diff --git a/tianshou/config/config.py b/tianshou/config/config.py deleted file mode 100644 index 25ea63a..0000000 --- a/tianshou/config/config.py +++ /dev/null @@ -1,86 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - -import torch -from jsonargparse import set_docstring_parse_options - -set_docstring_parse_options(attribute_docstrings=True) - - -@dataclass -class BasicExperimentConfig: - """Generic config for setting up the experiment, not RL or training specific.""" - - seed: int = 42 - task: str = "Ant-v4" - """Mujoco specific""" - render: float | None = 0.0 - """Milliseconds between rendered frames; if None, no rendering""" - device: str = "cuda" if torch.cuda.is_available() else "cpu" - resume_id: str | None = None - """For restoring a model and running means of env-specifics from a checkpoint""" - resume_path: str | None = None - """For restoring a model and running means of env-specifics from a checkpoint""" - watch: bool = False - """If True, will not perform training and only watch the restored policy""" - watch_num_episodes = 10 - - -@dataclass -class LoggerConfig: - """Logging config.""" - - logdir: str = "log" - logger: Literal["tensorboard", "wandb"] = "tensorboard" - wandb_project: str = "mujoco.benchmark" - """Only used if logger is wandb.""" - - -@dataclass -class RLSamplingConfig: - """Sampling, epochs, parallelization, buffers, collectors, and batching.""" - - num_epochs: int = 100 - step_per_epoch: int = 30000 - batch_size: int = 64 - num_train_envs: int = 64 - num_test_envs: int = 10 - buffer_size: int = 4096 - step_per_collect: int = 2048 - repeat_per_collect: int = 10 - update_per_step: int = 1 - - -@dataclass -class RLAgentConfig: - """Config common to most RL algorithms.""" - - gamma: float = 0.99 - """Discount factor""" - gae_lambda: float = 0.95 - """For Generalized Advantage Estimate (equivalent to TD(lambda))""" - action_bound_method: Literal["clip", "tanh"] | None = "clip" - """How to map original actions in range (-inf, inf) to [-1, 1]""" - rew_norm: bool = True - """Whether to normalize rewards""" - - -@dataclass -class PGConfig: - """Config of general policy-gradient algorithms.""" - - ent_coef: float = 0.0 - vf_coef: float = 0.25 - max_grad_norm: float = 0.5 - - -@dataclass -class PPOConfig: - """PPO specific config.""" - - value_clip: bool = False - norm_adv: bool = False - """Whether to normalize advantages""" - eps_clip: float = 0.2 - dual_clip: float | None = None - recompute_adv: bool = True diff --git a/tianshou/highlevel/__init__.py b/tianshou/highlevel/__init__.py index e69de29..8ce1671 100644 --- a/tianshou/highlevel/__init__.py +++ b/tianshou/highlevel/__init__.py @@ -0,0 +1,3 @@ +from jsonargparse import set_docstring_parse_options + +set_docstring_parse_options(attribute_docstrings=True) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index e95c799..7c8bede 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -2,13 +2,14 @@ import os from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass +from typing import Literal import torch -from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments +from tianshou.highlevel.experiment import RLSamplingConfig from tianshou.highlevel.logger import Logger from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory @@ -124,6 +125,41 @@ class OffpolicyAgentFactory(AgentFactory, ABC): ) +@dataclass +class RLAgentConfig: + """Config common to most RL algorithms.""" + + gamma: float = 0.99 + """Discount factor""" + gae_lambda: float = 0.95 + """For Generalized Advantage Estimate (equivalent to TD(lambda))""" + action_bound_method: Literal["clip", "tanh"] | None = "clip" + """How to map original actions in range (-inf, inf) to [-1, 1]""" + rew_norm: bool = True + """Whether to normalize rewards""" + + +@dataclass +class PGConfig: + """Config of general policy-gradient algorithms.""" + + ent_coef: float = 0.0 + vf_coef: float = 0.25 + max_grad_norm: float = 0.5 + + +@dataclass +class PPOConfig: + """PPO specific config.""" + + value_clip: bool = False + norm_adv: bool = False + """Whether to normalize advantages""" + eps_clip: float = 0.2 + dual_clip: float | None = None + recompute_adv: bool = True + + class PPOAgentFactory(OnpolicyAgentFactory): def __init__( self, @@ -186,10 +222,22 @@ class PPOAgentFactory(OnpolicyAgentFactory): ) +class SACConfig: + tau: float = 0.005 + gamma: float = 0.99 + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2 + reward_normalization: bool = False + estimation_step: int = 1 + deterministic_eval: bool = True + actor_lr: float = 1e-3 + critic1_lr: float = 1e-3 + critic2_lr: float = 1e-3 + + class SACAgentFactory(OffpolicyAgentFactory): def __init__( self, - config: "SACAgentFactory.Config", + config: SACConfig, sampling_config: RLSamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, @@ -227,17 +275,3 @@ class SACAgentFactory(OffpolicyAgentFactory): deterministic_eval=self.config.deterministic_eval, exploration_noise=self.exploration_noise, ) - - @dataclass - class Config: - """SAC configuration.""" - - tau: float = 0.005 - gamma: float = 0.99 - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2 - reward_normalization: bool = False - estimation_step: int = 1 - deterministic_eval: bool = True - actor_lr: float = 1e-3 - critic1_lr: float = 1e-3 - critic2_lr: float = 1e-3 diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 19f0e2c..d4527f2 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,12 +1,10 @@ +from dataclasses import dataclass from pprint import pprint from typing import Generic, TypeVar import numpy as np import torch -from tianshou.config import ( - BasicExperimentConfig, -) from tianshou.data import Collector from tianshou.highlevel.agent import AgentFactory from tianshou.highlevel.env import EnvFactory @@ -18,10 +16,42 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy) TTrainer = TypeVar("TTrainer", bound=BaseTrainer) +@dataclass +class RLExperimentConfig: + """Generic config for setting up the experiment, not RL or training specific.""" + + seed: int = 42 + render: float | None = 0.0 + """Milliseconds between rendered frames; if None, no rendering""" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + resume_id: str | None = None + """For restoring a model and running means of env-specifics from a checkpoint""" + resume_path: str | None = None + """For restoring a model and running means of env-specifics from a checkpoint""" + watch: bool = False + """If True, will not perform training and only watch the restored policy""" + watch_num_episodes = 10 + + +@dataclass +class RLSamplingConfig: + """Sampling, epochs, parallelization, buffers, collectors, and batching.""" + + num_epochs: int = 100 + step_per_epoch: int = 30000 + batch_size: int = 64 + num_train_envs: int = 64 + num_test_envs: int = 10 + buffer_size: int = 4096 + step_per_collect: int = 2048 + repeat_per_collect: int = 10 + update_per_step: int = 1 + + class RLExperiment(Generic[TPolicy, TTrainer]): def __init__( self, - config: BasicExperimentConfig, + config: RLExperimentConfig, env_factory: EnvFactory, logger_factory: LoggerFactory, agent_factory: AgentFactory, diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 098072b..06bd195 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -1,10 +1,10 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import Literal from torch.utils.tensorboard import SummaryWriter -from tianshou.config import LoggerConfig from tianshou.utils import TensorboardLogger, WandbLogger TLogger = TensorboardLogger | WandbLogger @@ -22,11 +22,21 @@ class LoggerFactory(ABC): pass +@dataclass +class LoggerConfig: + """Logging config.""" + + logdir: str = "log" + logger: Literal["tensorboard", "wandb"] = "tensorboard" + wandb_project: str = "mujoco.benchmark" + """Only used if logger is wandb.""" + + class DefaultLoggerFactory(LoggerFactory): def __init__(self, config: LoggerConfig): self.config = config - def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: + def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: writer = SummaryWriter(self.config.logdir) writer.add_text("args", str(self.config)) if self.config.logger == "wandb": diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module.py index e85a4db..8686218 100644 --- a/tianshou/highlevel/module.py +++ b/tianshou/highlevel/module.py @@ -66,7 +66,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory): actor = ActorProb( net_a, envs.get_action_shape(), - unbounded=True, + unbounded=self.unbounded, device=device, conditioned_sigma=self.conditioned_sigma, ).to(device) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 5ac660a..4e68104 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, Type +from typing import Any import numpy as np import torch @@ -8,7 +8,7 @@ from torch import Tensor from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from tianshou.config import RLSamplingConfig +from tianshou.highlevel.experiment import RLSamplingConfig TParams = Iterable[Tensor] | Iterable[dict[str, Any]] diff --git a/tianshou/config/utils.py b/tianshou/highlevel/utils.py similarity index 100% rename from tianshou/config/utils.py rename to tianshou/highlevel/utils.py