diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 0e8a283..ded54c3 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -46,7 +46,7 @@ class MujocoEnvFactory(EnvFactory): self.sampling_config = sampling_config self.seed = seed - def create_envs(self) -> ContinuousEnvironments: + def create_envs(self, config=None) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( task=self.task, seed=self.seed, diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 8f4bd0c..4b7942e 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -6,6 +6,7 @@ from typing import Any import gymnasium as gym from tianshou.env import BaseVectorEnv +from tianshou.highlevel.persistence import PersistableConfigProtocol TShape = int | Sequence[int] @@ -98,5 +99,8 @@ class ContinuousEnvironments(Environments): class EnvFactory(ABC): @abstractmethod - def create_envs(self) -> Environments: + def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: pass + + def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments: + return self.create_envs(config=config) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index c483d76..23d51f7 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -15,7 +15,7 @@ from tianshou.highlevel.agent import ( TD3AgentFactory, ) from tianshou.highlevel.config import RLSamplingConfig -from tianshou.highlevel.env import EnvFactory +from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.module import ( ActorFactory, @@ -26,6 +26,7 @@ from tianshou.highlevel.module import ( ) from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params +from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.policy import BasePolicy from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer @@ -55,9 +56,10 @@ class RLExperiment(Generic[TPolicy, TTrainer]): def __init__( self, config: RLExperimentConfig, - env_factory: EnvFactory, + env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments], agent_factory: AgentFactory, logger_factory: LoggerFactory | None = None, + env_config: PersistableConfigProtocol | None = None, ): if logger_factory is None: logger_factory = DefaultLoggerFactory() @@ -65,6 +67,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]): self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory + self.env_config = env_config def _set_seed(self) -> None: seed = self.config.seed @@ -78,8 +81,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]): def run(self, log_name: str) -> None: self._set_seed() - - envs = self.env_factory.create_envs() + envs = self.env_factory(self.env_config) full_config = self._build_config_dict() full_config.update(envs.info()) @@ -146,6 +148,11 @@ class RLExperimentBuilder: self._sampling_config = sampling_config self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactory | None = None + self._env_config: PersistableConfigProtocol | None = None + + def with_env_config(self, config: PersistableConfigProtocol) -> Self: + self._env_config = config + return self def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder: self._logger_factory = logger_factory @@ -187,6 +194,7 @@ class RLExperimentBuilder: self._env_factory, self._create_agent_factory(), self._logger_factory, + env_config=self._env_config, ) @@ -226,7 +234,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" def __init__(self): - super().__init__(ContinuousActorType.GAUSSIAN) + super().__init__(ContinuousActorType.DETERMINISTIC) def with_actor_factory_default( self, @@ -346,12 +354,14 @@ class PPOExperimentBuilder( env_factory: EnvFactory, sampling_config: RLSamplingConfig, dist_fn: Callable[[TDistParams], torch.distributions.Distribution], + env_config: PersistableConfigProtocol | None = None, ): - super().__init__(experiment_config, env_factory, sampling_config) + super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self) self._params: PPOParams = PPOParams() self._dist_fn = dist_fn + self._env_config = env_config def with_ppo_params(self, params: PPOParams) -> Self: self._params = params diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py new file mode 100644 index 0000000..71db019 --- /dev/null +++ b/tianshou/highlevel/persistence.py @@ -0,0 +1,12 @@ +import os +from typing import Protocol, Self, runtime_checkable + + +@runtime_checkable +class PersistableConfigProtocol(Protocol): + @classmethod + def load(cls, path: os.PathLike[str]) -> Self: + pass + + def save(self, path: os.PathLike[str]) -> None: + pass