Revert to simplified environment factory, removing unnecessary config object
(configuration shall be part of the factory instance)
This commit is contained in:
parent
f7f20649e3
commit
b5a891557f
@ -380,12 +380,7 @@ class AtariEnvFactory(EnvFactory):
|
|||||||
self.frame_stack = frame_stack
|
self.frame_stack = frame_stack
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
def create_envs(
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
|
||||||
self,
|
|
||||||
num_training_envs: int,
|
|
||||||
num_test_envs: int,
|
|
||||||
config=None,
|
|
||||||
) -> DiscreteEnvironments:
|
|
||||||
env, train_envs, test_envs = make_atari_env(
|
env, train_envs, test_envs = make_atari_env(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
|
@ -74,12 +74,7 @@ class MujocoEnvFactory(EnvFactory):
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.obs_norm = obs_norm
|
self.obs_norm = obs_norm
|
||||||
|
|
||||||
def create_envs(
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
||||||
self,
|
|
||||||
num_training_envs: int,
|
|
||||||
num_test_envs: int,
|
|
||||||
config=None,
|
|
||||||
) -> ContinuousEnvironments:
|
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
env, train_envs, test_envs = make_mujoco_env(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
|
@ -7,16 +7,10 @@ from tianshou.highlevel.env import (
|
|||||||
EnvFactory,
|
EnvFactory,
|
||||||
Environments,
|
Environments,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteTestEnvFactory(EnvFactory):
|
class DiscreteTestEnvFactory(EnvFactory):
|
||||||
def create_envs(
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||||
self,
|
|
||||||
num_training_envs: int,
|
|
||||||
num_test_envs: int,
|
|
||||||
config: PersistableConfigProtocol | None = None,
|
|
||||||
) -> Environments:
|
|
||||||
task = "CartPole-v0"
|
task = "CartPole-v0"
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||||
@ -25,12 +19,7 @@ class DiscreteTestEnvFactory(EnvFactory):
|
|||||||
|
|
||||||
|
|
||||||
class ContinuousTestEnvFactory(EnvFactory):
|
class ContinuousTestEnvFactory(EnvFactory):
|
||||||
def create_envs(
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||||
self,
|
|
||||||
num_training_envs: int,
|
|
||||||
num_test_envs: int,
|
|
||||||
config: PersistableConfigProtocol | None = None,
|
|
||||||
) -> Environments:
|
|
||||||
task = "Pendulum-v1"
|
task = "Pendulum-v1"
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, TypeAlias
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
|
from tianshou.highlevel.persistence import Persistence
|
||||||
from tianshou.utils.net.common import TActionShape
|
from tianshou.utils.net.common import TActionShape
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
@ -140,18 +140,5 @@ class DiscreteEnvironments(Environments):
|
|||||||
|
|
||||||
class EnvFactory(ToStringMixin, ABC):
|
class EnvFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_envs(
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||||
self,
|
|
||||||
num_training_envs: int,
|
|
||||||
num_test_envs: int,
|
|
||||||
config: PersistableConfigProtocol | None = None,
|
|
||||||
) -> Environments:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
num_training_envs: int,
|
|
||||||
num_test_envs: int,
|
|
||||||
config: PersistableConfigProtocol | None = None,
|
|
||||||
) -> Environments:
|
|
||||||
return self.create_envs(num_training_envs, num_test_envs, config=config)
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, Self
|
from typing import Any, Self
|
||||||
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
|
|||||||
TRPOAgentFactory,
|
TRPOAgentFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory, Environments
|
from tianshou.highlevel.env import EnvFactory
|
||||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
@ -66,7 +66,6 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.highlevel.persistence import (
|
from tianshou.highlevel.persistence import (
|
||||||
PersistableConfigProtocol,
|
|
||||||
PersistenceGroup,
|
PersistenceGroup,
|
||||||
PolicyPersistence,
|
PolicyPersistence,
|
||||||
)
|
)
|
||||||
@ -135,12 +134,10 @@ class Experiment(ToStringMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ExperimentConfig,
|
config: ExperimentConfig,
|
||||||
env_factory: EnvFactory
|
env_factory: EnvFactory,
|
||||||
| Callable[[int, int, PersistableConfigProtocol | None], Environments],
|
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
logger_factory: LoggerFactory | None = None,
|
logger_factory: LoggerFactory | None = None,
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
|
||||||
):
|
):
|
||||||
if logger_factory is None:
|
if logger_factory is None:
|
||||||
logger_factory = LoggerFactoryDefault()
|
logger_factory = LoggerFactoryDefault()
|
||||||
@ -149,7 +146,6 @@ class Experiment(ToStringMixin):
|
|||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
self.env_config = env_config
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
||||||
@ -216,10 +212,9 @@ class Experiment(ToStringMixin):
|
|||||||
self._set_seed()
|
self._set_seed()
|
||||||
|
|
||||||
# create environments
|
# create environments
|
||||||
envs = self.env_factory(
|
envs = self.env_factory.create_envs(
|
||||||
self.sampling_config.num_train_envs,
|
self.sampling_config.num_train_envs,
|
||||||
self.sampling_config.num_test_envs,
|
self.sampling_config.num_test_envs,
|
||||||
self.env_config,
|
|
||||||
)
|
)
|
||||||
log.info(f"Created {envs}")
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
@ -318,14 +313,9 @@ class ExperimentBuilder:
|
|||||||
self._sampling_config = sampling_config
|
self._sampling_config = sampling_config
|
||||||
self._logger_factory: LoggerFactory | None = None
|
self._logger_factory: LoggerFactory | None = None
|
||||||
self._optim_factory: OptimizerFactory | None = None
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
self._env_config: PersistableConfigProtocol | None = None
|
|
||||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
|
||||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
|
||||||
self._env_config = config
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||||
"""Allows to customize the logger factory to use.
|
"""Allows to customize the logger factory to use.
|
||||||
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||||
@ -424,7 +414,6 @@ class ExperimentBuilder:
|
|||||||
agent_factory,
|
agent_factory,
|
||||||
self._sampling_config,
|
self._sampling_config,
|
||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
env_config=self._env_config,
|
|
||||||
)
|
)
|
||||||
return experiment
|
return experiment
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -15,16 +14,6 @@ if TYPE_CHECKING:
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class PersistableConfigProtocol(Protocol):
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path: os.PathLike[str]) -> Self:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save(self, path: os.PathLike[str]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PersistEvent(Enum):
|
class PersistEvent(Enum):
|
||||||
"""Enumeration of persistence events that Persistence objects can react to."""
|
"""Enumeration of persistence events that Persistence objects can react to."""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user