Revert to simplified environment factory, removing unnecessary config object

(configuration shall be part of the factory instance)
This commit is contained in:
Dominik Jain 2023-10-24 12:07:23 +02:00
parent f7f20649e3
commit b5a891557f
6 changed files with 11 additions and 67 deletions

View File

@ -380,12 +380,7 @@ class AtariEnvFactory(EnvFactory):
self.frame_stack = frame_stack
self.scale = scale
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> DiscreteEnvironments:
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env(
task=self.task,
seed=self.seed,

View File

@ -74,12 +74,7 @@ class MujocoEnvFactory(EnvFactory):
self.seed = seed
self.obs_norm = obs_norm
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> ContinuousEnvironments:
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.task,
seed=self.seed,

View File

@ -7,16 +7,10 @@ from tianshou.highlevel.env import (
EnvFactory,
Environments,
)
from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory):
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
@ -25,12 +19,7 @@ class DiscreteTestEnvFactory(EnvFactory):
class ContinuousTestEnvFactory(EnvFactory):
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])

View File

@ -6,7 +6,7 @@ from typing import Any, TypeAlias
import gymnasium as gym
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
from tianshou.highlevel.persistence import Persistence
from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin
@ -140,18 +140,5 @@ class DiscreteEnvironments(Environments):
class EnvFactory(ToStringMixin, ABC):
@abstractmethod
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
pass
def __call__(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
return self.create_envs(num_training_envs, num_test_envs, config=config)

View File

@ -1,7 +1,7 @@
import os
import pickle
from abc import abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from dataclasses import dataclass
from pprint import pformat
from typing import Any, Self
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
TRPOAgentFactory,
)
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import (
ActorFactory,
@ -66,7 +66,6 @@ from tianshou.highlevel.params.policy_params import (
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import (
PersistableConfigProtocol,
PersistenceGroup,
PolicyPersistence,
)
@ -135,12 +134,10 @@ class Experiment(ToStringMixin):
def __init__(
self,
config: ExperimentConfig,
env_factory: EnvFactory
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
env_factory: EnvFactory,
agent_factory: AgentFactory,
sampling_config: SamplingConfig,
logger_factory: LoggerFactory | None = None,
env_config: PersistableConfigProtocol | None = None,
):
if logger_factory is None:
logger_factory = LoggerFactoryDefault()
@ -149,7 +146,6 @@ class Experiment(ToStringMixin):
self.env_factory = env_factory
self.agent_factory = agent_factory
self.logger_factory = logger_factory
self.env_config = env_config
@classmethod
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
@ -216,10 +212,9 @@ class Experiment(ToStringMixin):
self._set_seed()
# create environments
envs = self.env_factory(
envs = self.env_factory.create_envs(
self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs,
self.env_config,
)
log.info(f"Created {envs}")
@ -318,14 +313,9 @@ class ExperimentBuilder:
self._sampling_config = sampling_config
self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config
return self
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
"""Allows to customize the logger factory to use.
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
@ -424,7 +414,6 @@ class ExperimentBuilder:
agent_factory,
self._sampling_config,
self._logger_factory,
env_config=self._env_config,
)
return experiment

View File

@ -1,9 +1,8 @@
import logging
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
from typing import TYPE_CHECKING
import torch
@ -15,16 +14,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
@runtime_checkable
class PersistableConfigProtocol(Protocol):
@classmethod
def load(cls, path: os.PathLike[str]) -> Self:
pass
def save(self, path: os.PathLike[str]) -> None:
pass
class PersistEvent(Enum):
"""Enumeration of persistence events that Persistence objects can react to."""