Change interface of EnvFactory to ensure that configuration
of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
This commit is contained in:
parent
89ce40edc0
commit
6cbee188b8
@ -68,13 +68,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
experiment_config.seed,
|
||||
sampling_config,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||
|
||||
builder = (
|
||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -67,13 +67,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
experiment_config.seed,
|
||||
sampling_config,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||
|
||||
experiment = (
|
||||
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -73,7 +73,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
|
||||
|
||||
builder = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -67,13 +67,7 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
experiment_config.seed,
|
||||
sampling_config,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
|
||||
|
||||
builder = (
|
||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -9,7 +9,6 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||
|
||||
@ -375,26 +374,23 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactory):
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
seed: int,
|
||||
sampling_config: SamplingConfig,
|
||||
frame_stack: int,
|
||||
scale: int = 0,
|
||||
):
|
||||
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
|
||||
self.task = task
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
self.frame_stack = frame_stack
|
||||
self.scale = scale
|
||||
|
||||
def create_envs(self, config=None) -> DiscreteEnvironments:
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config=None,
|
||||
) -> DiscreteEnvironments:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
training_num=self.sampling_config.num_train_envs,
|
||||
test_num=self.sampling_config.num_test_envs,
|
||||
training_num=num_training_envs,
|
||||
test_num=num_test_envs,
|
||||
scale=self.scale,
|
||||
frame_stack=self.frame_stack,
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -53,7 +53,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -5,7 +5,6 @@ import warnings
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||
from tianshou.highlevel.world import World
|
||||
@ -70,18 +69,22 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig, obs_norm=True):
|
||||
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||
self.task = task
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
self.obs_norm = obs_norm
|
||||
|
||||
def create_envs(self, config=None) -> ContinuousEnvironments:
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config=None,
|
||||
) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
num_train_envs=self.sampling_config.num_train_envs,
|
||||
num_test_envs=self.sampling_config.num_test_envs,
|
||||
num_train_envs=num_training_envs,
|
||||
num_test_envs=num_test_envs,
|
||||
obs_norm=self.obs_norm,
|
||||
)
|
||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
|
@ -57,7 +57,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -62,7 +62,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -59,7 +59,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -50,7 +50,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -55,7 +55,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -59,7 +59,7 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
|
||||
|
||||
experiment = (
|
||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -59,7 +59,7 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
|
||||
|
||||
experiment = (
|
||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -11,26 +11,28 @@ from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactory):
|
||||
def __init__(self, test_num=10, train_num=10):
|
||||
self.test_num = test_num
|
||||
self.train_num = train_num
|
||||
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
task = "CartPole-v0"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactory):
|
||||
def __init__(self, test_num=10, train_num=10):
|
||||
self.test_num = test_num
|
||||
self.train_num = train_num
|
||||
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
task = "Pendulum-v1"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.train_num)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(self.test_num)])
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
|
@ -32,7 +32,12 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||
env_factory = ContinuousTestEnvFactory()
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=1,
|
||||
step_per_epoch=100,
|
||||
num_train_envs=2,
|
||||
num_test_envs=2,
|
||||
)
|
||||
experiment_config = ExperimentConfig(persistence_enabled=False)
|
||||
builder = builder_cls(
|
||||
experiment_config=experiment_config,
|
||||
|
@ -25,7 +25,12 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=1,
|
||||
step_per_epoch=100,
|
||||
num_train_envs=2,
|
||||
num_test_envs=2,
|
||||
)
|
||||
builder = builder_cls(
|
||||
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||
env_factory=env_factory,
|
||||
|
@ -140,8 +140,18 @@ class DiscreteEnvironments(Environments):
|
||||
|
||||
class EnvFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
pass
|
||||
|
||||
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
return self.create_envs(config=config)
|
||||
def __call__(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
return self.create_envs(num_training_envs, num_test_envs, config=config)
|
||||
|
@ -136,14 +136,17 @@ class Experiment(ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: ExperimentConfig,
|
||||
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
||||
env_factory: EnvFactory
|
||||
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
|
||||
agent_factory: AgentFactory,
|
||||
sampling_config: SamplingConfig,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = LoggerFactoryDefault()
|
||||
self.config = config
|
||||
self.sampling_config = sampling_config
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
@ -214,7 +217,11 @@ class Experiment(ToStringMixin):
|
||||
self._set_seed()
|
||||
|
||||
# create environments
|
||||
envs = self.env_factory(self.env_config)
|
||||
envs = self.env_factory(
|
||||
self.sampling_config.num_train_envs,
|
||||
self.sampling_config.num_test_envs,
|
||||
self.env_config,
|
||||
)
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
# initialize persistence
|
||||
@ -416,6 +423,7 @@ class ExperimentBuilder:
|
||||
self._config,
|
||||
self._env_factory,
|
||||
agent_factory,
|
||||
self._sampling_config,
|
||||
self._logger_factory,
|
||||
env_config=self._env_config,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user