diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 0767eb7..bd16d4f 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -8,9 +8,11 @@ from torch import nn from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import ( + TDevice, +) +from tianshou.highlevel.module.intermediate import ( IntermediateModule, IntermediateModuleFactory, - TDevice, ) from tianshou.utils.net.discrete import Actor, NoisyLinear diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 38407a8..7ac2ae7 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -81,6 +81,8 @@ log = logging.getLogger(__name__) class AgentFactory(ABC, ToStringMixin): + """Factory for the creation of an agent's policy, its trainer as well as collectors.""" + def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): self.sampling_config = sampling_config self.optim_factory = optim_factory diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index f356f4d..6fdc970 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -14,6 +14,8 @@ TObservationShape: TypeAlias = int | Sequence[int] class EnvType(Enum): + """Enumeration of environment types.""" + CONTINUOUS = "continuous" DISCRETE = "discrete" @@ -33,6 +35,8 @@ class EnvType(Enum): class Environments(ToStringMixin, ABC): + """Represents (vectorized) environments.""" + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs @@ -52,6 +56,11 @@ class Environments(ToStringMixin, ABC): } def set_persistence(self, *p: Persistence) -> None: + """Associates the given persistence handlers which may persist and restore + environment-specific information. + + :param p: persistence handlers + """ self.persistence = p @abstractmethod @@ -74,6 +83,8 @@ class Environments(ToStringMixin, ABC): class ContinuousEnvironments(Environments): + """Represents (vectorized) continuous environments.""" + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) @@ -110,6 +121,8 @@ class ContinuousEnvironments(Environments): class DiscreteEnvironments(Environments): + """Represents (vectorized) discrete environments.""" + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index ed2575d..f22ea6d 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -27,7 +27,7 @@ from tianshou.highlevel.agent import ( ) from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory, Environments -from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger +from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( ActorFactory, ActorFactoryDefault, @@ -38,8 +38,6 @@ from tianshou.highlevel.module.actor import ( IntermediateModuleFactoryFromActorFactory, ) from tianshou.highlevel.module.core import ( - ImplicitQuantileNetworkFactory, - IntermediateModuleFactory, TDevice, ) from tianshou.highlevel.module.critic import ( @@ -49,6 +47,8 @@ from tianshou.highlevel.module.critic import ( CriticFactoryDefault, CriticFactoryReuseActor, ) +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory +from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.params.policy_params import ( A2CParams, @@ -116,8 +116,12 @@ class ExperimentConfig: @dataclass class ExperimentResult: + """Contains the results of an experiment.""" + world: World + """contains all the essential instances of the experiment""" trainer_result: dict[str, Any] | None + """dictionary of results as returned by the trained (if any)""" class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): @@ -140,7 +144,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): env_config: PersistableConfigProtocol | None = None, ): if logger_factory is None: - logger_factory = DefaultLoggerFactory() + logger_factory = LoggerFactoryDefault() self.config = config self.env_factory = env_factory self.agent_factory = agent_factory @@ -179,7 +183,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): pickle.dump(self, f) def run( - self, experiment_name: str | None = None, logger_run_id: str | None = None, + self, + experiment_name: str | None = None, + logger_run_id: str | None = None, ) -> ExperimentResult: """:param experiment_name: the experiment name, which corresponds to the directory (within the logging directory) where all results associated with the experiment will be saved. @@ -317,14 +323,31 @@ class ExperimentBuilder: return self def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: + """Allows to customize the logger factory to use. + If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used. + + :param logger_factory: the factory to use + :return: the builder + """ self._logger_factory = logger_factory return self def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self: + """Allows to define a wrapper around the policy that is created, extending the original policy. + + :param policy_wrapper_factory: the factory for the wrapper + :return: the builder + """ self._policy_wrapper_factory = policy_wrapper_factory return self def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: + """Allows to customize the gradient-based optimizer to use. + By default, :class:`OptimizerFactoryAdam` will be used with default parameters. + + :param optim_factory: the optimizer factory + :return: the builder + """ self._optim_factory = optim_factory return self @@ -345,14 +368,30 @@ class ExperimentBuilder: return self def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self: + """Allows to define a callback function which is called at the beginning of every epoch during training. + + :param callback: the callback + :return: the builder + """ self._trainer_callbacks.epoch_callback_train = callback return self def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self: + """Allows to define a callback function which is called at the beginning of testing in each epoch. + + :param callback: the callback + :return: the builder + """ self._trainer_callbacks.epoch_callback_test = callback return self def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self: + """Allows to define a callback that decides whether training shall stop early. + The callback receives the undiscounted returns of the testing result. + + :param callback: the callback + :return: the builder + """ self._trainer_callbacks.stop_callback = callback return self @@ -367,6 +406,10 @@ class ExperimentBuilder: return self._optim_factory def build(self) -> Experiment: + """Creates the experiment based on the options specified via this builder. + + :return: the experiment + """ agent_factory = self._create_agent_factory() agent_factory.set_trainer_callbacks(self._trainer_callbacks) if self._policy_wrapper_factory: @@ -388,6 +431,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): self._actor_factory: ActorFactory | None = None def with_actor_factory(self, actor_factory: ActorFactory) -> Self: + """Allows to customize the actor component via the specification of a factory. + If this function is not called, a default actor factory (with default parameters) will be used. + + :param actor_factory: the factory to use for the creation of the actor network + :return: the builder + """ self._actor_factory = actor_factory return self @@ -397,6 +446,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: + """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure + :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits + :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) + shall be computed from the input; if False, sigma is an independent parameter. + :return: the builder + """ self._actor_factory = ActorFactoryDefault( self._continuous_actor_type, hidden_sizes, @@ -406,6 +461,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): return self def get_actor_future(self) -> ActorFuture: + """:return: an object, which, in the future, will contain the actor instance that is created for the experiment.""" return self._actor_future def _get_actor_factory(self) -> ActorFactory: @@ -431,6 +487,15 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: + """Defines use of the default actor factory, allowing its parameters it to be customized. + The default actor factory uses an MLP-style architecture. + + :param hidden_sizes: dimensions of hidden layers used by the network + :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits + :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) + shall be computed from the input; if False, sigma is an independent parameter. + :return: the builder + """ return super()._with_actor_factory_default( hidden_sizes, continuous_unbounded=continuous_unbounded, @@ -445,6 +510,12 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor super().__init__(ContinuousActorType.DETERMINISTIC) def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: + """Defines use of the default actor factory, allowing its parameters it to be customized. + The default actor factory uses an MLP-style architecture. + + :param hidden_sizes: dimensions of hidden layers used by the network + :return: the builder + """ return super()._with_actor_factory_default(hidden_sizes) @@ -480,6 +551,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): super().__init__(1, actor_future_provider) def with_critic_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for the critic. + + :param critic_factory: the critic factory + :return: the builder + """ self._with_critic_factory(0, critic_factory) return self @@ -487,6 +563,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, ) -> Self: + """Makes the critic use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :return: the builder + """ self._with_critic_factory_default(0, hidden_sizes) return self @@ -496,7 +577,7 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto super().__init__(actor_future_provider) def with_critic_factory_use_actor(self) -> Self: - """Makes the critic use the same network as the actor.""" + """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" return self._with_critic_factory_use_actor(0) @@ -505,6 +586,11 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): super().__init__(2, actor_future_provider) def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for both critics. + + :param critic_factory: the critic factory + :return: the builder + """ for i in range(len(self._critic_factories)): self._with_critic_factory(i, critic_factory) return self @@ -513,17 +599,27 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, ) -> Self: + """Makes both critics use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :return: the builder + """ for i in range(len(self._critic_factories)): self._with_critic_factory_default(i, hidden_sizes) return self def with_common_critic_factory_use_actor(self) -> Self: - """Makes all critics use the same network as the actor.""" + """Makes both critics reuse the actor's preprocessing network (parameter sharing).""" for i in range(len(self._critic_factories)): self._with_critic_factory_use_actor(i) return self def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for the first critic. + + :param critic_factory: the critic factory + :return: the builder + """ self._with_critic_factory(0, critic_factory) return self @@ -531,14 +627,24 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, ) -> Self: + """Makes the first critic use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :return: the builder + """ self._with_critic_factory_default(0, hidden_sizes) return self def with_critic1_factory_use_actor(self) -> Self: - """Makes the critic use the same network as the actor.""" + """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" return self._with_critic_factory_use_actor(0) def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for the second critic. + + :param critic_factory: the critic factory + :return: the builder + """ self._with_critic_factory(1, critic_factory) return self @@ -546,11 +652,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, ) -> Self: + """Makes the second critic use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :return: the builder + """ self._with_critic_factory_default(0, hidden_sizes) return self def with_critic2_factory_use_actor(self) -> Self: - """Makes the second critic use the same network as the actor.""" + """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" return self._with_critic_factory_use_actor(1) @@ -559,6 +670,12 @@ class _BuilderMixinCriticEnsembleFactory: self.critic_ensemble_factory: CriticEnsembleFactory | None = None def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self: + """Specifies that the given factory shall be used for the critic ensemble. + If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used. + + :param critic_factory: the critic factory + :return: the builder + """ self.critic_ensemble_factory = factory return self @@ -566,6 +683,11 @@ class _BuilderMixinCriticEnsembleFactory: self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, ) -> Self: + """Allows to customize the parameters of the default critic ensemble factory. + + :param hidden_sizes: the sequence of sizes of hidden layers in the network architecture + :return: the builder + """ self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) return self diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index c10599a..3458b6e 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -27,7 +27,7 @@ class LoggerFactory(ToStringMixin, ABC): """ -class DefaultLoggerFactory(LoggerFactory): +class LoggerFactoryDefault(LoggerFactory): def __init__( self, logger_type: Literal["tensorboard", "wandb"] = "tensorboard", diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index d3f31f6..b1d4792 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -9,12 +9,14 @@ from torch import nn from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import ( - IntermediateModule, - IntermediateModuleFactory, ModuleFactory, TDevice, init_linear_orthogonal, ) +from tianshou.highlevel.module.intermediate import ( + IntermediateModule, + IntermediateModuleFactory, +) from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete @@ -157,6 +159,11 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): unbounded: bool = True, conditioned_sigma: bool = False, ): + """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure + :param unbounded: whether to apply tanh activation on final logits + :param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the + input; if False, sigma is an independent parameter + """ self.hidden_sizes = hidden_sizes self.unbounded = unbounded self.conditioned_sigma = conditioned_sigma diff --git a/tianshou/highlevel/module/core.py b/tianshou/highlevel/module/core.py index 08fc88e..61f4a23 100644 --- a/tianshou/highlevel/module/core.py +++ b/tianshou/highlevel/module/core.py @@ -1,14 +1,10 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass from typing import TypeAlias import numpy as np import torch from tianshou.highlevel.env import Environments -from tianshou.utils.net.discrete import ImplicitQuantileNetwork -from tianshou.utils.string import ToStringMixin TDevice: TypeAlias = str | torch.device @@ -25,44 +21,8 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None: class ModuleFactory(ABC): + """Represents a factory for the creation of a torch module given an environment and target device.""" + @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: pass - - -@dataclass -class IntermediateModule: - module: torch.nn.Module - output_dim: int - - -class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): - @abstractmethod - def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: - pass - - def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: - return self.create_intermediate_module(envs, device).module - - -class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): - def __init__( - self, - preprocess_net_factory: IntermediateModuleFactory, - hidden_sizes: Sequence[int] = (), - num_cosines: int = 64, - ): - self.preprocess_net_factory = preprocess_net_factory - self.hidden_sizes = hidden_sizes - self.num_cosines = num_cosines - - def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: - preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) - return ImplicitQuantileNetwork( - preprocess_net=preprocess_net.module, - action_shape=envs.get_action_shape(), - hidden_sizes=self.hidden_sizes, - num_cosines=self.num_cosines, - preprocess_net_output_dim=preprocess_net.output_dim, - device=device, - ).to(device) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 54b6003..1576bc2 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -15,6 +15,8 @@ from tianshou.utils.string import ToStringMixin class CriticFactory(ToStringMixin, ABC): + """Represents a factory for the generation of a critic module.""" + @abstractmethod def create_module( self, @@ -23,9 +25,11 @@ class CriticFactory(ToStringMixin, ABC): use_action: bool, discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: - """:param envs: the environments + """Creates the critic module. + + :param envs: the environments :param device: the torch device - :param use_action: whether to (additionally) expect the action as input + :param use_action: whether to expect the action as an additional input (in addition to the observations) :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape :return: the module """ @@ -39,6 +43,16 @@ class CriticFactory(ToStringMixin, ABC): lr: float, discrete_last_size_use_action_shape: bool = False, ) -> ModuleOpt: + """Creates the critic module along with its optimizer for the given learning rate. + + :param envs: the environments + :param device: the torch device + :param use_action: whether to expect the action as an additional input (in addition to the observations) + :param optim_factory: the optimizer factory + :param lr: the learning rate + :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape + :return: + """ module = self.create_module( envs, device, diff --git a/tianshou/highlevel/module/intermediate.py b/tianshou/highlevel/module/intermediate.py new file mode 100644 index 0000000..a008935 --- /dev/null +++ b/tianshou/highlevel/module/intermediate.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.utils.string import ToStringMixin + + +@dataclass +class IntermediateModule: + """Container for a module which computes an intermediate representation (with a known dimension).""" + + module: torch.nn.Module + output_dim: int + + +class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): + """Factory for the generation of a module which computes an intermediate representation.""" + + @abstractmethod + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + pass + + def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: + return self.create_intermediate_module(envs, device).module diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py index 43242d8..222d680 100644 --- a/tianshou/highlevel/module/module_opt.py +++ b/tianshou/highlevel/module/module_opt.py @@ -7,12 +7,16 @@ from tianshou.utils.net.common import ActorCritic @dataclass class ModuleOpt: + """Container for a torch module along with its optimizer.""" + module: torch.nn.Module optim: torch.optim.Optimizer @dataclass class ActorCriticModuleOpt: + """Container for an :class:`ActorCritic` instance along with its optimizer.""" + actor_critic_module: ActorCritic optim: torch.optim.Optimizer diff --git a/tianshou/highlevel/module/special.py b/tianshou/highlevel/module/special.py new file mode 100644 index 0000000..8c3c568 --- /dev/null +++ b/tianshou/highlevel/module/special.py @@ -0,0 +1,30 @@ +from collections.abc import Sequence + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory +from tianshou.utils.net.discrete import ImplicitQuantileNetwork +from tianshou.utils.string import ToStringMixin + + +class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): + def __init__( + self, + preprocess_net_factory: IntermediateModuleFactory, + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + ): + self.preprocess_net_factory = preprocess_net_factory + self.hidden_sizes = hidden_sizes + self.num_cosines = num_cosines + + def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: + preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) + return ImplicitQuantileNetwork( + preprocess_net=preprocess_net.module, + action_shape=envs.get_action_shape(), + hidden_sizes=self.hidden_sizes, + num_cosines=self.num_cosines, + preprocess_net_output_dim=preprocess_net.output_dim, + device=device, + ).to(device) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 008321f..8697e05 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -13,10 +13,6 @@ class OptimizerWithLearningRateProtocol(Protocol): class OptimizerFactory(ABC, ToStringMixin): - # TODO: Is it OK to assume that all optimizers have a learning rate argument? - # Right now, the learning rate is typically a configuration parameter. - # If we drop the assumption, we can't have that and will need to move the parameter - # to the optimizer factory, which is inconvenient for the user. @abstractmethod def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: pass diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 820be2e..93b1ce0 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -9,6 +9,8 @@ from tianshou.utils.string import ToStringMixin class LRSchedulerFactory(ToStringMixin, ABC): + """Factory for the createion of a learning rate scheduler.""" + @abstractmethod def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: pass diff --git a/tianshou/highlevel/params/noise.py b/tianshou/highlevel/params/noise.py index 106e002..d3e4ce6 100644 --- a/tianshou/highlevel/params/noise.py +++ b/tianshou/highlevel/params/noise.py @@ -18,9 +18,12 @@ class NoiseFactoryMaxActionScaledGaussian(NoiseFactory): """ def __init__(self, std_fraction: float): + """:param std_fraction: fraction (between 0 and 1) of the maximum action value that shall + be used as the standard deviation + """ self.std_fraction = std_fraction - def create_noise(self, envs: Environments) -> BaseNoise: + def create_noise(self, envs: Environments) -> GaussianNoise: envs.get_type().assert_continuous(self) envs: ContinuousEnvironments return GaussianNoise(sigma=envs.max_action * self.std_fraction) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 33c9783..9d8cd53 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -241,7 +241,9 @@ class Params(GetParamTransformersProtocol): @dataclass class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): lr: float = 1e-3 + """the learning rate to use in the gradient-based optimizer""" lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 44d0f82..c821f50 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -3,7 +3,8 @@ from collections.abc import Sequence from typing import Generic, TypeVar from tianshou.highlevel.env import Environments -from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactory from tianshou.policy import BasePolicy, ICMPolicy from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index fad4bea..54f3c23 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -50,6 +50,8 @@ class Persistence(ABC): class PersistenceGroup(Persistence): + """Groups persistence handler such that they can be applied collectively.""" + def __init__(self, *p: Persistence, enabled: bool = True): self.items = p self.enabled = enabled @@ -69,7 +71,7 @@ class PolicyPersistence: FILENAME = "policy.dat" def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True): - """:param additional_persistence: a persistence instance which is to be envoked whenever + """:param additional_persistence: a persistence instance which is to be invoked whenever this object is used to persist/restore data :param enabled: whether persistence is enabled (restoration is always enabled) """ diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 770c92b..2388e2a 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -52,6 +52,7 @@ class TrainerStopCallback(ToStringMixin, ABC): @abstractmethod def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: """:param mean_rewards: the average undiscounted returns of the testing result + :param context: the training context :return: True if the goal has been reached and training should stop, False otherwise """ @@ -64,6 +65,8 @@ class TrainerStopCallback(ToStringMixin, ABC): @dataclass class TrainerCallbacks: + """Container for callbacks used during training.""" + epoch_callback_train: TrainerEpochCallbackTrain | None = None epoch_callback_test: TrainerEpochCallbackTest | None = None stop_callback: TrainerStopCallback | None = None diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 6ec7c4b..9d0572b 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: @dataclass class World: + """Container for instances and configuration items that are relevant to an experiment.""" + envs: "Environments" policy: "BasePolicy" train_collector: "Collector"