Revert to simplified environment factory, removing unnecessary config object

(configuration shall be part of the factory instance)
This commit is contained in:
Dominik Jain 2023-10-24 12:07:23 +02:00
parent f7f20649e3
commit b5a891557f
6 changed files with 11 additions and 67 deletions

View File

@ -380,12 +380,7 @@ class AtariEnvFactory(EnvFactory):
self.frame_stack = frame_stack self.frame_stack = frame_stack
self.scale = scale self.scale = scale
def create_envs( def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env( env, train_envs, test_envs = make_atari_env(
task=self.task, task=self.task,
seed=self.seed, seed=self.seed,

View File

@ -74,12 +74,7 @@ class MujocoEnvFactory(EnvFactory):
self.seed = seed self.seed = seed
self.obs_norm = obs_norm self.obs_norm = obs_norm
def create_envs( def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
self,
num_training_envs: int,
num_test_envs: int,
config=None,
) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env( env, train_envs, test_envs = make_mujoco_env(
task=self.task, task=self.task,
seed=self.seed, seed=self.seed,

View File

@ -7,16 +7,10 @@ from tianshou.highlevel.env import (
EnvFactory, EnvFactory,
Environments, Environments,
) )
from tianshou.highlevel.persistence import PersistableConfigProtocol
class DiscreteTestEnvFactory(EnvFactory): class DiscreteTestEnvFactory(EnvFactory):
def create_envs( def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "CartPole-v0" task = "CartPole-v0"
env = gym.make(task) env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
@ -25,12 +19,7 @@ class DiscreteTestEnvFactory(EnvFactory):
class ContinuousTestEnvFactory(EnvFactory): class ContinuousTestEnvFactory(EnvFactory):
def create_envs( def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
task = "Pendulum-v1" task = "Pendulum-v1"
env = gym.make(task) env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])

View File

@ -6,7 +6,7 @@ from typing import Any, TypeAlias
import gymnasium as gym import gymnasium as gym
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence from tianshou.highlevel.persistence import Persistence
from tianshou.utils.net.common import TActionShape from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -140,18 +140,5 @@ class DiscreteEnvironments(Environments):
class EnvFactory(ToStringMixin, ABC): class EnvFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_envs( def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
pass pass
def __call__(
self,
num_training_envs: int,
num_test_envs: int,
config: PersistableConfigProtocol | None = None,
) -> Environments:
return self.create_envs(num_training_envs, num_test_envs, config=config)

View File

@ -1,7 +1,7 @@
import os import os
import pickle import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pformat from pprint import pformat
from typing import Any, Self from typing import Any, Self
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
TRPOAgentFactory, TRPOAgentFactory,
) )
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
@ -66,7 +66,6 @@ from tianshou.highlevel.params.policy_params import (
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import ( from tianshou.highlevel.persistence import (
PersistableConfigProtocol,
PersistenceGroup, PersistenceGroup,
PolicyPersistence, PolicyPersistence,
) )
@ -135,12 +134,10 @@ class Experiment(ToStringMixin):
def __init__( def __init__(
self, self,
config: ExperimentConfig, config: ExperimentConfig,
env_factory: EnvFactory env_factory: EnvFactory,
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory, agent_factory: AgentFactory,
sampling_config: SamplingConfig, sampling_config: SamplingConfig,
logger_factory: LoggerFactory | None = None, logger_factory: LoggerFactory | None = None,
env_config: PersistableConfigProtocol | None = None,
): ):
if logger_factory is None: if logger_factory is None:
logger_factory = LoggerFactoryDefault() logger_factory = LoggerFactoryDefault()
@ -149,7 +146,6 @@ class Experiment(ToStringMixin):
self.env_factory = env_factory self.env_factory = env_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
self.logger_factory = logger_factory self.logger_factory = logger_factory
self.env_config = env_config
@classmethod @classmethod
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
@ -216,10 +212,9 @@ class Experiment(ToStringMixin):
self._set_seed() self._set_seed()
# create environments # create environments
envs = self.env_factory( envs = self.env_factory.create_envs(
self.sampling_config.num_train_envs, self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs, self.sampling_config.num_test_envs,
self.env_config,
) )
log.info(f"Created {envs}") log.info(f"Created {envs}")
@ -318,14 +313,9 @@ class ExperimentBuilder:
self._sampling_config = sampling_config self._sampling_config = sampling_config
self._logger_factory: LoggerFactory | None = None self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config
return self
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
"""Allows to customize the logger factory to use. """Allows to customize the logger factory to use.
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used. If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
@ -424,7 +414,6 @@ class ExperimentBuilder:
agent_factory, agent_factory,
self._sampling_config, self._sampling_config,
self._logger_factory, self._logger_factory,
env_config=self._env_config,
) )
return experiment return experiment

View File

@ -1,9 +1,8 @@
import logging import logging
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable from typing import TYPE_CHECKING
import torch import torch
@ -15,16 +14,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@runtime_checkable
class PersistableConfigProtocol(Protocol):
@classmethod
def load(cls, path: os.PathLike[str]) -> Self:
pass
def save(self, path: os.PathLike[str]) -> None:
pass
class PersistEvent(Enum): class PersistEvent(Enum):
"""Enumeration of persistence events that Persistence objects can react to.""" """Enumeration of persistence events that Persistence objects can react to."""