added explicit env seeding for train and test envs

This commit is contained in:
Maximilian Huettenrauch 2024-03-06 17:09:06 +01:00
parent 6746a80f6d
commit 95cbfe6cdf
17 changed files with 37 additions and 24 deletions

View File

@ -66,7 +66,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -65,7 +65,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -65,7 +65,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -340,7 +340,7 @@ def make_atari_env(
:return: a tuple of (single env, training envs, test envs).
"""
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale))
envs = env_factory.create_envs(training_num, test_num)
return envs.env, envs.train_envs, envs.test_envs
@ -349,7 +349,8 @@ class AtariEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
seed: int,
train_seed: int,
test_seed: int,
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
@ -366,7 +367,8 @@ class AtariEnvFactory(EnvFactoryRegistered):
log.info("Not using envpool, because it is not available")
super().__init__(
task=task,
seed=seed,
train_seed=train_seed,
test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=envpool_factory,
)

View File

@ -54,7 +54,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -51,7 +51,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -29,7 +29,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
:return: a tuple of (single env, training envs, test envs).
"""
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
@ -62,10 +62,11 @@ class MujocoEnvObsRmsPersistence(Persistence):
class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(self, task: str, seed: int, obs_norm=True) -> None:
def __init__(self, task: str, train_seed: int, test_seed: int, obs_norm=True) -> None:
super().__init__(
task=task,
seed=seed,
train_seed=train_seed,
test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)

View File

@ -56,7 +56,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -61,7 +61,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -57,7 +57,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -49,7 +49,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -52,7 +52,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -6,9 +6,9 @@ from tianshou.highlevel.env import (
class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
class ContinuousTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)

View File

@ -50,6 +50,9 @@ class SamplingConfig(ToStringMixin):
num_train_envs: int = -1
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
train_seed: int = 42
"""the seed to use for the training environments."""
num_test_envs: int = 1
"""the number of test environments to use"""
@ -126,6 +129,10 @@ class SamplingConfig(ToStringMixin):
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
"""
@property
def test_seed(self) -> int:
return self.train_seed + self.num_train_envs + 1
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()

View File

@ -395,7 +395,8 @@ class EnvFactoryRegistered(EnvFactory):
self,
*,
task: str,
seed: int,
train_seed: int,
test_seed: int,
venv_type: VectorEnvType,
envpool_factory: EnvPoolFactory | None = None,
render_mode_train: str | None = None,
@ -417,7 +418,8 @@ class EnvFactoryRegistered(EnvFactory):
super().__init__(venv_type)
self.task = task
self.envpool_factory = envpool_factory
self.seed = seed
self.train_seed = train_seed
self.test_seed = test_seed
self.render_modes = {
EnvMode.TRAIN: render_mode_train,
EnvMode.TEST: render_mode_test,
@ -445,15 +447,16 @@ class EnvFactoryRegistered(EnvFactory):
return gymnasium.make(self.task, **kwargs)
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed
if self.envpool_factory is not None:
return self.envpool_factory.create_venv(
self.task,
num_envs,
mode,
self.seed,
seed,
self._create_kwargs(mode),
)
else:
venv = super().create_venv(num_envs, mode)
venv.seed(self.seed)
venv.seed(seed)
return venv