Add alternative functional interface for environment creation
where a persistable configuration object is passed as an argument, as this can help to ensure persistability (making the requirement explicit)
This commit is contained in:
parent
d4e604b46e
commit
5bcf514c55
@ -46,7 +46,7 @@ class MujocoEnvFactory(EnvFactory):
|
|||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def create_envs(self) -> ContinuousEnvironments:
|
def create_envs(self, config=None) -> ContinuousEnvironments:
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
env, train_envs, test_envs = make_mujoco_env(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
|
|
||||||
TShape = int | Sequence[int]
|
TShape = int | Sequence[int]
|
||||||
|
|
||||||
@ -98,5 +99,8 @@ class ContinuousEnvironments(Environments):
|
|||||||
|
|
||||||
class EnvFactory(ABC):
|
class EnvFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_envs(self) -> Environments:
|
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||||
|
return self.create_envs(config=config)
|
||||||
|
@ -15,7 +15,7 @@ from tianshou.highlevel.agent import (
|
|||||||
TD3AgentFactory,
|
TD3AgentFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory
|
from tianshou.highlevel.env import EnvFactory, Environments
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
@ -26,6 +26,7 @@ from tianshou.highlevel.module import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
||||||
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
from tianshou.policy.modelfree.pg import TDistParams
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
@ -55,9 +56,10 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RLExperimentConfig,
|
config: RLExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
logger_factory: LoggerFactory | None = None,
|
logger_factory: LoggerFactory | None = None,
|
||||||
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
if logger_factory is None:
|
if logger_factory is None:
|
||||||
logger_factory = DefaultLoggerFactory()
|
logger_factory = DefaultLoggerFactory()
|
||||||
@ -65,6 +67,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
|
self.env_config = env_config
|
||||||
|
|
||||||
def _set_seed(self) -> None:
|
def _set_seed(self) -> None:
|
||||||
seed = self.config.seed
|
seed = self.config.seed
|
||||||
@ -78,8 +81,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
|
|
||||||
def run(self, log_name: str) -> None:
|
def run(self, log_name: str) -> None:
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
|
envs = self.env_factory(self.env_config)
|
||||||
envs = self.env_factory.create_envs()
|
|
||||||
|
|
||||||
full_config = self._build_config_dict()
|
full_config = self._build_config_dict()
|
||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
@ -146,6 +148,11 @@ class RLExperimentBuilder:
|
|||||||
self._sampling_config = sampling_config
|
self._sampling_config = sampling_config
|
||||||
self._logger_factory: LoggerFactory | None = None
|
self._logger_factory: LoggerFactory | None = None
|
||||||
self._optim_factory: OptimizerFactory | None = None
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
|
self._env_config: PersistableConfigProtocol | None = None
|
||||||
|
|
||||||
|
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
||||||
|
self._env_config = config
|
||||||
|
return self
|
||||||
|
|
||||||
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
||||||
self._logger_factory = logger_factory
|
self._logger_factory = logger_factory
|
||||||
@ -187,6 +194,7 @@ class RLExperimentBuilder:
|
|||||||
self._env_factory,
|
self._env_factory,
|
||||||
self._create_agent_factory(),
|
self._create_agent_factory(),
|
||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
|
env_config=self._env_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -226,7 +234,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||||
|
|
||||||
def with_actor_factory_default(
|
def with_actor_factory_default(
|
||||||
self,
|
self,
|
||||||
@ -346,12 +354,14 @@ class PPOExperimentBuilder(
|
|||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||||
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
self._dist_fn = dist_fn
|
self._dist_fn = dist_fn
|
||||||
|
self._env_config = env_config
|
||||||
|
|
||||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
|
12
tianshou/highlevel/persistence.py
Normal file
12
tianshou/highlevel/persistence.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import os
|
||||||
|
from typing import Protocol, Self, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class PersistableConfigProtocol(Protocol):
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: os.PathLike[str]) -> Self:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self, path: os.PathLike[str]) -> None:
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user