Change interface of EnvFactory to ensure that configuration
of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
This commit is contained in:
parent
89ce40edc0
commit
6cbee188b8
@ -68,13 +68,7 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(
|
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||||
task,
|
|
||||||
experiment_config.seed,
|
|
||||||
sampling_config,
|
|
||||||
frames_stack,
|
|
||||||
scale=scale_obs,
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -67,13 +67,7 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(
|
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||||
task,
|
|
||||||
experiment_config.seed,
|
|
||||||
sampling_config,
|
|
||||||
frames_stack,
|
|
||||||
scale=scale_obs,
|
|
||||||
)
|
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -73,7 +73,7 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -67,13 +67,7 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(
|
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||||
task,
|
|
||||||
experiment_config.seed,
|
|
||||||
sampling_config,
|
|
||||||
frames_stack,
|
|
||||||
scale=scale_obs,
|
|
||||||
)
|
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -9,7 +9,6 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
|
||||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||||
|
|
||||||
@ -375,26 +374,23 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class AtariEnvFactory(EnvFactory):
|
class AtariEnvFactory(EnvFactory):
|
||||||
def __init__(
|
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
|
||||||
self,
|
|
||||||
task: str,
|
|
||||||
seed: int,
|
|
||||||
sampling_config: SamplingConfig,
|
|
||||||
frame_stack: int,
|
|
||||||
scale: int = 0,
|
|
||||||
):
|
|
||||||
self.task = task
|
self.task = task
|
||||||
self.sampling_config = sampling_config
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.frame_stack = frame_stack
|
self.frame_stack = frame_stack
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
def create_envs(self, config=None) -> DiscreteEnvironments:
|
def create_envs(
|
||||||
|
self,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
config=None,
|
||||||
|
) -> DiscreteEnvironments:
|
||||||
env, train_envs, test_envs = make_atari_env(
|
env, train_envs, test_envs = make_atari_env(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
training_num=self.sampling_config.num_train_envs,
|
training_num=num_training_envs,
|
||||||
test_num=self.sampling_config.num_test_envs,
|
test_num=num_test_envs,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
frame_stack=self.frame_stack,
|
frame_stack=self.frame_stack,
|
||||||
)
|
)
|
||||||
|
@ -55,7 +55,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -53,7 +53,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -5,7 +5,6 @@ import warnings
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
@ -70,18 +69,22 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
|||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactory):
|
class MujocoEnvFactory(EnvFactory):
|
||||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True):
|
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||||
self.task = task
|
self.task = task
|
||||||
self.sampling_config = sampling_config
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.obs_norm = obs_norm
|
self.obs_norm = obs_norm
|
||||||
|
|
||||||
def create_envs(self, config=None) -> ContinuousEnvironments:
|
def create_envs(
|
||||||
|
self,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
config=None,
|
||||||
|
) -> ContinuousEnvironments:
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
env, train_envs, test_envs = make_mujoco_env(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
num_train_envs=self.sampling_config.num_train_envs,
|
num_train_envs=num_training_envs,
|
||||||
num_test_envs=self.sampling_config.num_test_envs,
|
num_test_envs=num_test_envs,
|
||||||
obs_norm=self.obs_norm,
|
obs_norm=self.obs_norm,
|
||||||
)
|
)
|
||||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||||
|
@ -57,7 +57,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -62,7 +62,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -59,7 +59,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -50,7 +50,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -55,7 +55,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -59,7 +59,7 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -59,7 +59,7 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -11,26 +11,28 @@ from tianshou.highlevel.persistence import PersistableConfigProtocol
|
|||||||
|
|
||||||
|
|
||||||
class DiscreteTestEnvFactory(EnvFactory):
|
class DiscreteTestEnvFactory(EnvFactory):
|
||||||
def __init__(self, test_num=10, train_num=10):
|
def create_envs(
|
||||||
self.test_num = test_num
|
self,
|
||||||
self.train_num = train_num
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
config: PersistableConfigProtocol | None = None,
|
||||||
|
) -> Environments:
|
||||||
task = "CartPole-v0"
|
task = "CartPole-v0"
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||||
|
|
||||||
|
|
||||||
class ContinuousTestEnvFactory(EnvFactory):
|
class ContinuousTestEnvFactory(EnvFactory):
|
||||||
def __init__(self, test_num=10, train_num=10):
|
def create_envs(
|
||||||
self.test_num = test_num
|
self,
|
||||||
self.train_num = train_num
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
config: PersistableConfigProtocol | None = None,
|
||||||
|
) -> Environments:
|
||||||
task = "Pendulum-v1"
|
task = "Pendulum-v1"
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||||
|
@ -32,7 +32,12 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||||
env_factory = ContinuousTestEnvFactory()
|
env_factory = ContinuousTestEnvFactory()
|
||||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=1,
|
||||||
|
step_per_epoch=100,
|
||||||
|
num_train_envs=2,
|
||||||
|
num_test_envs=2,
|
||||||
|
)
|
||||||
experiment_config = ExperimentConfig(persistence_enabled=False)
|
experiment_config = ExperimentConfig(persistence_enabled=False)
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
experiment_config=experiment_config,
|
experiment_config=experiment_config,
|
||||||
|
@ -25,7 +25,12 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||||
env_factory = DiscreteTestEnvFactory()
|
env_factory = DiscreteTestEnvFactory()
|
||||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=1,
|
||||||
|
step_per_epoch=100,
|
||||||
|
num_train_envs=2,
|
||||||
|
num_test_envs=2,
|
||||||
|
)
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
experiment_config=ExperimentConfig(persistence_enabled=False),
|
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||||
env_factory=env_factory,
|
env_factory=env_factory,
|
||||||
|
@ -140,8 +140,18 @@ class DiscreteEnvironments(Environments):
|
|||||||
|
|
||||||
class EnvFactory(ToStringMixin, ABC):
|
class EnvFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
def create_envs(
|
||||||
|
self,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
config: PersistableConfigProtocol | None = None,
|
||||||
|
) -> Environments:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
def __call__(
|
||||||
return self.create_envs(config=config)
|
self,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
config: PersistableConfigProtocol | None = None,
|
||||||
|
) -> Environments:
|
||||||
|
return self.create_envs(num_training_envs, num_test_envs, config=config)
|
||||||
|
@ -136,14 +136,17 @@ class Experiment(ToStringMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ExperimentConfig,
|
config: ExperimentConfig,
|
||||||
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
env_factory: EnvFactory
|
||||||
|
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
|
sampling_config: SamplingConfig,
|
||||||
logger_factory: LoggerFactory | None = None,
|
logger_factory: LoggerFactory | None = None,
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
if logger_factory is None:
|
if logger_factory is None:
|
||||||
logger_factory = LoggerFactoryDefault()
|
logger_factory = LoggerFactoryDefault()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.sampling_config = sampling_config
|
||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
@ -214,7 +217,11 @@ class Experiment(ToStringMixin):
|
|||||||
self._set_seed()
|
self._set_seed()
|
||||||
|
|
||||||
# create environments
|
# create environments
|
||||||
envs = self.env_factory(self.env_config)
|
envs = self.env_factory(
|
||||||
|
self.sampling_config.num_train_envs,
|
||||||
|
self.sampling_config.num_test_envs,
|
||||||
|
self.env_config,
|
||||||
|
)
|
||||||
log.info(f"Created {envs}")
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
# initialize persistence
|
# initialize persistence
|
||||||
@ -416,6 +423,7 @@ class ExperimentBuilder:
|
|||||||
self._config,
|
self._config,
|
||||||
self._env_factory,
|
self._env_factory,
|
||||||
agent_factory,
|
agent_factory,
|
||||||
|
self._sampling_config,
|
||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
env_config=self._env_config,
|
env_config=self._env_config,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user