From 96298eafd8e5f2878f10d51cec6e40f28dc95cfe Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 24 Oct 2023 15:38:16 +0200 Subject: [PATCH] Add convenient construction mechanisms for Environments (based on factory function for a single environment) --- tianshou/highlevel/env.py | 117 +++++++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index e57a9b9..9ead62e 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -1,11 +1,17 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from enum import Enum -from typing import Any, TypeAlias +from typing import Any, TypeAlias, cast import gymnasium as gym -from tianshou.env import BaseVectorEnv +from tianshou.env import ( + BaseVectorEnv, + DummyVectorEnv, + RayVectorEnv, + ShmemVectorEnv, + SubprocVectorEnv, +) from tianshou.highlevel.persistence import Persistence from tianshou.utils.net.common import TActionShape from tianshou.utils.string import ToStringMixin @@ -34,6 +40,30 @@ class EnvType(Enum): raise AssertionError(f"{requiring_entity} requires discrete environments") +class VectorEnvType(Enum): + DUMMY = "dummy" + """Vectorized environment without parallelization; environments are processed sequentially""" + SUBPROC = "subproc" + """Parallelization based on `subprocess`""" + SUBPROC_SHARED_MEM = "shmem" + """Parallelization based on `subprocess` with shared memory""" + RAY = "ray" + """Parallelization based on the `ray` library""" + + def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv: + match self: + case VectorEnvType.DUMMY: + return DummyVectorEnv(factories) + case VectorEnvType.SUBPROC: + return SubprocVectorEnv(factories) + case VectorEnvType.SUBPROC_SHARED_MEM: + return ShmemVectorEnv(factories) + case VectorEnvType.RAY: + return RayVectorEnv(factories) + case _: + raise NotImplementedError(self) + + class Environments(ToStringMixin, ABC): """Represents (vectorized) environments.""" @@ -43,6 +73,35 @@ class Environments(ToStringMixin, ABC): self.test_envs = test_envs self.persistence: Sequence[Persistence] = [] + @staticmethod + def from_factory_and_type( + factory_fn: Callable[[], gym.Env], + env_type: EnvType, + venv_type: VectorEnvType, + num_training_envs: int, + num_test_envs: int, + ) -> "Environments": + """Creates a suitable subtype instance from a factory function that creates a single instance and + the type of environment (continuous/discrete). + + :param factory_fn: the factory for a single environment instance + :param env_type: the type of environments created by `factory_fn` + :param venv_type: the vector environment type to use for parallelization + :param num_training_envs: the number of training environments to create + :param num_test_envs: the number of test environments to create + :return: the instance + """ + train_envs = venv_type.create_venv([factory_fn] * num_training_envs) + test_envs = venv_type.create_venv([factory_fn] * num_test_envs) + env = factory_fn() + match env_type: + case EnvType.CONTINUOUS: + return ContinuousEnvironments(env, train_envs, test_envs) + case EnvType.DISCRETE: + return DiscreteEnvironments(env, train_envs, test_envs) + case _: + raise ValueError(f"Environment type {env_type} not handled") + def _tostring_includes(self) -> list[str]: return [] @@ -89,6 +148,32 @@ class ContinuousEnvironments(Environments): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) + @staticmethod + def from_factory( + factory_fn: Callable[[], gym.Env], + venv_type: VectorEnvType, + num_training_envs: int, + num_test_envs: int, + ) -> "ContinuousEnvironments": + """Creates an instance from a factory function that creates a single instance. + + :param factory_fn: the factory for a single environment instance + :param venv_type: the vector environment type to use for parallelization + :param num_training_envs: the number of training environments to create + :param num_test_envs: the number of test environments to create + :return: the instance + """ + return cast( + ContinuousEnvironments, + Environments.from_factory_and_type( + factory_fn, + EnvType.CONTINUOUS, + venv_type, + num_training_envs, + num_test_envs, + ), + ) + def info(self) -> dict[str, Any]: d = super().info() d["max_action"] = self.max_action @@ -128,6 +213,32 @@ class DiscreteEnvironments(Environments): self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.action_shape = env.action_space.shape or env.action_space.n # type: ignore + @staticmethod + def from_factory( + factory_fn: Callable[[], gym.Env], + venv_type: VectorEnvType, + num_training_envs: int, + num_test_envs: int, + ) -> "DiscreteEnvironments": + """Creates an instance from a factory function that creates a single instance. + + :param factory_fn: the factory for a single environment instance + :param venv_type: the vector environment type to use for parallelization + :param num_training_envs: the number of training environments to create + :param num_test_envs: the number of test environments to create + :return: the instance + """ + return cast( + DiscreteEnvironments, + Environments.from_factory_and_type( + factory_fn, + EnvType.CONTINUOUS, + venv_type, + num_training_envs, + num_test_envs, + ), + ) + def get_action_shape(self) -> TActionShape: return self.action_shape