Revert to simplified environment factory, removing unnecessary config object
(configuration shall be part of the factory instance)
This commit is contained in:
parent
f7f20649e3
commit
b5a891557f
@ -380,12 +380,7 @@ class AtariEnvFactory(EnvFactory):
|
||||
self.frame_stack = frame_stack
|
||||
self.scale = scale
|
||||
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config=None,
|
||||
) -> DiscreteEnvironments:
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
|
@ -74,12 +74,7 @@ class MujocoEnvFactory(EnvFactory):
|
||||
self.seed = seed
|
||||
self.obs_norm = obs_norm
|
||||
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config=None,
|
||||
) -> ContinuousEnvironments:
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
|
@ -7,16 +7,10 @@ from tianshou.highlevel.env import (
|
||||
EnvFactory,
|
||||
Environments,
|
||||
)
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactory):
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
task = "CartPole-v0"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
@ -25,12 +19,7 @@ class DiscreteTestEnvFactory(EnvFactory):
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactory):
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
task = "Pendulum-v1"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, TypeAlias
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
|
||||
from tianshou.highlevel.persistence import Persistence
|
||||
from tianshou.utils.net.common import TActionShape
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -140,18 +140,5 @@ class DiscreteEnvironments(Environments):
|
||||
|
||||
class EnvFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_envs(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
config: PersistableConfigProtocol | None = None,
|
||||
) -> Environments:
|
||||
return self.create_envs(num_training_envs, num_test_envs, config=config)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any, Self
|
||||
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
|
||||
TRPOAgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
@ -66,7 +66,6 @@ from tianshou.highlevel.params.policy_params import (
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import (
|
||||
PersistableConfigProtocol,
|
||||
PersistenceGroup,
|
||||
PolicyPersistence,
|
||||
)
|
||||
@ -135,12 +134,10 @@ class Experiment(ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: ExperimentConfig,
|
||||
env_factory: EnvFactory
|
||||
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
|
||||
env_factory: EnvFactory,
|
||||
agent_factory: AgentFactory,
|
||||
sampling_config: SamplingConfig,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = LoggerFactoryDefault()
|
||||
@ -149,7 +146,6 @@ class Experiment(ToStringMixin):
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
self.env_config = env_config
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
||||
@ -216,10 +212,9 @@ class Experiment(ToStringMixin):
|
||||
self._set_seed()
|
||||
|
||||
# create environments
|
||||
envs = self.env_factory(
|
||||
envs = self.env_factory.create_envs(
|
||||
self.sampling_config.num_train_envs,
|
||||
self.sampling_config.num_test_envs,
|
||||
self.env_config,
|
||||
)
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
@ -318,14 +313,9 @@ class ExperimentBuilder:
|
||||
self._sampling_config = sampling_config
|
||||
self._logger_factory: LoggerFactory | None = None
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._env_config: PersistableConfigProtocol | None = None
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
|
||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||
self._env_config = config
|
||||
return self
|
||||
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
"""Allows to customize the logger factory to use.
|
||||
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||
@ -424,7 +414,6 @@ class ExperimentBuilder:
|
||||
agent_factory,
|
||||
self._sampling_config,
|
||||
self._logger_factory,
|
||||
env_config=self._env_config,
|
||||
)
|
||||
return experiment
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,16 +14,6 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PersistableConfigProtocol(Protocol):
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike[str]) -> Self:
|
||||
pass
|
||||
|
||||
def save(self, path: os.PathLike[str]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PersistEvent(Enum):
|
||||
"""Enumeration of persistence events that Persistence objects can react to."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user