added explicit env seeding for train and test envs
This commit is contained in:
parent
6746a80f6d
commit
95cbfe6cdf
@ -66,7 +66,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
||||
|
||||
builder = (
|
||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -65,7 +65,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
||||
|
||||
experiment = (
|
||||
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -65,7 +65,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
||||
|
||||
builder = (
|
||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -340,7 +340,7 @@ def make_atari_env(
|
||||
|
||||
:return: a tuple of (single env, training envs, test envs).
|
||||
"""
|
||||
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
|
||||
env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale))
|
||||
envs = env_factory.create_envs(training_num, test_num)
|
||||
return envs.env, envs.train_envs, envs.test_envs
|
||||
|
||||
@ -349,7 +349,8 @@ class AtariEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
seed: int,
|
||||
train_seed: int,
|
||||
test_seed: int,
|
||||
frame_stack: int,
|
||||
scale: bool = False,
|
||||
use_envpool_if_available: bool = True,
|
||||
@ -366,7 +367,8 @@ class AtariEnvFactory(EnvFactoryRegistered):
|
||||
log.info("Not using envpool, because it is not available")
|
||||
super().__init__(
|
||||
task=task,
|
||||
seed=seed,
|
||||
train_seed=train_seed,
|
||||
test_seed=test_seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
envpool_factory=envpool_factory,
|
||||
)
|
||||
|
@ -54,7 +54,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -51,7 +51,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -29,7 +29,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
||||
|
||||
:return: a tuple of (single env, training envs, test envs).
|
||||
"""
|
||||
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
|
||||
envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
|
||||
num_train_envs,
|
||||
num_test_envs,
|
||||
)
|
||||
@ -62,10 +62,11 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self, task: str, seed: int, obs_norm=True) -> None:
|
||||
def __init__(self, task: str, train_seed: int, test_seed: int, obs_norm=True) -> None:
|
||||
super().__init__(
|
||||
task=task,
|
||||
seed=seed,
|
||||
train_seed=train_seed,
|
||||
test_seed=test_seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
|
||||
)
|
||||
|
@ -56,7 +56,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -61,7 +61,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -57,7 +57,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -49,7 +49,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -52,7 +52,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -58,7 +58,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -58,7 +58,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -6,9 +6,9 @@ from tianshou.highlevel.env import (
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
|
||||
|
@ -50,6 +50,9 @@ class SamplingConfig(ToStringMixin):
|
||||
num_train_envs: int = -1
|
||||
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
|
||||
|
||||
train_seed: int = 42
|
||||
"""the seed to use for the training environments."""
|
||||
|
||||
num_test_envs: int = 1
|
||||
"""the number of test environments to use"""
|
||||
|
||||
@ -126,6 +129,10 @@ class SamplingConfig(ToStringMixin):
|
||||
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
|
||||
"""
|
||||
|
||||
@property
|
||||
def test_seed(self) -> int:
|
||||
return self.train_seed + self.num_train_envs + 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.num_train_envs == -1:
|
||||
self.num_train_envs = multiprocessing.cpu_count()
|
||||
|
@ -395,7 +395,8 @@ class EnvFactoryRegistered(EnvFactory):
|
||||
self,
|
||||
*,
|
||||
task: str,
|
||||
seed: int,
|
||||
train_seed: int,
|
||||
test_seed: int,
|
||||
venv_type: VectorEnvType,
|
||||
envpool_factory: EnvPoolFactory | None = None,
|
||||
render_mode_train: str | None = None,
|
||||
@ -417,7 +418,8 @@ class EnvFactoryRegistered(EnvFactory):
|
||||
super().__init__(venv_type)
|
||||
self.task = task
|
||||
self.envpool_factory = envpool_factory
|
||||
self.seed = seed
|
||||
self.train_seed = train_seed
|
||||
self.test_seed = test_seed
|
||||
self.render_modes = {
|
||||
EnvMode.TRAIN: render_mode_train,
|
||||
EnvMode.TEST: render_mode_test,
|
||||
@ -445,15 +447,16 @@ class EnvFactoryRegistered(EnvFactory):
|
||||
return gymnasium.make(self.task, **kwargs)
|
||||
|
||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||
seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed
|
||||
if self.envpool_factory is not None:
|
||||
return self.envpool_factory.create_venv(
|
||||
self.task,
|
||||
num_envs,
|
||||
mode,
|
||||
self.seed,
|
||||
seed,
|
||||
self._create_kwargs(mode),
|
||||
)
|
||||
else:
|
||||
venv = super().create_venv(num_envs, mode)
|
||||
venv.seed(self.seed)
|
||||
venv.seed(seed)
|
||||
return venv
|
||||
|
Loading…
x
Reference in New Issue
Block a user