Environments: Add option to a use a different factory for test envs
to `from_factory` convenience construction mechanisms
This commit is contained in:
parent
45a1a3f259
commit
e8cc80f990
@ -80,6 +80,7 @@ class Environments(ToStringMixin, ABC):
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||
|
||||
@ -88,10 +89,14 @@ class Environments(ToStringMixin, ABC):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
if test_factory_fn is None:
|
||||
test_factory_fn = factory_fn
|
||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
|
||||
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
|
||||
env = factory_fn()
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
@ -152,6 +157,7 @@ class ContinuousEnvironments(Environments):
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "ContinuousEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -159,6 +165,8 @@ class ContinuousEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -169,6 +177,7 @@ class ContinuousEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
test_factory_fn=test_factory_fn,
|
||||
),
|
||||
)
|
||||
|
||||
@ -217,6 +226,7 @@ class DiscreteEnvironments(Environments):
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "DiscreteEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -224,6 +234,8 @@ class DiscreteEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -234,6 +246,7 @@ class DiscreteEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
test_factory_fn=test_factory_fn,
|
||||
),
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user