From b5a891557f7f50df0a8cbf441038521ba8dcf578 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 24 Oct 2023 12:07:23 +0200 Subject: [PATCH] Revert to simplified environment factory, removing unnecessary config object (configuration shall be part of the factory instance) --- examples/atari/atari_wrapper.py | 7 +------ examples/mujoco/mujoco_env.py | 7 +------ test/highlevel/env_factory.py | 15 ++------------- tianshou/highlevel/env.py | 17 ++--------------- tianshou/highlevel/experiment.py | 19 ++++--------------- tianshou/highlevel/persistence.py | 13 +------------ 6 files changed, 11 insertions(+), 67 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 510106d..2018d1f 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -380,12 +380,7 @@ class AtariEnvFactory(EnvFactory): self.frame_stack = frame_stack self.scale = scale - def create_envs( - self, - num_training_envs: int, - num_test_envs: int, - config=None, - ) -> DiscreteEnvironments: + def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments: env, train_envs, test_envs = make_atari_env( task=self.task, seed=self.seed, diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 3aae335..27d6a39 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -74,12 +74,7 @@ class MujocoEnvFactory(EnvFactory): self.seed = seed self.obs_norm = obs_norm - def create_envs( - self, - num_training_envs: int, - num_test_envs: int, - config=None, - ) -> ContinuousEnvironments: + def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( task=self.task, seed=self.seed, diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index aab96c6..8ed89b1 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -7,16 +7,10 @@ from tianshou.highlevel.env import ( EnvFactory, Environments, ) -from tianshou.highlevel.persistence import PersistableConfigProtocol class DiscreteTestEnvFactory(EnvFactory): - def create_envs( - self, - num_training_envs: int, - num_test_envs: int, - config: PersistableConfigProtocol | None = None, - ) -> Environments: + def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: task = "CartPole-v0" env = gym.make(task) train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) @@ -25,12 +19,7 @@ class DiscreteTestEnvFactory(EnvFactory): class ContinuousTestEnvFactory(EnvFactory): - def create_envs( - self, - num_training_envs: int, - num_test_envs: int, - config: PersistableConfigProtocol | None = None, - ) -> Environments: + def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: task = "Pendulum-v1" env = gym.make(task) train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index a99da13..e57a9b9 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -6,7 +6,7 @@ from typing import Any, TypeAlias import gymnasium as gym from tianshou.env import BaseVectorEnv -from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence +from tianshou.highlevel.persistence import Persistence from tianshou.utils.net.common import TActionShape from tianshou.utils.string import ToStringMixin @@ -140,18 +140,5 @@ class DiscreteEnvironments(Environments): class EnvFactory(ToStringMixin, ABC): @abstractmethod - def create_envs( - self, - num_training_envs: int, - num_test_envs: int, - config: PersistableConfigProtocol | None = None, - ) -> Environments: + def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: pass - - def __call__( - self, - num_training_envs: int, - num_test_envs: int, - config: PersistableConfigProtocol | None = None, - ) -> Environments: - return self.create_envs(num_training_envs, num_test_envs, config=config) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 537ddb6..70c7be7 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,7 +1,7 @@ import os import pickle from abc import abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass from pprint import pformat from typing import Any, Self @@ -26,7 +26,7 @@ from tianshou.highlevel.agent import ( TRPOAgentFactory, ) from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.env import EnvFactory, Environments +from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( ActorFactory, @@ -66,7 +66,6 @@ from tianshou.highlevel.params.policy_params import ( ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.persistence import ( - PersistableConfigProtocol, PersistenceGroup, PolicyPersistence, ) @@ -135,12 +134,10 @@ class Experiment(ToStringMixin): def __init__( self, config: ExperimentConfig, - env_factory: EnvFactory - | Callable[[int, int, PersistableConfigProtocol | None], Environments], + env_factory: EnvFactory, agent_factory: AgentFactory, sampling_config: SamplingConfig, logger_factory: LoggerFactory | None = None, - env_config: PersistableConfigProtocol | None = None, ): if logger_factory is None: logger_factory = LoggerFactoryDefault() @@ -149,7 +146,6 @@ class Experiment(ToStringMixin): self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory - self.env_config = env_config @classmethod def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": @@ -216,10 +212,9 @@ class Experiment(ToStringMixin): self._set_seed() # create environments - envs = self.env_factory( + envs = self.env_factory.create_envs( self.sampling_config.num_train_envs, self.sampling_config.num_test_envs, - self.env_config, ) log.info(f"Created {envs}") @@ -318,14 +313,9 @@ class ExperimentBuilder: self._sampling_config = sampling_config self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactory | None = None - self._env_config: PersistableConfigProtocol | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() - def with_env_config(self, config: PersistableConfigProtocol) -> Self: - self._env_config = config - return self - def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: """Allows to customize the logger factory to use. If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used. @@ -424,7 +414,6 @@ class ExperimentBuilder: agent_factory, self._sampling_config, self._logger_factory, - env_config=self._env_config, ) return experiment diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 54f3c23..951ca08 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -1,9 +1,8 @@ import logging -import os from abc import ABC, abstractmethod from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable +from typing import TYPE_CHECKING import torch @@ -15,16 +14,6 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) -@runtime_checkable -class PersistableConfigProtocol(Protocol): - @classmethod - def load(cls, path: os.PathLike[str]) -> Self: - pass - - def save(self, path: os.PathLike[str]) -> None: - pass - - class PersistEvent(Enum): """Enumeration of persistence events that Persistence objects can react to."""