diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index b4ae0dd..c38d30a 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -68,13 +68,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory( - task, - experiment_config.seed, - sampling_config, - frames_stack, - scale=scale_obs, - ) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) builder = ( DQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index a4d3dd2..b0119b0 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -67,13 +67,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory( - task, - experiment_config.seed, - sampling_config, - frames_stack, - scale=scale_obs, - ) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) experiment = ( IQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 25fd0a5..b039616 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -73,7 +73,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack) builder = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 9f743eb..d0fd067 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -67,13 +67,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory( - task, - experiment_config.seed, - sampling_config, - frames_stack, - scale=scale_obs, - ) + env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) builder = ( DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a44fbb8..510106d 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -9,7 +9,6 @@ import gymnasium as gym import numpy as np from tianshou.env import ShmemVectorEnv -from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext @@ -375,26 +374,23 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): class AtariEnvFactory(EnvFactory): - def __init__( - self, - task: str, - seed: int, - sampling_config: SamplingConfig, - frame_stack: int, - scale: int = 0, - ): + def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0): self.task = task - self.sampling_config = sampling_config self.seed = seed self.frame_stack = frame_stack self.scale = scale - def create_envs(self, config=None) -> DiscreteEnvironments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config=None, + ) -> DiscreteEnvironments: env, train_envs, test_envs = make_atari_env( task=self.task, seed=self.seed, - training_num=self.sampling_config.num_train_envs, - test_num=self.sampling_config.num_test_envs, + training_num=num_training_envs, + test_num=num_test_envs, scale=self.scale, frame_stack=self.frame_stack, ) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 76de96f..6982b94 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -55,7 +55,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index f5901dc..e7c519e 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -53,7 +53,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index c2be4c1..3aae335 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -5,7 +5,6 @@ import warnings import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs -from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World @@ -70,18 +69,22 @@ class MujocoEnvObsRmsPersistence(Persistence): class MujocoEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True): + def __init__(self, task: str, seed: int, obs_norm=True): self.task = task - self.sampling_config = sampling_config self.seed = seed self.obs_norm = obs_norm - def create_envs(self, config=None) -> ContinuousEnvironments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config=None, + ) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( task=self.task, seed=self.seed, - num_train_envs=self.sampling_config.num_train_envs, - num_test_envs=self.sampling_config.num_test_envs, + num_train_envs=num_training_envs, + num_test_envs=num_test_envs, obs_norm=self.obs_norm, ) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index bcf2bc6..b6d2637 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -57,7 +57,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 0d121aa..ba81806 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -62,7 +62,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 52a296d..2136f33 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -59,7 +59,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index fed5b06..5097d98 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -50,7 +50,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 930aacf..2edca7b 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -55,7 +55,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 1a5a3ad..e4021d6 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -59,7 +59,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2e11f4e..34ed14b 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -59,7 +59,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) + env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 342848d..aab96c6 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -11,26 +11,28 @@ from tianshou.highlevel.persistence import PersistableConfigProtocol class DiscreteTestEnvFactory(EnvFactory): - def __init__(self, test_num=10, train_num=10): - self.test_num = test_num - self.train_num = train_num - - def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: task = "CartPole-v0" env = gym.make(task) - train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) return DiscreteEnvironments(env, train_envs, test_envs) class ContinuousTestEnvFactory(EnvFactory): - def __init__(self, test_num=10, train_num=10): - self.test_num = test_num - self.train_num = train_num - - def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: task = "Pendulum-v1" env = gym.make(task) - train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) return ContinuousEnvironments(env, train_envs, test_envs) diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py index 06995a7..32de680 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_continuous.py @@ -32,7 +32,12 @@ from tianshou.highlevel.experiment import ( ) def test_experiment_builder_continuous_default_params(builder_cls): env_factory = ContinuousTestEnvFactory() - sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) experiment_config = ExperimentConfig(persistence_enabled=False) builder = builder_cls( experiment_config=experiment_config, diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py index d9592af..53f1dd8 100644 --- a/test/highlevel/test_discrete.py +++ b/test/highlevel/test_discrete.py @@ -25,7 +25,12 @@ from tianshou.highlevel.experiment import ( ) def test_experiment_builder_discrete_default_params(builder_cls): env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) builder = builder_cls( experiment_config=ExperimentConfig(persistence_enabled=False), env_factory=env_factory, diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index d11eddc..a99da13 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -140,8 +140,18 @@ class DiscreteEnvironments(Environments): class EnvFactory(ToStringMixin, ABC): @abstractmethod - def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: pass - def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments: - return self.create_envs(config=config) + def __call__( + self, + num_training_envs: int, + num_test_envs: int, + config: PersistableConfigProtocol | None = None, + ) -> Environments: + return self.create_envs(num_training_envs, num_test_envs, config=config) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index fa1cc0e..7d7a26c 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -136,14 +136,17 @@ class Experiment(ToStringMixin): def __init__( self, config: ExperimentConfig, - env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments], + env_factory: EnvFactory + | Callable[[int, int, PersistableConfigProtocol | None], Environments], agent_factory: AgentFactory, + sampling_config: SamplingConfig, logger_factory: LoggerFactory | None = None, env_config: PersistableConfigProtocol | None = None, ): if logger_factory is None: logger_factory = LoggerFactoryDefault() self.config = config + self.sampling_config = sampling_config self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory @@ -214,7 +217,11 @@ class Experiment(ToStringMixin): self._set_seed() # create environments - envs = self.env_factory(self.env_config) + envs = self.env_factory( + self.sampling_config.num_train_envs, + self.sampling_config.num_test_envs, + self.env_config, + ) log.info(f"Created {envs}") # initialize persistence @@ -416,6 +423,7 @@ class ExperimentBuilder: self._config, self._env_factory, agent_factory, + self._sampling_config, self._logger_factory, env_config=self._env_config, )