From 837ff13c04357331ebd94fa9a48cc4542ddee07b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 6 Oct 2023 13:53:45 +0200 Subject: [PATCH] Reorder ExperimentBuilder args (EnvFactory first) --- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_ppo_hl.py | 4 ++-- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 4 ++-- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- test/highlevel/test_continuous.py | 2 +- test/highlevel/test_discrete.py | 2 +- tianshou/highlevel/experiment.py | 26 +++++++++++++------------- 10 files changed, 24 insertions(+), 24 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 13b49d6..8aed0fa 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -94,7 +94,7 @@ def main( policy.set_eps(eps_test) builder = ( - DQNExperimentBuilder(experiment_config, env_factory, sampling_config) + DQNExperimentBuilder(env_factory, experiment_config, sampling_config) .with_dqn_params( DQNParams( discount_factor=gamma, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 1f785e8..c6cc321 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -13,8 +13,8 @@ from examples.atari.atari_network import ( from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( - PPOExperimentBuilder, ExperimentConfig, + PPOExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams @@ -76,7 +76,7 @@ def main( env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) builder = ( - PPOExperimentBuilder(experiment_config, env_factory, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( PPOParams( discount_factor=gamma, diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 7501fae..badd1c5 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -58,7 +58,7 @@ def main( env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) experiment = ( - A2CExperimentBuilder(experiment_config, env_factory, sampling_config) + A2CExperimentBuilder(env_factory, experiment_config, sampling_config) .with_a2c_params( A2CParams( discount_factor=gamma, diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 5944f9b..5763d9d 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -57,7 +57,7 @@ def main( env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) experiment = ( - DDPGExperimentBuilder(experiment_config, env_factory, sampling_config) + DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ddpg_params( DDPGParams( actor_lr=actor_lr, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 33d3edc..dfaf908 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -11,8 +11,8 @@ from torch.distributions import Independent, Normal from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( - PPOExperimentBuilder, ExperimentConfig, + PPOExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams @@ -66,7 +66,7 @@ def main( return Independent(Normal(*logits), 1) experiment = ( - PPOExperimentBuilder(experiment_config, env_factory, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( PPOParams( discount_factor=gamma, diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 852c6f5..c0be24d 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -58,7 +58,7 @@ def main( env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) experiment = ( - SACExperimentBuilder(experiment_config, env_factory, sampling_config) + SACExperimentBuilder(env_factory, experiment_config, sampling_config) .with_sac_params( SACParams( tau=tau, diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 28a211a..2a09409 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -62,7 +62,7 @@ def main( env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) experiment = ( - TD3ExperimentBuilder(experiment_config, env_factory, sampling_config) + TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) .with_td3_params( TD3Params( tau=tau, diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py index 0934c45..884e9d8 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_continuous.py @@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, - PPOExperimentBuilder, ExperimentConfig, + PPOExperimentBuilder, SACExperimentBuilder, TD3ExperimentBuilder, ) diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py index ff593df..0dd624b 100644 --- a/test/highlevel/test_discrete.py +++ b/test/highlevel/test_discrete.py @@ -6,8 +6,8 @@ from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DQNExperimentBuilder, - PPOExperimentBuilder, ExperimentConfig, + PPOExperimentBuilder, ) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 742f858..54de984 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -156,8 +156,8 @@ TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder") class ExperimentBuilder: def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, ): self._config = experiment_config @@ -400,12 +400,12 @@ class A2CExperimentBuilder( ): def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, env_config: PersistableConfigProtocol | None = None, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) self._params: A2CParams = A2CParams() @@ -434,11 +434,11 @@ class PPOExperimentBuilder( ): def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) self._params: PPOParams = PPOParams() @@ -465,11 +465,11 @@ class DQNExperimentBuilder( ): def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) self._params: DQNParams = DQNParams() @@ -494,11 +494,11 @@ class DDPGExperimentBuilder( ): def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) self._params: DDPGParams = DDPGParams() @@ -525,11 +525,11 @@ class SACExperimentBuilder( ): def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinDualCriticFactory.__init__(self) self._params: SACParams = SACParams() @@ -556,11 +556,11 @@ class TD3ExperimentBuilder( ): def __init__( self, - experiment_config: ExperimentConfig, env_factory: EnvFactory, + experiment_config: ExperimentConfig, sampling_config: SamplingConfig, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(env_factory, experiment_config, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinDualCriticFactory.__init__(self) self._params: TD3Params = TD3Params()