Add alternative functional interface for environment creation

where a persistable configuration object is passed as an
argument, as this can help to ensure persistability (making the
requirement explicit)
This commit is contained in:
Michael Panchenko 2023-09-27 14:10:45 +02:00 committed by Dominik Jain
parent d4e604b46e
commit 5bcf514c55
4 changed files with 34 additions and 8 deletions

View File

@ -46,7 +46,7 @@ class MujocoEnvFactory(EnvFactory):
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.seed = seed self.seed = seed
def create_envs(self) -> ContinuousEnvironments: def create_envs(self, config=None) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env( env, train_envs, test_envs = make_mujoco_env(
task=self.task, task=self.task,
seed=self.seed, seed=self.seed,

View File

@ -6,6 +6,7 @@ from typing import Any
import gymnasium as gym import gymnasium as gym
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol
TShape = int | Sequence[int] TShape = int | Sequence[int]
@ -98,5 +99,8 @@ class ContinuousEnvironments(Environments):
class EnvFactory(ABC): class EnvFactory(ABC):
@abstractmethod @abstractmethod
def create_envs(self) -> Environments: def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
pass pass
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
return self.create_envs(config=config)

View File

@ -15,7 +15,7 @@ from tianshou.highlevel.agent import (
TD3AgentFactory, TD3AgentFactory,
) )
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import ( from tianshou.highlevel.module import (
ActorFactory, ActorFactory,
@ -26,6 +26,7 @@ from tianshou.highlevel.module import (
) )
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
@ -55,9 +56,10 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__( def __init__(
self, self,
config: RLExperimentConfig, config: RLExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory, agent_factory: AgentFactory,
logger_factory: LoggerFactory | None = None, logger_factory: LoggerFactory | None = None,
env_config: PersistableConfigProtocol | None = None,
): ):
if logger_factory is None: if logger_factory is None:
logger_factory = DefaultLoggerFactory() logger_factory = DefaultLoggerFactory()
@ -65,6 +67,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
self.env_factory = env_factory self.env_factory = env_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
self.logger_factory = logger_factory self.logger_factory = logger_factory
self.env_config = env_config
def _set_seed(self) -> None: def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
@ -78,8 +81,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
def run(self, log_name: str) -> None: def run(self, log_name: str) -> None:
self._set_seed() self._set_seed()
envs = self.env_factory(self.env_config)
envs = self.env_factory.create_envs()
full_config = self._build_config_dict() full_config = self._build_config_dict()
full_config.update(envs.info()) full_config.update(envs.info())
@ -146,6 +148,11 @@ class RLExperimentBuilder:
self._sampling_config = sampling_config self._sampling_config = sampling_config
self._logger_factory: LoggerFactory | None = None self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None self._optim_factory: OptimizerFactory | None = None
self._env_config: PersistableConfigProtocol | None = None
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
self._env_config = config
return self
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder: def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
self._logger_factory = logger_factory self._logger_factory = logger_factory
@ -187,6 +194,7 @@ class RLExperimentBuilder:
self._env_factory, self._env_factory,
self._create_agent_factory(), self._create_agent_factory(),
self._logger_factory, self._logger_factory,
env_config=self._env_config,
) )
@ -226,7 +234,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self): def __init__(self):
super().__init__(ContinuousActorType.GAUSSIAN) super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default( def with_actor_factory_default(
self, self,
@ -346,12 +354,14 @@ class PPOExperimentBuilder(
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution], dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
env_config: PersistableConfigProtocol | None = None,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOParams = PPOParams() self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn self._dist_fn = dist_fn
self._env_config = env_config
def with_ppo_params(self, params: PPOParams) -> Self: def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params self._params = params

View File

@ -0,0 +1,12 @@
import os
from typing import Protocol, Self, runtime_checkable
@runtime_checkable
class PersistableConfigProtocol(Protocol):
@classmethod
def load(cls, path: os.PathLike[str]) -> Self:
pass
def save(self, path: os.PathLike[str]) -> None:
pass