Change interface of EnvFactory to ensure that configuration

of number of environments in SamplingConfig is used
(values are now passed to factory method)

This is clearer and removes the need to pass otherwise
unnecessary configuration to environment factories at
construction
This commit is contained in:
Dominik Jain 2023-10-18 23:55:23 +02:00
parent 89ce40edc0
commit 6cbee188b8
20 changed files with 82 additions and 71 deletions

View File

@ -68,13 +68,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -67,13 +67,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -73,7 +73,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
builder = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -67,13 +67,7 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -9,7 +9,6 @@ import gymnasium as gym
import numpy as np
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
@ -375,26 +374,23 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
class AtariEnvFactory(EnvFactory):
def __init__(
self,
task: str,
seed: int,
sampling_config: SamplingConfig,
frame_stack: int,
scale: int = 0,
):
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
self.task = task
self.sampling_config = sampling_config
self.seed = seed
self.frame_stack = frame_stack
self.scale = scale
def create_envs(self, config=None) -> DiscreteEnvironments:
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env(
task=self.task,
seed=self.seed,
training_num=self.sampling_config.num_train_envs,
test_num=self.sampling_config.num_test_envs,
training_num=num_training_envs,
test_num=num_test_envs,
scale=self.scale,
frame_stack=self.frame_stack,
)

View File

@ -55,7 +55,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -53,7 +53,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -5,7 +5,6 @@ import warnings
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World
@ -70,18 +69,22 @@ class MujocoEnvObsRmsPersistence(Persistence):
class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True):
def __init__(self, task: str, seed: int, obs_norm=True):
self.task = task
self.sampling_config = sampling_config
self.seed = seed
self.obs_norm = obs_norm
def create_envs(self, config=None) -> ContinuousEnvironments:
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.task,
seed=self.seed,
num_train_envs=self.sampling_config.num_train_envs,
num_test_envs=self.sampling_config.num_test_envs,
num_train_envs=num_training_envs,
num_test_envs=num_test_envs,
obs_norm=self.obs_norm,
)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)

View File

@ -57,7 +57,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -62,7 +62,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -50,7 +50,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -55,7 +55,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -11,26 +11,28 @@ from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10):
self.test_num = test_num
self.train_num = train_num
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return DiscreteEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10):
self.test_num = test_num
self.train_num = train_num
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return ContinuousEnvironments(env, train_envs, test_envs)

View File

@ -32,7 +32,12 @@ from tianshou.highlevel.experiment import (
)
def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
experiment_config = ExperimentConfig(persistence_enabled=False)
builder = builder_cls(
experiment_config=experiment_config,

View File

@ -25,7 +25,12 @@ from tianshou.highlevel.experiment import (
)
def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = builder_cls(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,

View File

@ -140,8 +140,18 @@ class DiscreteEnvironments(Environments):
class EnvFactory(ToStringMixin, ABC):
@abstractmethod
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
pass
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
return self.create_envs(config=config)
def __call__(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
return self.create_envs(num_training_envs, num_test_envs, config=config)

View File

@ -136,14 +136,17 @@ class Experiment(ToStringMixin):
def __init__(
self,
config: ExperimentConfig,
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
env_factory: EnvFactory
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory,
sampling_config: SamplingConfig,
logger_factory: LoggerFactory | None = None,
env_config: PersistableConfigProtocol | None = None,
):
if logger_factory is None:
logger_factory = LoggerFactoryDefault()
self.config = config
self.sampling_config = sampling_config
self.env_factory = env_factory
self.agent_factory = agent_factory
self.logger_factory = logger_factory
@ -214,7 +217,11 @@ class Experiment(ToStringMixin):
self._set_seed()
# create environments
envs = self.env_factory(self.env_config)
envs = self.env_factory(
self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs,
self.env_config,
)
log.info(f"Created {envs}")
# initialize persistence
@ -416,6 +423,7 @@ class ExperimentBuilder:
self._config,
self._env_factory,
agent_factory,
self._sampling_config,
self._logger_factory,
env_config=self._env_config,
)