diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 887ebc8..edba36e 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -66,7 +66,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) + env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) builder = ( DQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index dcdacf2..01f4f34 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -65,7 +65,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) + env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) experiment = ( IQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 1271567..9bcaf05 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -65,7 +65,7 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs) + env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) builder = ( DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a2fdcca..9f0d97d 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -340,7 +340,7 @@ def make_atari_env( :return: a tuple of (single env, training envs, test envs). """ - env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale)) + env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale)) envs = env_factory.create_envs(training_num, test_num) return envs.env, envs.train_envs, envs.test_envs @@ -349,7 +349,8 @@ class AtariEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, - seed: int, + train_seed: int, + test_seed: int, frame_stack: int, scale: bool = False, use_envpool_if_available: bool = True, @@ -366,7 +367,8 @@ class AtariEnvFactory(EnvFactoryRegistered): log.info("Not using envpool, because it is not available") super().__init__( task=task, - seed=seed, + train_seed=train_seed, + test_seed=test_seed, venv_type=VectorEnvType.SUBPROC_SHARED_MEM, envpool_factory=envpool_factory, ) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index fec2e26..23ace13 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -54,7 +54,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 2bbc669..1a7abab 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -51,7 +51,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index b04f243..b115450 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -29,7 +29,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in :return: a tuple of (single env, training envs, test envs). """ - envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs( + envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs( num_train_envs, num_test_envs, ) @@ -62,10 +62,11 @@ class MujocoEnvObsRmsPersistence(Persistence): class MujocoEnvFactory(EnvFactoryRegistered): - def __init__(self, task: str, seed: int, obs_norm=True) -> None: + def __init__(self, task: str, train_seed: int, test_seed: int, obs_norm=True) -> None: super().__init__( task=task, - seed=seed, + train_seed=train_seed, + test_seed=test_seed, venv_type=VectorEnvType.SUBPROC_SHARED_MEM, envpool_factory=EnvPoolFactory() if envpool_is_available else None, ) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 6ab0eb8..21594d0 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -56,7 +56,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index dbc6fb5..898d448 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -61,7 +61,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 78e0d34..7607b37 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -57,7 +57,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index bc07e05..0ff6537 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -49,7 +49,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index c6a6a3b..1d71fc7 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -52,7 +52,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 73d20fe..7ea32b8 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -58,7 +58,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2f9a777..2e947b7 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -58,7 +58,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True) + env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 1dd1273..c526e2d 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -6,9 +6,9 @@ from tianshou.highlevel.env import ( class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: - super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY) + super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY) class ContinuousTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: - super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY) + super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 4982472..a798401 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -50,6 +50,9 @@ class SamplingConfig(ToStringMixin): num_train_envs: int = -1 """the number of training environments to use. If set to -1, use number of CPUs/threads.""" + train_seed: int = 42 + """the seed to use for the training environments.""" + num_test_envs: int = 1 """the number of test environments to use""" @@ -126,6 +129,10 @@ class SamplingConfig(ToStringMixin): temporal aspects (e.g. velocities of moving objects for which only positions are observed). """ + @property + def test_seed(self) -> int: + return self.train_seed + self.num_train_envs + 1 + def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 71de0f8..08acd13 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -395,7 +395,8 @@ class EnvFactoryRegistered(EnvFactory): self, *, task: str, - seed: int, + train_seed: int, + test_seed: int, venv_type: VectorEnvType, envpool_factory: EnvPoolFactory | None = None, render_mode_train: str | None = None, @@ -417,7 +418,8 @@ class EnvFactoryRegistered(EnvFactory): super().__init__(venv_type) self.task = task self.envpool_factory = envpool_factory - self.seed = seed + self.train_seed = train_seed + self.test_seed = test_seed self.render_modes = { EnvMode.TRAIN: render_mode_train, EnvMode.TEST: render_mode_test, @@ -445,15 +447,16 @@ class EnvFactoryRegistered(EnvFactory): return gymnasium.make(self.task, **kwargs) def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed if self.envpool_factory is not None: return self.envpool_factory.create_venv( self.task, num_envs, mode, - self.seed, + seed, self._create_kwargs(mode), ) else: venv = super().create_venv(num_envs, mode) - venv.seed(self.seed) + venv.seed(seed) return venv