Add alternative functional interface for environment creation
where a persistable configuration object is passed as an argument, as this can help to ensure persistability (making the requirement explicit)
This commit is contained in:
parent
d4e604b46e
commit
5bcf514c55
@ -46,7 +46,7 @@ class MujocoEnvFactory(EnvFactory):
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
|
||||
def create_envs(self) -> ContinuousEnvironments:
|
||||
def create_envs(self, config=None) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
|
@ -6,6 +6,7 @@ from typing import Any
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
|
||||
TShape = int | Sequence[int]
|
||||
|
||||
@ -98,5 +99,8 @@ class ContinuousEnvironments(Environments):
|
||||
|
||||
class EnvFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_envs(self) -> Environments:
|
||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
pass
|
||||
|
||||
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||
return self.create_envs(config=config)
|
||||
|
@ -15,7 +15,7 @@ from tianshou.highlevel.agent import (
|
||||
TD3AgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ActorFactory,
|
||||
@ -26,6 +26,7 @@ from tianshou.highlevel.module import (
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.trainer import BaseTrainer
|
||||
@ -55,9 +56,10 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
def __init__(
|
||||
self,
|
||||
config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
||||
agent_factory: AgentFactory,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = DefaultLoggerFactory()
|
||||
@ -65,6 +67,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
self.env_config = env_config
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
seed = self.config.seed
|
||||
@ -78,8 +81,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
|
||||
def run(self, log_name: str) -> None:
|
||||
self._set_seed()
|
||||
|
||||
envs = self.env_factory.create_envs()
|
||||
envs = self.env_factory(self.env_config)
|
||||
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
@ -146,6 +148,11 @@ class RLExperimentBuilder:
|
||||
self._sampling_config = sampling_config
|
||||
self._logger_factory: LoggerFactory | None = None
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._env_config: PersistableConfigProtocol | None = None
|
||||
|
||||
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||
self._env_config = config
|
||||
return self
|
||||
|
||||
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
||||
self._logger_factory = logger_factory
|
||||
@ -187,6 +194,7 @@ class RLExperimentBuilder:
|
||||
self._env_factory,
|
||||
self._create_agent_factory(),
|
||||
self._logger_factory,
|
||||
env_config=self._env_config,
|
||||
)
|
||||
|
||||
|
||||
@ -226,7 +234,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||
|
||||
def with_actor_factory_default(
|
||||
self,
|
||||
@ -346,12 +354,14 @@ class PPOExperimentBuilder(
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
self._params: PPOParams = PPOParams()
|
||||
self._dist_fn = dist_fn
|
||||
self._env_config = env_config
|
||||
|
||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||
self._params = params
|
||||
|
12
tianshou/highlevel/persistence.py
Normal file
12
tianshou/highlevel/persistence.py
Normal file
@ -0,0 +1,12 @@
|
||||
import os
|
||||
from typing import Protocol, Self, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PersistableConfigProtocol(Protocol):
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike[str]) -> Self:
|
||||
pass
|
||||
|
||||
def save(self, path: os.PathLike[str]) -> None:
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user