diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 47916dc..13b49d6 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -10,10 +10,10 @@ from examples.atari.atari_network import ( FeatureNetFactoryDQN, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( DQNExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, ) from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.params.policy_wrapper import ( @@ -25,7 +25,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "PongNoFrameskip-v4", scale_obs: int = 0, eps_test: float = 0.005, @@ -52,7 +52,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 3179934..1f785e8 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -11,10 +11,10 @@ from examples.atari.atari_network import ( FeatureNetFactoryDQN, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( PPOExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams @@ -25,7 +25,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "PongNoFrameskip-v4", scale_obs: bool = True, buffer_size: int = 100000, @@ -59,7 +59,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 13384d6..a44fbb8 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -9,7 +9,7 @@ import gymnasium as gym import numpy as np from tianshou.env import ShmemVectorEnv -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext @@ -379,7 +379,7 @@ class AtariEnvFactory(EnvFactory): self, task: str, seed: int, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, frame_stack: int, scale: int = 0, ): diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index a6f48f6..7501fae 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -8,10 +8,10 @@ from typing import Literal from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, ) from tianshou.highlevel.optim import OptimizerFactoryRMSprop from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear @@ -20,7 +20,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "Ant-v3", buffer_size: int = 4096, hidden_sizes: Sequence[int] = (64, 64), @@ -44,7 +44,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 0b173b4..5944f9b 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -7,10 +7,10 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( DDPGExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, ) from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.policy_params import DDPGParams @@ -18,7 +18,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "Ant-v3", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), @@ -40,7 +40,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index ded54c3..620d6dd 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -3,7 +3,7 @@ import warnings import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory try: @@ -41,7 +41,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in class MujocoEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig): + def __init__(self, task: str, seed: int, sampling_config: SamplingConfig): self.task = task self.sampling_config = sampling_config self.seed = seed diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 3272b0f..33d3edc 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -9,10 +9,10 @@ from jsonargparse import CLI from torch.distributions import Independent, Normal from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( PPOExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams @@ -20,7 +20,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "Ant-v4", buffer_size: int = 4096, hidden_sizes: Sequence[int] = (64, 64), @@ -49,7 +49,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 1996689..852c6f5 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,9 +7,9 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( - RLExperimentConfig, + ExperimentConfig, SACExperimentBuilder, ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault @@ -18,7 +18,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "Ant-v3", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), @@ -42,7 +42,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "sac", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, num_train_envs=training_num, diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index f2c906a..28a211a 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -7,9 +7,9 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( - RLExperimentConfig, + ExperimentConfig, TD3ExperimentBuilder, ) from tianshou.highlevel.params.env_param import MaxActionScaled @@ -21,7 +21,7 @@ from tianshou.utils import logging def main( - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, task: str = "Ant-v3", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), @@ -46,7 +46,7 @@ def main( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "td3", str(experiment_config.seed), now) - sampling_config = RLSamplingConfig( + sampling_config = SamplingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, num_train_envs=training_num, diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py index bc394d6..0934c45 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_continuous.py @@ -2,12 +2,12 @@ from test.highlevel.env_factory import ContinuousTestEnvFactory import pytest -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, PPOExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, SACExperimentBuilder, TD3ExperimentBuilder, ) @@ -25,8 +25,8 @@ from tianshou.highlevel.experiment import ( ) def test_experiment_builder_continuous_default_params(builder_cls): env_factory = ContinuousTestEnvFactory() - sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100) - experiment_config = RLExperimentConfig() + sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) + experiment_config = ExperimentConfig() builder = builder_cls( experiment_config=experiment_config, env_factory=env_factory, diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py index b242b85..ff593df 100644 --- a/test/highlevel/test_discrete.py +++ b/test/highlevel/test_discrete.py @@ -2,12 +2,12 @@ from test.highlevel.env_factory import DiscreteTestEnvFactory import pytest -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DQNExperimentBuilder, PPOExperimentBuilder, - RLExperimentConfig, + ExperimentConfig, ) @@ -17,9 +17,9 @@ from tianshou.highlevel.experiment import ( ) def test_experiment_builder_discrete_default_params(builder_cls): env_factory = DiscreteTestEnvFactory() - sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100) + sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) builder = builder_cls( - experiment_config=RLExperimentConfig(), + experiment_config=ExperimentConfig(), env_factory=env_factory, sampling_config=sampling_config, ) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index af0fea9..6fa308c 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -6,7 +6,7 @@ import gymnasium import torch from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import Logger from tianshou.highlevel.module.actor import ( @@ -54,7 +54,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy) class AgentFactory(ABC, ToStringMixin): - def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory): + def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): self.sampling_config = sampling_config self.optim_factory = optim_factory self.policy_wrapper_factory: PolicyWrapperFactory | None = None @@ -352,7 +352,7 @@ class ActorCriticAgentFactory( def __init__( self, params: TParams, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, @@ -399,7 +399,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): def __init__( self, params: A2CParams, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, @@ -423,7 +423,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): def __init__( self, params: PPOParams, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, @@ -447,7 +447,7 @@ class DQNAgentFactory(OffpolicyAgentFactory): def __init__( self, params: DQNParams, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactory, ): @@ -483,7 +483,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin): def __init__( self, params: DDPGParams, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactory, @@ -526,7 +526,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( self, params: SACParams, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, @@ -575,7 +575,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( self, params: TD3Params, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 6937fc2..e3f74b2 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -2,7 +2,7 @@ from dataclasses import dataclass @dataclass -class RLSamplingConfig: +class SamplingConfig: """Sampling, epochs, parallelization, buffers, collectors, and batching.""" # TODO: What are reasonable defaults? diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index aed6005..742f858 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -18,7 +18,7 @@ from tianshou.highlevel.agent import ( SACAgentFactory, TD3AgentFactory, ) -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.module.actor import ( @@ -53,7 +53,7 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer) @dataclass -class RLExperimentConfig: +class ExperimentConfig: """Generic config for setting up the experiment, not RL or training specific.""" seed: int = 42 @@ -69,10 +69,10 @@ class RLExperimentConfig: watch_num_episodes = 10 -class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin): +class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): def __init__( self, - config: RLExperimentConfig, + config: ExperimentConfig, env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments], agent_factory: AgentFactory, logger_factory: LoggerFactory | None = None, @@ -153,12 +153,12 @@ class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin): TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder") -class RLExperimentBuilder: +class ExperimentBuilder: def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, ): self._config = experiment_config self._env_factory = env_factory @@ -223,12 +223,12 @@ class RLExperimentBuilder: else: return self._optim_factory - def build(self) -> RLExperiment: + def build(self) -> Experiment: agent_factory = self._create_agent_factory() agent_factory.set_trainer_callbacks(self._trainer_callbacks) if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) - experiment = RLExperiment( + experiment = Experiment( self._config, self._env_factory, agent_factory, @@ -394,15 +394,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): class A2CExperimentBuilder( - RLExperimentBuilder, + ExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, env_config: PersistableConfigProtocol | None = None, ): super().__init__(experiment_config, env_factory, sampling_config) @@ -428,15 +428,15 @@ class A2CExperimentBuilder( class PPOExperimentBuilder( - RLExperimentBuilder, + ExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) @@ -460,14 +460,14 @@ class PPOExperimentBuilder( class DQNExperimentBuilder( - RLExperimentBuilder, + ExperimentBuilder, _BuilderMixinActorFactory, ): def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) @@ -488,15 +488,15 @@ class DQNExperimentBuilder( class DDPGExperimentBuilder( - RLExperimentBuilder, + ExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) @@ -519,15 +519,15 @@ class DDPGExperimentBuilder( class SACExperimentBuilder( - RLExperimentBuilder, + ExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinDualCriticFactory, ): def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) @@ -550,15 +550,15 @@ class SACExperimentBuilder( class TD3ExperimentBuilder( - RLExperimentBuilder, + ExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinDualCriticFactory, ): def __init__( self, - experiment_config: RLExperimentConfig, + experiment_config: ExperimentConfig, env_factory: EnvFactory, - sampling_config: RLSamplingConfig, + sampling_config: SamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 1082522..820be2e 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -4,7 +4,7 @@ import numpy as np import torch from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.config import SamplingConfig from tianshou.utils.string import ToStringMixin @@ -15,7 +15,7 @@ class LRSchedulerFactory(ToStringMixin, ABC): class LRSchedulerFactoryLinear(LRSchedulerFactory): - def __init__(self, sampling_config: RLSamplingConfig): + def __init__(self, sampling_config: SamplingConfig): self.sampling_config = sampling_config def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: