Add alternative functional interface for environment creation

where a persistable configuration object is passed as an
argument, as this can help to ensure persistability (making the
requirement explicit)
This commit is contained in:
Michael Panchenko 2023-09-27 14:10:45 +02:00 committed by Dominik Jain
parent d4e604b46e
commit 5bcf514c55
4 changed files with 34 additions and 8 deletions

View File

@ -46,7 +46,7 @@ class MujocoEnvFactory(EnvFactory):
self.sampling_config = sampling_config
self.seed = seed
def create_envs(self) -> ContinuousEnvironments:
def create_envs(self, config=None) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.task,
seed=self.seed,

View File

@ -6,6 +6,7 @@ from typing import Any
import gymnasium as gym
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol
TShape = int | Sequence[int]
@ -98,5 +99,8 @@ class ContinuousEnvironments(Environments):
class EnvFactory(ABC):
@abstractmethod
def create_envs(self) -> Environments:
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
pass
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
return self.create_envs(config=config)

View File

@ -15,7 +15,7 @@ from tianshou.highlevel.agent import (
TD3AgentFactory,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import (
ActorFactory,
@ -26,6 +26,7 @@ from tianshou.highlevel.module import (
)
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer
@ -55,9 +56,10 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__(
self,
config: RLExperimentConfig,
env_factory: EnvFactory,
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory,
logger_factory: LoggerFactory | None = None,
env_config: PersistableConfigProtocol | None = None,
):
if logger_factory is None:
logger_factory = DefaultLoggerFactory()
@ -65,6 +67,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
self.env_factory = env_factory
self.agent_factory = agent_factory
self.logger_factory = logger_factory
self.env_config = env_config
def _set_seed(self) -> None:
seed = self.config.seed
@ -78,8 +81,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
def run(self, log_name: str) -> None:
self._set_seed()
envs = self.env_factory.create_envs()
envs = self.env_factory(self.env_config)
full_config = self._build_config_dict()
full_config.update(envs.info())
@ -146,6 +148,11 @@ class RLExperimentBuilder:
self._sampling_config = sampling_config
self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config
return self
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
self._logger_factory = logger_factory
@ -187,6 +194,7 @@ class RLExperimentBuilder:
self._env_factory,
self._create_agent_factory(),
self._logger_factory,
env_config=self._env_config,
)
@ -226,7 +234,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
super().__init__(ContinuousActorType.GAUSSIAN)
super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(
self,
@ -346,12 +354,14 @@ class PPOExperimentBuilder(
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
env_config: PersistableConfigProtocol | None = None,
):
super().__init__(experiment_config, env_factory, sampling_config)
super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn
self._env_config = env_config
def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params

View File

@ -0,0 +1,12 @@
import os
from typing import Protocol, Self, runtime_checkable
@runtime_checkable
class PersistableConfigProtocol(Protocol):
@classmethod
def load(cls, path: os.PathLike[str]) -> Self:
pass
def save(self, path: os.PathLike[str]) -> None:
pass