Change interface of EnvFactory to ensure that configuration

of number of environments in SamplingConfig is used
(values are now passed to factory method)

This is clearer and removes the need to pass otherwise
unnecessary configuration to environment factories at
construction
This commit is contained in:
Dominik Jain 2023-10-18 23:55:23 +02:00
parent 89ce40edc0
commit 6cbee188b8
20 changed files with 82 additions and 71 deletions

View File

@ -68,13 +68,7 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory( env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
builder = ( builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config) DQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -67,13 +67,7 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory( env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
experiment = ( experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config) IQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -73,7 +73,7 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
builder = ( builder = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config) PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -67,13 +67,7 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory( env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
builder = ( builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -9,7 +9,6 @@ import gymnasium as gym
import numpy as np import numpy as np
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
@ -375,26 +374,23 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
class AtariEnvFactory(EnvFactory): class AtariEnvFactory(EnvFactory):
def __init__( def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
self,
task: str,
seed: int,
sampling_config: SamplingConfig,
frame_stack: int,
scale: int = 0,
):
self.task = task self.task = task
self.sampling_config = sampling_config
self.seed = seed self.seed = seed
self.frame_stack = frame_stack self.frame_stack = frame_stack
self.scale = scale self.scale = scale
def create_envs(self, config=None) -> DiscreteEnvironments: def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env( env, train_envs, test_envs = make_atari_env(
task=self.task, task=self.task,
seed=self.seed, seed=self.seed,
training_num=self.sampling_config.num_train_envs, training_num=num_training_envs,
test_num=self.sampling_config.num_test_envs, test_num=num_test_envs,
scale=self.scale, scale=self.scale,
frame_stack=self.frame_stack, frame_stack=self.frame_stack,
) )

View File

@ -55,7 +55,7 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = ( experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config) A2CExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -53,7 +53,7 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = ( experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -5,7 +5,6 @@ import warnings
import gymnasium as gym import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World from tianshou.highlevel.world import World
@ -70,18 +69,22 @@ class MujocoEnvObsRmsPersistence(Persistence):
class MujocoEnvFactory(EnvFactory): class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True): def __init__(self, task: str, seed: int, obs_norm=True):
self.task = task self.task = task
self.sampling_config = sampling_config
self.seed = seed self.seed = seed
self.obs_norm = obs_norm self.obs_norm = obs_norm
def create_envs(self, config=None) -> ContinuousEnvironments: def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env( env, train_envs, test_envs = make_mujoco_env(
task=self.task, task=self.task,
seed=self.seed, seed=self.seed,
num_train_envs=self.sampling_config.num_train_envs, num_train_envs=num_training_envs,
num_test_envs=self.sampling_config.num_test_envs, num_test_envs=num_test_envs,
obs_norm=self.obs_norm, obs_norm=self.obs_norm,
) )
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)

View File

@ -57,7 +57,7 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = ( experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config) NPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -62,7 +62,7 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = ( experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config) PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = ( experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config) REDQExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -50,7 +50,7 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = ( experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config) PGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -55,7 +55,7 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = ( experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config) SACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
experiment = ( experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -59,7 +59,7 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True) env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = ( experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -11,26 +11,28 @@ from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory): class DiscreteTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10): def create_envs(
self.test_num = test_num self,
self.train_num = train_num num_training_envs: int,
num_test_envs: int,
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "CartPole-v0" task = "CartPole-v0"
env = gym.make(task) env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return DiscreteEnvironments(env, train_envs, test_envs) return DiscreteEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactory): class ContinuousTestEnvFactory(EnvFactory):
def __init__(self, test_num=10, train_num=10): def create_envs(
self.test_num = test_num self,
self.train_num = train_num num_training_envs: int,
num_test_envs: int,
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "Pendulum-v1" task = "Pendulum-v1"
env = gym.make(task) env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)]) train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)]) test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return ContinuousEnvironments(env, train_envs, test_envs) return ContinuousEnvironments(env, train_envs, test_envs)

View File

@ -32,7 +32,12 @@ from tianshou.highlevel.experiment import (
) )
def test_experiment_builder_continuous_default_params(builder_cls): def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory() env_factory = ContinuousTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
experiment_config = ExperimentConfig(persistence_enabled=False) experiment_config = ExperimentConfig(persistence_enabled=False)
builder = builder_cls( builder = builder_cls(
experiment_config=experiment_config, experiment_config=experiment_config,

View File

@ -25,7 +25,12 @@ from tianshou.highlevel.experiment import (
) )
def test_experiment_builder_discrete_default_params(builder_cls): def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory() env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = builder_cls( builder = builder_cls(
experiment_config=ExperimentConfig(persistence_enabled=False), experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory, env_factory=env_factory,

View File

@ -140,8 +140,18 @@ class DiscreteEnvironments(Environments):
class EnvFactory(ToStringMixin, ABC): class EnvFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
pass pass
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments: def __call__(
return self.create_envs(config=config) self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
return self.create_envs(num_training_envs, num_test_envs, config=config)

View File

@ -136,14 +136,17 @@ class Experiment(ToStringMixin):
def __init__( def __init__(
self, self,
config: ExperimentConfig, config: ExperimentConfig,
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments], env_factory: EnvFactory
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory, agent_factory: AgentFactory,
sampling_config: SamplingConfig,
logger_factory: LoggerFactory | None = None, logger_factory: LoggerFactory | None = None,
env_config: PersistableConfigProtocol | None = None, env_config: PersistableConfigProtocol | None = None,
): ):
if logger_factory is None: if logger_factory is None:
logger_factory = LoggerFactoryDefault() logger_factory = LoggerFactoryDefault()
self.config = config self.config = config
self.sampling_config = sampling_config
self.env_factory = env_factory self.env_factory = env_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
self.logger_factory = logger_factory self.logger_factory = logger_factory
@ -214,7 +217,11 @@ class Experiment(ToStringMixin):
self._set_seed() self._set_seed()
# create environments # create environments
envs = self.env_factory(self.env_config) envs = self.env_factory(
self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs,
self.env_config,
)
log.info(f"Created {envs}") log.info(f"Created {envs}")
# initialize persistence # initialize persistence
@ -416,6 +423,7 @@ class ExperimentBuilder:
self._config, self._config,
self._env_factory, self._env_factory,
agent_factory, agent_factory,
self._sampling_config,
self._logger_factory, self._logger_factory,
env_config=self._env_config, env_config=self._env_config,
) )