Add documentation, improve structure of 'module' package
This commit is contained in:
parent
97e21b5ddf
commit
4b270eaa2d
@ -8,9 +8,11 @@ from torch import nn
|
|||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.actor import ActorFactory
|
from tianshou.highlevel.module.actor import ActorFactory
|
||||||
from tianshou.highlevel.module.core import (
|
from tianshou.highlevel.module.core import (
|
||||||
|
TDevice,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.module.intermediate import (
|
||||||
IntermediateModule,
|
IntermediateModule,
|
||||||
IntermediateModuleFactory,
|
IntermediateModuleFactory,
|
||||||
TDevice,
|
|
||||||
)
|
)
|
||||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||||
|
|
||||||
|
@ -81,6 +81,8 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC, ToStringMixin):
|
class AgentFactory(ABC, ToStringMixin):
|
||||||
|
"""Factory for the creation of an agent's policy, its trainer as well as collectors."""
|
||||||
|
|
||||||
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
|
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
@ -14,6 +14,8 @@ TObservationShape: TypeAlias = int | Sequence[int]
|
|||||||
|
|
||||||
|
|
||||||
class EnvType(Enum):
|
class EnvType(Enum):
|
||||||
|
"""Enumeration of environment types."""
|
||||||
|
|
||||||
CONTINUOUS = "continuous"
|
CONTINUOUS = "continuous"
|
||||||
DISCRETE = "discrete"
|
DISCRETE = "discrete"
|
||||||
|
|
||||||
@ -33,6 +35,8 @@ class EnvType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Environments(ToStringMixin, ABC):
|
class Environments(ToStringMixin, ABC):
|
||||||
|
"""Represents (vectorized) environments."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
@ -52,6 +56,11 @@ class Environments(ToStringMixin, ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def set_persistence(self, *p: Persistence) -> None:
|
def set_persistence(self, *p: Persistence) -> None:
|
||||||
|
"""Associates the given persistence handlers which may persist and restore
|
||||||
|
environment-specific information.
|
||||||
|
|
||||||
|
:param p: persistence handlers
|
||||||
|
"""
|
||||||
self.persistence = p
|
self.persistence = p
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -74,6 +83,8 @@ class Environments(ToStringMixin, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ContinuousEnvironments(Environments):
|
class ContinuousEnvironments(Environments):
|
||||||
|
"""Represents (vectorized) continuous environments."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
super().__init__(env, train_envs, test_envs)
|
super().__init__(env, train_envs, test_envs)
|
||||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||||
@ -110,6 +121,8 @@ class ContinuousEnvironments(Environments):
|
|||||||
|
|
||||||
|
|
||||||
class DiscreteEnvironments(Environments):
|
class DiscreteEnvironments(Environments):
|
||||||
|
"""Represents (vectorized) discrete environments."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
super().__init__(env, train_envs, test_envs)
|
super().__init__(env, train_envs, test_envs)
|
||||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||||
|
@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory, Environments
|
from tianshou.highlevel.env import EnvFactory, Environments
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
|
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
ActorFactoryDefault,
|
ActorFactoryDefault,
|
||||||
@ -38,8 +38,6 @@ from tianshou.highlevel.module.actor import (
|
|||||||
IntermediateModuleFactoryFromActorFactory,
|
IntermediateModuleFactoryFromActorFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.core import (
|
from tianshou.highlevel.module.core import (
|
||||||
ImplicitQuantileNetworkFactory,
|
|
||||||
IntermediateModuleFactory,
|
|
||||||
TDevice,
|
TDevice,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.critic import (
|
from tianshou.highlevel.module.critic import (
|
||||||
@ -49,6 +47,8 @@ from tianshou.highlevel.module.critic import (
|
|||||||
CriticFactoryDefault,
|
CriticFactoryDefault,
|
||||||
CriticFactoryReuseActor,
|
CriticFactoryReuseActor,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
||||||
|
from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory
|
||||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
@ -116,8 +116,12 @@ class ExperimentConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExperimentResult:
|
class ExperimentResult:
|
||||||
|
"""Contains the results of an experiment."""
|
||||||
|
|
||||||
world: World
|
world: World
|
||||||
|
"""contains all the essential instances of the experiment"""
|
||||||
trainer_result: dict[str, Any] | None
|
trainer_result: dict[str, Any] | None
|
||||||
|
"""dictionary of results as returned by the trained (if any)"""
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||||
@ -140,7 +144,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
env_config: PersistableConfigProtocol | None = None,
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
if logger_factory is None:
|
if logger_factory is None:
|
||||||
logger_factory = DefaultLoggerFactory()
|
logger_factory = LoggerFactoryDefault()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
@ -179,7 +183,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, experiment_name: str | None = None, logger_run_id: str | None = None,
|
self,
|
||||||
|
experiment_name: str | None = None,
|
||||||
|
logger_run_id: str | None = None,
|
||||||
) -> ExperimentResult:
|
) -> ExperimentResult:
|
||||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||||
directory) where all results associated with the experiment will be saved.
|
directory) where all results associated with the experiment will be saved.
|
||||||
@ -317,14 +323,31 @@ class ExperimentBuilder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||||
|
"""Allows to customize the logger factory to use.
|
||||||
|
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||||
|
|
||||||
|
:param logger_factory: the factory to use
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._logger_factory = logger_factory
|
self._logger_factory = logger_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
|
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
|
||||||
|
"""Allows to define a wrapper around the policy that is created, extending the original policy.
|
||||||
|
|
||||||
|
:param policy_wrapper_factory: the factory for the wrapper
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._policy_wrapper_factory = policy_wrapper_factory
|
self._policy_wrapper_factory = policy_wrapper_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||||
|
"""Allows to customize the gradient-based optimizer to use.
|
||||||
|
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
|
||||||
|
|
||||||
|
:param optim_factory: the optimizer factory
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._optim_factory = optim_factory
|
self._optim_factory = optim_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -345,14 +368,30 @@ class ExperimentBuilder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
||||||
|
"""Allows to define a callback function which is called at the beginning of every epoch during training.
|
||||||
|
|
||||||
|
:param callback: the callback
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._trainer_callbacks.epoch_callback_train = callback
|
self._trainer_callbacks.epoch_callback_train = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
||||||
|
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
|
||||||
|
|
||||||
|
:param callback: the callback
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._trainer_callbacks.epoch_callback_test = callback
|
self._trainer_callbacks.epoch_callback_test = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||||
|
"""Allows to define a callback that decides whether training shall stop early.
|
||||||
|
The callback receives the undiscounted returns of the testing result.
|
||||||
|
|
||||||
|
:param callback: the callback
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._trainer_callbacks.stop_callback = callback
|
self._trainer_callbacks.stop_callback = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -367,6 +406,10 @@ class ExperimentBuilder:
|
|||||||
return self._optim_factory
|
return self._optim_factory
|
||||||
|
|
||||||
def build(self) -> Experiment:
|
def build(self) -> Experiment:
|
||||||
|
"""Creates the experiment based on the options specified via this builder.
|
||||||
|
|
||||||
|
:return: the experiment
|
||||||
|
"""
|
||||||
agent_factory = self._create_agent_factory()
|
agent_factory = self._create_agent_factory()
|
||||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||||
if self._policy_wrapper_factory:
|
if self._policy_wrapper_factory:
|
||||||
@ -388,6 +431,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
self._actor_factory: ActorFactory | None = None
|
self._actor_factory: ActorFactory | None = None
|
||||||
|
|
||||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||||
|
"""Allows to customize the actor component via the specification of a factory.
|
||||||
|
If this function is not called, a default actor factory (with default parameters) will be used.
|
||||||
|
|
||||||
|
:param actor_factory: the factory to use for the creation of the actor network
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._actor_factory = actor_factory
|
self._actor_factory = actor_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -397,6 +446,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||||
|
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||||
|
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||||
|
shall be computed from the input; if False, sigma is an independent parameter.
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._actor_factory = ActorFactoryDefault(
|
self._actor_factory = ActorFactoryDefault(
|
||||||
self._continuous_actor_type,
|
self._continuous_actor_type,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
@ -406,6 +461,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def get_actor_future(self) -> ActorFuture:
|
def get_actor_future(self) -> ActorFuture:
|
||||||
|
""":return: an object, which, in the future, will contain the actor instance that is created for the experiment."""
|
||||||
return self._actor_future
|
return self._actor_future
|
||||||
|
|
||||||
def _get_actor_factory(self) -> ActorFactory:
|
def _get_actor_factory(self) -> ActorFactory:
|
||||||
@ -431,6 +487,15 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||||
|
The default actor factory uses an MLP-style architecture.
|
||||||
|
|
||||||
|
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||||
|
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||||
|
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||||
|
shall be computed from the input; if False, sigma is an independent parameter.
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
return super()._with_actor_factory_default(
|
return super()._with_actor_factory_default(
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
continuous_unbounded=continuous_unbounded,
|
continuous_unbounded=continuous_unbounded,
|
||||||
@ -445,6 +510,12 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
|||||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||||
|
|
||||||
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
||||||
|
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||||
|
The default actor factory uses an MLP-style architecture.
|
||||||
|
|
||||||
|
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
return super()._with_actor_factory_default(hidden_sizes)
|
return super()._with_actor_factory_default(hidden_sizes)
|
||||||
|
|
||||||
|
|
||||||
@ -480,6 +551,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
super().__init__(1, actor_future_provider)
|
super().__init__(1, actor_future_provider)
|
||||||
|
|
||||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
|
"""Specifies that the given factory shall be used for the critic.
|
||||||
|
|
||||||
|
:param critic_factory: the critic factory
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._with_critic_factory(0, critic_factory)
|
self._with_critic_factory(0, critic_factory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -487,6 +563,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
"""Makes the critic use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -496,7 +577,7 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
|
|||||||
super().__init__(actor_future_provider)
|
super().__init__(actor_future_provider)
|
||||||
|
|
||||||
def with_critic_factory_use_actor(self) -> Self:
|
def with_critic_factory_use_actor(self) -> Self:
|
||||||
"""Makes the critic use the same network as the actor."""
|
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
|
||||||
return self._with_critic_factory_use_actor(0)
|
return self._with_critic_factory_use_actor(0)
|
||||||
|
|
||||||
|
|
||||||
@ -505,6 +586,11 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
super().__init__(2, actor_future_provider)
|
super().__init__(2, actor_future_provider)
|
||||||
|
|
||||||
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
|
"""Specifies that the given factory shall be used for both critics.
|
||||||
|
|
||||||
|
:param critic_factory: the critic factory
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
self._with_critic_factory(i, critic_factory)
|
self._with_critic_factory(i, critic_factory)
|
||||||
return self
|
return self
|
||||||
@ -513,17 +599,27 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
"""Makes both critics use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
self._with_critic_factory_default(i, hidden_sizes)
|
self._with_critic_factory_default(i, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_common_critic_factory_use_actor(self) -> Self:
|
def with_common_critic_factory_use_actor(self) -> Self:
|
||||||
"""Makes all critics use the same network as the actor."""
|
"""Makes both critics reuse the actor's preprocessing network (parameter sharing)."""
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
self._with_critic_factory_use_actor(i)
|
self._with_critic_factory_use_actor(i)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
|
"""Specifies that the given factory shall be used for the first critic.
|
||||||
|
|
||||||
|
:param critic_factory: the critic factory
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._with_critic_factory(0, critic_factory)
|
self._with_critic_factory(0, critic_factory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -531,14 +627,24 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic1_factory_use_actor(self) -> Self:
|
def with_critic1_factory_use_actor(self) -> Self:
|
||||||
"""Makes the critic use the same network as the actor."""
|
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
|
||||||
return self._with_critic_factory_use_actor(0)
|
return self._with_critic_factory_use_actor(0)
|
||||||
|
|
||||||
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
|
"""Specifies that the given factory shall be used for the second critic.
|
||||||
|
|
||||||
|
:param critic_factory: the critic factory
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._with_critic_factory(1, critic_factory)
|
self._with_critic_factory(1, critic_factory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -546,11 +652,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic2_factory_use_actor(self) -> Self:
|
def with_critic2_factory_use_actor(self) -> Self:
|
||||||
"""Makes the second critic use the same network as the actor."""
|
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
|
||||||
return self._with_critic_factory_use_actor(1)
|
return self._with_critic_factory_use_actor(1)
|
||||||
|
|
||||||
|
|
||||||
@ -559,6 +670,12 @@ class _BuilderMixinCriticEnsembleFactory:
|
|||||||
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
|
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
|
||||||
|
|
||||||
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||||
|
"""Specifies that the given factory shall be used for the critic ensemble.
|
||||||
|
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
||||||
|
|
||||||
|
:param critic_factory: the critic factory
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self.critic_ensemble_factory = factory
|
self.critic_ensemble_factory = factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -566,6 +683,11 @@ class _BuilderMixinCriticEnsembleFactory:
|
|||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
|
"""Allows to customize the parameters of the default critic ensemble factory.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of sizes of hidden layers in the network architecture
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ class LoggerFactory(ToStringMixin, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class DefaultLoggerFactory(LoggerFactory):
|
class LoggerFactoryDefault(LoggerFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
||||||
|
@ -9,12 +9,14 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.highlevel.module.core import (
|
from tianshou.highlevel.module.core import (
|
||||||
IntermediateModule,
|
|
||||||
IntermediateModuleFactory,
|
|
||||||
ModuleFactory,
|
ModuleFactory,
|
||||||
TDevice,
|
TDevice,
|
||||||
init_linear_orthogonal,
|
init_linear_orthogonal,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.module.intermediate import (
|
||||||
|
IntermediateModule,
|
||||||
|
IntermediateModuleFactory,
|
||||||
|
)
|
||||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
@ -157,6 +159,11 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
|||||||
unbounded: bool = True,
|
unbounded: bool = True,
|
||||||
conditioned_sigma: bool = False,
|
conditioned_sigma: bool = False,
|
||||||
):
|
):
|
||||||
|
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||||
|
:param unbounded: whether to apply tanh activation on final logits
|
||||||
|
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
|
||||||
|
input; if False, sigma is an independent parameter
|
||||||
|
"""
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
self.unbounded = unbounded
|
self.unbounded = unbounded
|
||||||
self.conditioned_sigma = conditioned_sigma
|
self.conditioned_sigma = conditioned_sigma
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TypeAlias
|
from typing import TypeAlias
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
|
||||||
from tianshou.utils.string import ToStringMixin
|
|
||||||
|
|
||||||
TDevice: TypeAlias = str | torch.device
|
TDevice: TypeAlias = str | torch.device
|
||||||
|
|
||||||
@ -25,44 +21,8 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class ModuleFactory(ABC):
|
class ModuleFactory(ABC):
|
||||||
|
"""Represents a factory for the creation of a torch module given an environment and target device."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IntermediateModule:
|
|
||||||
module: torch.nn.Module
|
|
||||||
output_dim: int
|
|
||||||
|
|
||||||
|
|
||||||
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
|
||||||
return self.create_intermediate_module(envs, device).module
|
|
||||||
|
|
||||||
|
|
||||||
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
preprocess_net_factory: IntermediateModuleFactory,
|
|
||||||
hidden_sizes: Sequence[int] = (),
|
|
||||||
num_cosines: int = 64,
|
|
||||||
):
|
|
||||||
self.preprocess_net_factory = preprocess_net_factory
|
|
||||||
self.hidden_sizes = hidden_sizes
|
|
||||||
self.num_cosines = num_cosines
|
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
|
|
||||||
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
|
|
||||||
return ImplicitQuantileNetwork(
|
|
||||||
preprocess_net=preprocess_net.module,
|
|
||||||
action_shape=envs.get_action_shape(),
|
|
||||||
hidden_sizes=self.hidden_sizes,
|
|
||||||
num_cosines=self.num_cosines,
|
|
||||||
preprocess_net_output_dim=preprocess_net.output_dim,
|
|
||||||
device=device,
|
|
||||||
).to(device)
|
|
||||||
|
@ -15,6 +15,8 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
|
|
||||||
|
|
||||||
class CriticFactory(ToStringMixin, ABC):
|
class CriticFactory(ToStringMixin, ABC):
|
||||||
|
"""Represents a factory for the generation of a critic module."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(
|
def create_module(
|
||||||
self,
|
self,
|
||||||
@ -23,9 +25,11 @@ class CriticFactory(ToStringMixin, ABC):
|
|||||||
use_action: bool,
|
use_action: bool,
|
||||||
discrete_last_size_use_action_shape: bool = False,
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
""":param envs: the environments
|
"""Creates the critic module.
|
||||||
|
|
||||||
|
:param envs: the environments
|
||||||
:param device: the torch device
|
:param device: the torch device
|
||||||
:param use_action: whether to (additionally) expect the action as input
|
:param use_action: whether to expect the action as an additional input (in addition to the observations)
|
||||||
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
||||||
:return: the module
|
:return: the module
|
||||||
"""
|
"""
|
||||||
@ -39,6 +43,16 @@ class CriticFactory(ToStringMixin, ABC):
|
|||||||
lr: float,
|
lr: float,
|
||||||
discrete_last_size_use_action_shape: bool = False,
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> ModuleOpt:
|
) -> ModuleOpt:
|
||||||
|
"""Creates the critic module along with its optimizer for the given learning rate.
|
||||||
|
|
||||||
|
:param envs: the environments
|
||||||
|
:param device: the torch device
|
||||||
|
:param use_action: whether to expect the action as an additional input (in addition to the observations)
|
||||||
|
:param optim_factory: the optimizer factory
|
||||||
|
:param lr: the learning rate
|
||||||
|
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
module = self.create_module(
|
module = self.create_module(
|
||||||
envs,
|
envs,
|
||||||
device,
|
device,
|
||||||
|
27
tianshou/highlevel/module/intermediate.py
Normal file
27
tianshou/highlevel/module/intermediate.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IntermediateModule:
|
||||||
|
"""Container for a module which computes an intermediate representation (with a known dimension)."""
|
||||||
|
|
||||||
|
module: torch.nn.Module
|
||||||
|
output_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
|
||||||
|
"""Factory for the generation of a module which computes an intermediate representation."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||||
|
return self.create_intermediate_module(envs, device).module
|
@ -7,12 +7,16 @@ from tianshou.utils.net.common import ActorCritic
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModuleOpt:
|
class ModuleOpt:
|
||||||
|
"""Container for a torch module along with its optimizer."""
|
||||||
|
|
||||||
module: torch.nn.Module
|
module: torch.nn.Module
|
||||||
optim: torch.optim.Optimizer
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActorCriticModuleOpt:
|
class ActorCriticModuleOpt:
|
||||||
|
"""Container for an :class:`ActorCritic` instance along with its optimizer."""
|
||||||
|
|
||||||
actor_critic_module: ActorCritic
|
actor_critic_module: ActorCritic
|
||||||
optim: torch.optim.Optimizer
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
|
30
tianshou/highlevel/module/special.py
Normal file
30
tianshou/highlevel/module/special.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||||
|
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
||||||
|
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
|
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preprocess_net_factory: IntermediateModuleFactory,
|
||||||
|
hidden_sizes: Sequence[int] = (),
|
||||||
|
num_cosines: int = 64,
|
||||||
|
):
|
||||||
|
self.preprocess_net_factory = preprocess_net_factory
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.num_cosines = num_cosines
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
|
||||||
|
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
|
||||||
|
return ImplicitQuantileNetwork(
|
||||||
|
preprocess_net=preprocess_net.module,
|
||||||
|
action_shape=envs.get_action_shape(),
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
num_cosines=self.num_cosines,
|
||||||
|
preprocess_net_output_dim=preprocess_net.output_dim,
|
||||||
|
device=device,
|
||||||
|
).to(device)
|
@ -13,10 +13,6 @@ class OptimizerWithLearningRateProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerFactory(ABC, ToStringMixin):
|
class OptimizerFactory(ABC, ToStringMixin):
|
||||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
|
||||||
# Right now, the learning rate is typically a configuration parameter.
|
|
||||||
# If we drop the assumption, we can't have that and will need to move the parameter
|
|
||||||
# to the optimizer factory, which is inconvenient for the user.
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||||
pass
|
pass
|
||||||
|
@ -9,6 +9,8 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
|
|
||||||
|
|
||||||
class LRSchedulerFactory(ToStringMixin, ABC):
|
class LRSchedulerFactory(ToStringMixin, ABC):
|
||||||
|
"""Factory for the createion of a learning rate scheduler."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||||
pass
|
pass
|
||||||
|
@ -18,9 +18,12 @@ class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, std_fraction: float):
|
def __init__(self, std_fraction: float):
|
||||||
|
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||||
|
be used as the standard deviation
|
||||||
|
"""
|
||||||
self.std_fraction = std_fraction
|
self.std_fraction = std_fraction
|
||||||
|
|
||||||
def create_noise(self, envs: Environments) -> BaseNoise:
|
def create_noise(self, envs: Environments) -> GaussianNoise:
|
||||||
envs.get_type().assert_continuous(self)
|
envs.get_type().assert_continuous(self)
|
||||||
envs: ContinuousEnvironments
|
envs: ContinuousEnvironments
|
||||||
return GaussianNoise(sigma=envs.max_action * self.std_fraction)
|
return GaussianNoise(sigma=envs.max_action * self.std_fraction)
|
||||||
|
@ -241,7 +241,9 @@ class Params(GetParamTransformersProtocol):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
|
"""the learning rate to use in the gradient-based optimizer"""
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
"""factory for the creation of a learning rate scheduler"""
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
return [
|
return [
|
||||||
|
@ -3,7 +3,8 @@ from collections.abc import Sequence
|
|||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
|
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.policy import BasePolicy, ICMPolicy
|
from tianshou.policy import BasePolicy, ICMPolicy
|
||||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
|
@ -50,6 +50,8 @@ class Persistence(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class PersistenceGroup(Persistence):
|
class PersistenceGroup(Persistence):
|
||||||
|
"""Groups persistence handler such that they can be applied collectively."""
|
||||||
|
|
||||||
def __init__(self, *p: Persistence, enabled: bool = True):
|
def __init__(self, *p: Persistence, enabled: bool = True):
|
||||||
self.items = p
|
self.items = p
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
@ -69,7 +71,7 @@ class PolicyPersistence:
|
|||||||
FILENAME = "policy.dat"
|
FILENAME = "policy.dat"
|
||||||
|
|
||||||
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
||||||
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
||||||
this object is used to persist/restore data
|
this object is used to persist/restore data
|
||||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||||
"""
|
"""
|
||||||
|
@ -52,6 +52,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||||
""":param mean_rewards: the average undiscounted returns of the testing result
|
""":param mean_rewards: the average undiscounted returns of the testing result
|
||||||
|
:param context: the training context
|
||||||
:return: True if the goal has been reached and training should stop, False otherwise
|
:return: True if the goal has been reached and training should stop, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -64,6 +65,8 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainerCallbacks:
|
class TrainerCallbacks:
|
||||||
|
"""Container for callbacks used during training."""
|
||||||
|
|
||||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||||
stop_callback: TrainerStopCallback | None = None
|
stop_callback: TrainerStopCallback | None = None
|
||||||
|
@ -12,6 +12,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class World:
|
class World:
|
||||||
|
"""Container for instances and configuration items that are relevant to an experiment."""
|
||||||
|
|
||||||
envs: "Environments"
|
envs: "Environments"
|
||||||
policy: "BasePolicy"
|
policy: "BasePolicy"
|
||||||
train_collector: "Collector"
|
train_collector: "Collector"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user