Environments: Add option to a use a different factory for test envs

to `from_factory` convenience construction mechanisms
This commit is contained in:
Dominik Jain 2023-12-18 12:52:05 +01:00
parent 45a1a3f259
commit e8cc80f990

View File

@ -80,6 +80,7 @@ class Environments(ToStringMixin, ABC):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
@ -88,10 +89,14 @@ class Environments(ToStringMixin, ABC):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
if test_factory_fn is None:
test_factory_fn = factory_fn
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
env = factory_fn()
match env_type:
case EnvType.CONTINUOUS:
@ -152,6 +157,7 @@ class ContinuousEnvironments(Environments):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -159,6 +165,8 @@ class ContinuousEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -169,6 +177,7 @@ class ContinuousEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)
@ -217,6 +226,7 @@ class DiscreteEnvironments(Environments):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -224,6 +234,8 @@ class DiscreteEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -234,6 +246,7 @@ class DiscreteEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)