Environments: Add option to a use a different factory for test envs
to `from_factory` convenience construction mechanisms
This commit is contained in:
parent
45a1a3f259
commit
e8cc80f990
@ -80,6 +80,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||||
) -> "Environments":
|
) -> "Environments":
|
||||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||||
|
|
||||||
@ -88,10 +89,14 @@ class Environments(ToStringMixin, ABC):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||||
|
if None, use `factory_fn` for all environments (train and test)
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
|
if test_factory_fn is None:
|
||||||
|
test_factory_fn = factory_fn
|
||||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||||
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
|
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
|
||||||
env = factory_fn()
|
env = factory_fn()
|
||||||
match env_type:
|
match env_type:
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
@ -152,6 +157,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||||
) -> "ContinuousEnvironments":
|
) -> "ContinuousEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -159,6 +165,8 @@ class ContinuousEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||||
|
if None, use `factory_fn` for all environments (train and test)
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -169,6 +177,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
|
test_factory_fn=test_factory_fn,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -217,6 +226,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||||
) -> "DiscreteEnvironments":
|
) -> "DiscreteEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -224,6 +234,8 @@ class DiscreteEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||||
|
if None, use `factory_fn` for all environments (train and test)
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -234,6 +246,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
|
test_factory_fn=test_factory_fn,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user