Add documentation, improve structure of 'module' package

This commit is contained in:
Dominik Jain 2023-10-16 18:19:31 +02:00
parent 97e21b5ddf
commit 4b270eaa2d
19 changed files with 256 additions and 64 deletions

View File

@ -8,9 +8,11 @@ from torch import nn
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import (
TDevice,
)
from tianshou.highlevel.module.intermediate import (
IntermediateModule,
IntermediateModuleFactory,
TDevice,
)
from tianshou.utils.net.discrete import Actor, NoisyLinear

View File

@ -81,6 +81,8 @@ log = logging.getLogger(__name__)
class AgentFactory(ABC, ToStringMixin):
"""Factory for the creation of an agent's policy, its trainer as well as collectors."""
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config
self.optim_factory = optim_factory

View File

@ -14,6 +14,8 @@ TObservationShape: TypeAlias = int | Sequence[int]
class EnvType(Enum):
"""Enumeration of environment types."""
CONTINUOUS = "continuous"
DISCRETE = "discrete"
@ -33,6 +35,8 @@ class EnvType(Enum):
class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
@ -52,6 +56,11 @@ class Environments(ToStringMixin, ABC):
}
def set_persistence(self, *p: Persistence) -> None:
"""Associates the given persistence handlers which may persist and restore
environment-specific information.
:param p: persistence handlers
"""
self.persistence = p
@abstractmethod
@ -74,6 +83,8 @@ class Environments(ToStringMixin, ABC):
class ContinuousEnvironments(Environments):
"""Represents (vectorized) continuous environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@ -110,6 +121,8 @@ class ContinuousEnvironments(Environments):
class DiscreteEnvironments(Environments):
"""Represents (vectorized) discrete environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore

View File

@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
)
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import (
ActorFactory,
ActorFactoryDefault,
@ -38,8 +38,6 @@ from tianshou.highlevel.module.actor import (
IntermediateModuleFactoryFromActorFactory,
)
from tianshou.highlevel.module.core import (
ImplicitQuantileNetworkFactory,
IntermediateModuleFactory,
TDevice,
)
from tianshou.highlevel.module.critic import (
@ -49,6 +47,8 @@ from tianshou.highlevel.module.critic import (
CriticFactoryDefault,
CriticFactoryReuseActor,
)
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import (
A2CParams,
@ -116,8 +116,12 @@ class ExperimentConfig:
@dataclass
class ExperimentResult:
"""Contains the results of an experiment."""
world: World
"""contains all the essential instances of the experiment"""
trainer_result: dict[str, Any] | None
"""dictionary of results as returned by the trained (if any)"""
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
@ -140,7 +144,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
env_config: PersistableConfigProtocol | None = None,
):
if logger_factory is None:
logger_factory = DefaultLoggerFactory()
logger_factory = LoggerFactoryDefault()
self.config = config
self.env_factory = env_factory
self.agent_factory = agent_factory
@ -179,7 +183,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
pickle.dump(self, f)
def run(
self, experiment_name: str | None = None, logger_run_id: str | None = None,
self,
experiment_name: str | None = None,
logger_run_id: str | None = None,
) -> ExperimentResult:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
@ -317,14 +323,31 @@ class ExperimentBuilder:
return self
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
"""Allows to customize the logger factory to use.
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
:param logger_factory: the factory to use
:return: the builder
"""
self._logger_factory = logger_factory
return self
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
"""Allows to define a wrapper around the policy that is created, extending the original policy.
:param policy_wrapper_factory: the factory for the wrapper
:return: the builder
"""
self._policy_wrapper_factory = policy_wrapper_factory
return self
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
"""Allows to customize the gradient-based optimizer to use.
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
:param optim_factory: the optimizer factory
:return: the builder
"""
self._optim_factory = optim_factory
return self
@ -345,14 +368,30 @@ class ExperimentBuilder:
return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
"""Allows to define a callback function which is called at the beginning of every epoch during training.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_train = callback
return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_test = callback
return self
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
"""Allows to define a callback that decides whether training shall stop early.
The callback receives the undiscounted returns of the testing result.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.stop_callback = callback
return self
@ -367,6 +406,10 @@ class ExperimentBuilder:
return self._optim_factory
def build(self) -> Experiment:
"""Creates the experiment based on the options specified via this builder.
:return: the experiment
"""
agent_factory = self._create_agent_factory()
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory:
@ -388,6 +431,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
self._actor_factory: ActorFactory | None = None
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
"""Allows to customize the actor component via the specification of a factory.
If this function is not called, a default actor factory (with default parameters) will be used.
:param actor_factory: the factory to use for the creation of the actor network
:return: the builder
"""
self._actor_factory = actor_factory
return self
@ -397,6 +446,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter.
:return: the builder
"""
self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type,
hidden_sizes,
@ -406,6 +461,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
return self
def get_actor_future(self) -> ActorFuture:
""":return: an object, which, in the future, will contain the actor instance that is created for the experiment."""
return self._actor_future
def _get_actor_factory(self) -> ActorFactory:
@ -431,6 +487,15 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter.
:return: the builder
"""
return super()._with_actor_factory_default(
hidden_sizes,
continuous_unbounded=continuous_unbounded,
@ -445,6 +510,12 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
:return: the builder
"""
return super()._with_actor_factory_default(hidden_sizes)
@ -480,6 +551,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
super().__init__(1, actor_future_provider)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for the critic.
:param critic_factory: the critic factory
:return: the builder
"""
self._with_critic_factory(0, critic_factory)
return self
@ -487,6 +563,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
"""Makes the critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes)
return self
@ -496,7 +577,7 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
super().__init__(actor_future_provider)
def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(0)
@ -505,6 +586,11 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
super().__init__(2, actor_future_provider)
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for both critics.
:param critic_factory: the critic factory
:return: the builder
"""
for i in range(len(self._critic_factories)):
self._with_critic_factory(i, critic_factory)
return self
@ -513,17 +599,27 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
"""Makes both critics use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes)
return self
def with_common_critic_factory_use_actor(self) -> Self:
"""Makes all critics use the same network as the actor."""
"""Makes both critics reuse the actor's preprocessing network (parameter sharing)."""
for i in range(len(self._critic_factories)):
self._with_critic_factory_use_actor(i)
return self
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for the first critic.
:param critic_factory: the critic factory
:return: the builder
"""
self._with_critic_factory(0, critic_factory)
return self
@ -531,14 +627,24 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic1_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor."""
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(0)
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for the second critic.
:param critic_factory: the critic factory
:return: the builder
"""
self._with_critic_factory(1, critic_factory)
return self
@ -546,11 +652,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic2_factory_use_actor(self) -> Self:
"""Makes the second critic use the same network as the actor."""
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(1)
@ -559,6 +670,12 @@ class _BuilderMixinCriticEnsembleFactory:
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
"""Specifies that the given factory shall be used for the critic ensemble.
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
:param critic_factory: the critic factory
:return: the builder
"""
self.critic_ensemble_factory = factory
return self
@ -566,6 +683,11 @@ class _BuilderMixinCriticEnsembleFactory:
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
"""Allows to customize the parameters of the default critic ensemble factory.
:param hidden_sizes: the sequence of sizes of hidden layers in the network architecture
:return: the builder
"""
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
return self

View File

@ -27,7 +27,7 @@ class LoggerFactory(ToStringMixin, ABC):
"""
class DefaultLoggerFactory(LoggerFactory):
class LoggerFactoryDefault(LoggerFactory):
def __init__(
self,
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",

View File

@ -9,12 +9,14 @@ from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import (
IntermediateModule,
IntermediateModuleFactory,
ModuleFactory,
TDevice,
init_linear_orthogonal,
)
from tianshou.highlevel.module.intermediate import (
IntermediateModule,
IntermediateModuleFactory,
)
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
@ -157,6 +159,11 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
unbounded: bool = True,
conditioned_sigma: bool = False,
):
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param unbounded: whether to apply tanh activation on final logits
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
input; if False, sigma is an independent parameter
"""
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma

View File

@ -1,14 +1,10 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | torch.device
@ -25,44 +21,8 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
class ModuleFactory(ABC):
"""Represents a factory for the creation of a torch module given an environment and target device."""
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
pass
@dataclass
class IntermediateModule:
module: torch.nn.Module
output_dim: int
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
@abstractmethod
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
def __init__(
self,
preprocess_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
):
self.preprocess_net_factory = preprocess_net_factory
self.hidden_sizes = hidden_sizes
self.num_cosines = num_cosines
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
return ImplicitQuantileNetwork(
preprocess_net=preprocess_net.module,
action_shape=envs.get_action_shape(),
hidden_sizes=self.hidden_sizes,
num_cosines=self.num_cosines,
preprocess_net_output_dim=preprocess_net.output_dim,
device=device,
).to(device)

View File

@ -15,6 +15,8 @@ from tianshou.utils.string import ToStringMixin
class CriticFactory(ToStringMixin, ABC):
"""Represents a factory for the generation of a critic module."""
@abstractmethod
def create_module(
self,
@ -23,9 +25,11 @@ class CriticFactory(ToStringMixin, ABC):
use_action: bool,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
""":param envs: the environments
"""Creates the critic module.
:param envs: the environments
:param device: the torch device
:param use_action: whether to (additionally) expect the action as input
:param use_action: whether to expect the action as an additional input (in addition to the observations)
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
:return: the module
"""
@ -39,6 +43,16 @@ class CriticFactory(ToStringMixin, ABC):
lr: float,
discrete_last_size_use_action_shape: bool = False,
) -> ModuleOpt:
"""Creates the critic module along with its optimizer for the given learning rate.
:param envs: the environments
:param device: the torch device
:param use_action: whether to expect the action as an additional input (in addition to the observations)
:param optim_factory: the optimizer factory
:param lr: the learning rate
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
:return:
"""
module = self.create_module(
envs,
device,

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.utils.string import ToStringMixin
@dataclass
class IntermediateModule:
"""Container for a module which computes an intermediate representation (with a known dimension)."""
module: torch.nn.Module
output_dim: int
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
"""Factory for the generation of a module which computes an intermediate representation."""
@abstractmethod
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module

View File

@ -7,12 +7,16 @@ from tianshou.utils.net.common import ActorCritic
@dataclass
class ModuleOpt:
"""Container for a torch module along with its optimizer."""
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
"""Container for an :class:`ActorCritic` instance along with its optimizer."""
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer

View File

@ -0,0 +1,30 @@
from collections.abc import Sequence
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
def __init__(
self,
preprocess_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
):
self.preprocess_net_factory = preprocess_net_factory
self.hidden_sizes = hidden_sizes
self.num_cosines = num_cosines
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
return ImplicitQuantileNetwork(
preprocess_net=preprocess_net.module,
action_shape=envs.get_action_shape(),
hidden_sizes=self.hidden_sizes,
num_cosines=self.num_cosines,
preprocess_net_output_dim=preprocess_net.output_dim,
device=device,
).to(device)

View File

@ -13,10 +13,6 @@ class OptimizerWithLearningRateProtocol(Protocol):
class OptimizerFactory(ABC, ToStringMixin):
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter.
# If we drop the assumption, we can't have that and will need to move the parameter
# to the optimizer factory, which is inconvenient for the user.
@abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass

View File

@ -9,6 +9,8 @@ from tianshou.utils.string import ToStringMixin
class LRSchedulerFactory(ToStringMixin, ABC):
"""Factory for the createion of a learning rate scheduler."""
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass

View File

@ -18,9 +18,12 @@ class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
"""
def __init__(self, std_fraction: float):
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
be used as the standard deviation
"""
self.std_fraction = std_fraction
def create_noise(self, envs: Environments) -> BaseNoise:
def create_noise(self, envs: Environments) -> GaussianNoise:
envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments
return GaussianNoise(sigma=envs.max_action * self.std_fraction)

View File

@ -241,7 +241,9 @@ class Params(GetParamTransformersProtocol):
@dataclass
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3
"""the learning rate to use in the gradient-based optimizer"""
lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return [

View File

@ -3,7 +3,8 @@ from collections.abc import Sequence
from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule

View File

@ -50,6 +50,8 @@ class Persistence(ABC):
class PersistenceGroup(Persistence):
"""Groups persistence handler such that they can be applied collectively."""
def __init__(self, *p: Persistence, enabled: bool = True):
self.items = p
self.enabled = enabled
@ -69,7 +71,7 @@ class PolicyPersistence:
FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
""":param additional_persistence: a persistence instance which is to be envoked whenever
""":param additional_persistence: a persistence instance which is to be invoked whenever
this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled)
"""

View File

@ -52,6 +52,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
@abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
""":param mean_rewards: the average undiscounted returns of the testing result
:param context: the training context
:return: True if the goal has been reached and training should stop, False otherwise
"""
@ -64,6 +65,8 @@ class TrainerStopCallback(ToStringMixin, ABC):
@dataclass
class TrainerCallbacks:
"""Container for callbacks used during training."""
epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None

View File

@ -12,6 +12,8 @@ if TYPE_CHECKING:
@dataclass
class World:
"""Container for instances and configuration items that are relevant to an experiment."""
envs: "Environments"
policy: "BasePolicy"
train_collector: "Collector"