From 6cbee188b832fd3706f9c7300af2c6b420faffad Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 18 Oct 2023 23:55:23 +0200 Subject: [PATCH] Change interface of EnvFactory to ensure that configuration of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction --- examples/atari/atari_dqn_hl.py | 8 +------ examples/atari/atari_iqn_hl.py | 8 +------ examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_sac_hl.py | 8 +------ examples/atari/atari_wrapper.py | 22 ++++++++----------- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_env.py | 15 +++++++------ examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- test/highlevel/env_factory.py | 30 ++++++++++++++------------ test/highlevel/test_continuous.py | 7 +++++- test/highlevel/test_discrete.py | 7 +++++- tianshou/highlevel/env.py | 16 +++++++++++--- tianshou/highlevel/experiment.py | 12 +++++++++-- 20 files changed, 82 insertions(+), 71 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index b4ae0dd..c38d30a 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -68,13 +68,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory( - task, - experiment_config.seed, - sampling_config, - frames_stack, - scale=scale_obs, - ) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) builder = ( DQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index a4d3dd2..b0119b0 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -67,13 +67,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory( - task, - experiment_config.seed, - sampling_config, - frames_stack, - scale=scale_obs, - ) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) experiment = ( IQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 25fd0a5..b039616 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -73,7 +73,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack) builder = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 9f743eb..d0fd067 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -67,13 +67,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory( - task, - experiment_config.seed, - sampling_config, - frames_stack, - scale=scale_obs, - ) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) builder = ( DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a44fbb8..510106d 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -9,7 +9,6 @@ import gymnasium as gym import numpy as np from tianshou.env import ShmemVectorEnv -from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext @@ -375,26 +374,23 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): class AtariEnvFactory(EnvFactory): - def __init__( - self, - task: str, - seed: int, - sampling_config: SamplingConfig, - frame_stack: int, - scale: int = 0, - ): + def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0): self.task = task - self.sampling_config = sampling_config self.seed = seed self.frame_stack = frame_stack self.scale = scale - def create_envs(self, config=None) -> DiscreteEnvironments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config=None, + ) -> DiscreteEnvironments: env, train_envs, test_envs = make_atari_env( task=self.task, seed=self.seed, - training_num=self.sampling_config.num_train_envs, - test_num=self.sampling_config.num_test_envs, + training_num=num_training_envs, + test_num=num_test_envs, scale=self.scale, frame_stack=self.frame_stack, ) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 76de96f..6982b94 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -55,7 +55,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index f5901dc..e7c519e 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -53,7 +53,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index c2be4c1..3aae335 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -5,7 +5,6 @@ import warnings import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs -from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World @@ -70,18 +69,22 @@ class MujocoEnvObsRmsPersistence(Persistence): class MujocoEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True): + def __init__(self, task: str, seed: int, obs_norm=True): self.task = task - self.sampling_config = sampling_config self.seed = seed self.obs_norm = obs_norm - def create_envs(self, config=None) -> ContinuousEnvironments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config=None, + ) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( task=self.task, seed=self.seed, - num_train_envs=self.sampling_config.num_train_envs, - num_test_envs=self.sampling_config.num_test_envs, + num_train_envs=num_training_envs, + num_test_envs=num_test_envs, obs_norm=self.obs_norm, ) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index bcf2bc6..b6d2637 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -57,7 +57,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 0d121aa..ba81806 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -62,7 +62,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 52a296d..2136f33 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -59,7 +59,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index fed5b06..5097d98 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -50,7 +50,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 930aacf..2edca7b 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -55,7 +55,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 1a5a3ad..e4021d6 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -59,7 +59,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2e11f4e..34ed14b 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -59,7 +59,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 342848d..aab96c6 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -11,26 +11,28 @@ from tianshou.highlevel.persistence import PersistableConfigProtocol class DiscreteTestEnvFactory(EnvFactory): - def __init__(self, test_num=10, train_num=10): - self.test_num = test_num - self.train_num = train_num - - def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: task = "CartPole-v0" env = gym.make(task) - train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) return DiscreteEnvironments(env, train_envs, test_envs) class ContinuousTestEnvFactory(EnvFactory): - def __init__(self, test_num=10, train_num=10): - self.test_num = test_num - self.train_num = train_num - - def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: task = "Pendulum-v1" env = gym.make(task) - train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) return ContinuousEnvironments(env, train_envs, test_envs) diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py index 06995a7..32de680 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_continuous.py @@ -32,7 +32,12 @@ from tianshou.highlevel.experiment import ( ) def test_experiment_builder_continuous_default_params(builder_cls): env_factory = ContinuousTestEnvFactory() - sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) experiment_config = ExperimentConfig(persistence_enabled=False) builder = builder_cls( experiment_config=experiment_config, diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py index d9592af..53f1dd8 100644 --- a/test/highlevel/test_discrete.py +++ b/test/highlevel/test_discrete.py @@ -25,7 +25,12 @@ from tianshou.highlevel.experiment import ( ) def test_experiment_builder_discrete_default_params(builder_cls): env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) builder = builder_cls( experiment_config=ExperimentConfig(persistence_enabled=False), env_factory=env_factory, diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index d11eddc..a99da13 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -140,8 +140,18 @@ class DiscreteEnvironments(Environments): class EnvFactory(ToStringMixin, ABC): @abstractmethod - def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: pass - def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments: - return self.create_envs(config=config) + def __call__( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: + return self.create_envs(num_training_envs, num_test_envs, config=config) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index fa1cc0e..7d7a26c 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -136,14 +136,17 @@ class Experiment(ToStringMixin): def __init__( self, config: ExperimentConfig, - env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments], + env_factory: EnvFactory + | Callable[[int, int, PersistableConfigProtocol | None], Environments], agent_factory: AgentFactory, + sampling_config: SamplingConfig, logger_factory: LoggerFactory | None = None, env_config: PersistableConfigProtocol | None = None, ): if logger_factory is None: logger_factory = LoggerFactoryDefault() self.config = config + self.sampling_config = sampling_config self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory @@ -214,7 +217,11 @@ class Experiment(ToStringMixin): self._set_seed() # create environments - envs = self.env_factory(self.env_config) + envs = self.env_factory( + self.sampling_config.num_train_envs, + self.sampling_config.num_test_envs, + self.env_config, + ) log.info(f"Created {envs}") # initialize persistence @@ -416,6 +423,7 @@ class ExperimentBuilder: self._config, self._env_factory, agent_factory, + self._sampling_config, self._logger_factory, env_config=self._env_config, )