Add documentation, improve structure of 'module' package

This commit is contained in:
Dominik Jain 2023-10-16 18:19:31 +02:00
parent 97e21b5ddf
commit 4b270eaa2d
19 changed files with 256 additions and 64 deletions

View File

@ -8,9 +8,11 @@ from torch import nn
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import ( from tianshou.highlevel.module.core import (
TDevice,
)
from tianshou.highlevel.module.intermediate import (
IntermediateModule, IntermediateModule,
IntermediateModuleFactory, IntermediateModuleFactory,
TDevice,
) )
from tianshou.utils.net.discrete import Actor, NoisyLinear from tianshou.utils.net.discrete import Actor, NoisyLinear

View File

@ -81,6 +81,8 @@ log = logging.getLogger(__name__)
class AgentFactory(ABC, ToStringMixin): class AgentFactory(ABC, ToStringMixin):
"""Factory for the creation of an agent's policy, its trainer as well as collectors."""
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.optim_factory = optim_factory self.optim_factory = optim_factory

View File

@ -14,6 +14,8 @@ TObservationShape: TypeAlias = int | Sequence[int]
class EnvType(Enum): class EnvType(Enum):
"""Enumeration of environment types."""
CONTINUOUS = "continuous" CONTINUOUS = "continuous"
DISCRETE = "discrete" DISCRETE = "discrete"
@ -33,6 +35,8 @@ class EnvType(Enum):
class Environments(ToStringMixin, ABC): class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
@ -52,6 +56,11 @@ class Environments(ToStringMixin, ABC):
} }
def set_persistence(self, *p: Persistence) -> None: def set_persistence(self, *p: Persistence) -> None:
"""Associates the given persistence handlers which may persist and restore
environment-specific information.
:param p: persistence handlers
"""
self.persistence = p self.persistence = p
@abstractmethod @abstractmethod
@ -74,6 +83,8 @@ class Environments(ToStringMixin, ABC):
class ContinuousEnvironments(Environments): class ContinuousEnvironments(Environments):
"""Represents (vectorized) continuous environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs) super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@ -110,6 +121,8 @@ class ContinuousEnvironments(Environments):
class DiscreteEnvironments(Environments): class DiscreteEnvironments(Environments):
"""Represents (vectorized) discrete environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs) super().__init__(env, train_envs, test_envs)
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore

View File

@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
) )
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
ActorFactoryDefault, ActorFactoryDefault,
@ -38,8 +38,6 @@ from tianshou.highlevel.module.actor import (
IntermediateModuleFactoryFromActorFactory, IntermediateModuleFactoryFromActorFactory,
) )
from tianshou.highlevel.module.core import ( from tianshou.highlevel.module.core import (
ImplicitQuantileNetworkFactory,
IntermediateModuleFactory,
TDevice, TDevice,
) )
from tianshou.highlevel.module.critic import ( from tianshou.highlevel.module.critic import (
@ -49,6 +47,8 @@ from tianshou.highlevel.module.critic import (
CriticFactoryDefault, CriticFactoryDefault,
CriticFactoryReuseActor, CriticFactoryReuseActor,
) )
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
@ -116,8 +116,12 @@ class ExperimentConfig:
@dataclass @dataclass
class ExperimentResult: class ExperimentResult:
"""Contains the results of an experiment."""
world: World world: World
"""contains all the essential instances of the experiment"""
trainer_result: dict[str, Any] | None trainer_result: dict[str, Any] | None
"""dictionary of results as returned by the trained (if any)"""
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
@ -140,7 +144,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
env_config: PersistableConfigProtocol | None = None, env_config: PersistableConfigProtocol | None = None,
): ):
if logger_factory is None: if logger_factory is None:
logger_factory = DefaultLoggerFactory() logger_factory = LoggerFactoryDefault()
self.config = config self.config = config
self.env_factory = env_factory self.env_factory = env_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
@ -179,7 +183,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
pickle.dump(self, f) pickle.dump(self, f)
def run( def run(
self, experiment_name: str | None = None, logger_run_id: str | None = None, self,
experiment_name: str | None = None,
logger_run_id: str | None = None,
) -> ExperimentResult: ) -> ExperimentResult:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging """:param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved. directory) where all results associated with the experiment will be saved.
@ -317,14 +323,31 @@ class ExperimentBuilder:
return self return self
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
"""Allows to customize the logger factory to use.
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
:param logger_factory: the factory to use
:return: the builder
"""
self._logger_factory = logger_factory self._logger_factory = logger_factory
return self return self
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self: def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
"""Allows to define a wrapper around the policy that is created, extending the original policy.
:param policy_wrapper_factory: the factory for the wrapper
:return: the builder
"""
self._policy_wrapper_factory = policy_wrapper_factory self._policy_wrapper_factory = policy_wrapper_factory
return self return self
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
"""Allows to customize the gradient-based optimizer to use.
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
:param optim_factory: the optimizer factory
:return: the builder
"""
self._optim_factory = optim_factory self._optim_factory = optim_factory
return self return self
@ -345,14 +368,30 @@ class ExperimentBuilder:
return self return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self: def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
"""Allows to define a callback function which is called at the beginning of every epoch during training.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_train = callback self._trainer_callbacks.epoch_callback_train = callback
return self return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self: def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_test = callback self._trainer_callbacks.epoch_callback_test = callback
return self return self
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self: def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
"""Allows to define a callback that decides whether training shall stop early.
The callback receives the undiscounted returns of the testing result.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.stop_callback = callback self._trainer_callbacks.stop_callback = callback
return self return self
@ -367,6 +406,10 @@ class ExperimentBuilder:
return self._optim_factory return self._optim_factory
def build(self) -> Experiment: def build(self) -> Experiment:
"""Creates the experiment based on the options specified via this builder.
:return: the experiment
"""
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
agent_factory.set_trainer_callbacks(self._trainer_callbacks) agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory: if self._policy_wrapper_factory:
@ -388,6 +431,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
self._actor_factory: ActorFactory | None = None self._actor_factory: ActorFactory | None = None
def with_actor_factory(self, actor_factory: ActorFactory) -> Self: def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
"""Allows to customize the actor component via the specification of a factory.
If this function is not called, a default actor factory (with default parameters) will be used.
:param actor_factory: the factory to use for the creation of the actor network
:return: the builder
"""
self._actor_factory = actor_factory self._actor_factory = actor_factory
return self return self
@ -397,6 +446,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
continuous_unbounded: bool = False, continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False, continuous_conditioned_sigma: bool = False,
) -> Self: ) -> Self:
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter.
:return: the builder
"""
self._actor_factory = ActorFactoryDefault( self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type, self._continuous_actor_type,
hidden_sizes, hidden_sizes,
@ -406,6 +461,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
return self return self
def get_actor_future(self) -> ActorFuture: def get_actor_future(self) -> ActorFuture:
""":return: an object, which, in the future, will contain the actor instance that is created for the experiment."""
return self._actor_future return self._actor_future
def _get_actor_factory(self) -> ActorFactory: def _get_actor_factory(self) -> ActorFactory:
@ -431,6 +487,15 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
continuous_unbounded: bool = False, continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False, continuous_conditioned_sigma: bool = False,
) -> Self: ) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter.
:return: the builder
"""
return super()._with_actor_factory_default( return super()._with_actor_factory_default(
hidden_sizes, hidden_sizes,
continuous_unbounded=continuous_unbounded, continuous_unbounded=continuous_unbounded,
@ -445,6 +510,12 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
super().__init__(ContinuousActorType.DETERMINISTIC) super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
:return: the builder
"""
return super()._with_actor_factory_default(hidden_sizes) return super()._with_actor_factory_default(hidden_sizes)
@ -480,6 +551,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
super().__init__(1, actor_future_provider) super().__init__(1, actor_future_provider)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self: def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for the critic.
:param critic_factory: the critic factory
:return: the builder
"""
self._with_critic_factory(0, critic_factory) self._with_critic_factory(0, critic_factory)
return self return self
@ -487,6 +563,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self: ) -> Self:
"""Makes the critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
@ -496,7 +577,7 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
super().__init__(actor_future_provider) super().__init__(actor_future_provider)
def with_critic_factory_use_actor(self) -> Self: def with_critic_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor.""" """Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(0) return self._with_critic_factory_use_actor(0)
@ -505,6 +586,11 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
super().__init__(2, actor_future_provider) super().__init__(2, actor_future_provider)
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for both critics.
:param critic_factory: the critic factory
:return: the builder
"""
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
self._with_critic_factory(i, critic_factory) self._with_critic_factory(i, critic_factory)
return self return self
@ -513,17 +599,27 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self: ) -> Self:
"""Makes both critics use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes) self._with_critic_factory_default(i, hidden_sizes)
return self return self
def with_common_critic_factory_use_actor(self) -> Self: def with_common_critic_factory_use_actor(self) -> Self:
"""Makes all critics use the same network as the actor.""" """Makes both critics reuse the actor's preprocessing network (parameter sharing)."""
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
self._with_critic_factory_use_actor(i) self._with_critic_factory_use_actor(i)
return self return self
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for the first critic.
:param critic_factory: the critic factory
:return: the builder
"""
self._with_critic_factory(0, critic_factory) self._with_critic_factory(0, critic_factory)
return self return self
@ -531,14 +627,24 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self: ) -> Self:
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
def with_critic1_factory_use_actor(self) -> Self: def with_critic1_factory_use_actor(self) -> Self:
"""Makes the critic use the same network as the actor.""" """Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(0) return self._with_critic_factory_use_actor(0)
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
"""Specifies that the given factory shall be used for the second critic.
:param critic_factory: the critic factory
:return: the builder
"""
self._with_critic_factory(1, critic_factory) self._with_critic_factory(1, critic_factory)
return self return self
@ -546,11 +652,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self: ) -> Self:
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
def with_critic2_factory_use_actor(self) -> Self: def with_critic2_factory_use_actor(self) -> Self:
"""Makes the second critic use the same network as the actor.""" """Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(1) return self._with_critic_factory_use_actor(1)
@ -559,6 +670,12 @@ class _BuilderMixinCriticEnsembleFactory:
self.critic_ensemble_factory: CriticEnsembleFactory | None = None self.critic_ensemble_factory: CriticEnsembleFactory | None = None
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self: def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
"""Specifies that the given factory shall be used for the critic ensemble.
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
:param critic_factory: the critic factory
:return: the builder
"""
self.critic_ensemble_factory = factory self.critic_ensemble_factory = factory
return self return self
@ -566,6 +683,11 @@ class _BuilderMixinCriticEnsembleFactory:
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self: ) -> Self:
"""Allows to customize the parameters of the default critic ensemble factory.
:param hidden_sizes: the sequence of sizes of hidden layers in the network architecture
:return: the builder
"""
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
return self return self

View File

@ -27,7 +27,7 @@ class LoggerFactory(ToStringMixin, ABC):
""" """
class DefaultLoggerFactory(LoggerFactory): class LoggerFactoryDefault(LoggerFactory):
def __init__( def __init__(
self, self,
logger_type: Literal["tensorboard", "wandb"] = "tensorboard", logger_type: Literal["tensorboard", "wandb"] = "tensorboard",

View File

@ -9,12 +9,14 @@ from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import ( from tianshou.highlevel.module.core import (
IntermediateModule,
IntermediateModuleFactory,
ModuleFactory, ModuleFactory,
TDevice, TDevice,
init_linear_orthogonal, init_linear_orthogonal,
) )
from tianshou.highlevel.module.intermediate import (
IntermediateModule,
IntermediateModuleFactory,
)
from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
@ -157,6 +159,11 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
unbounded: bool = True, unbounded: bool = True,
conditioned_sigma: bool = False, conditioned_sigma: bool = False,
): ):
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param unbounded: whether to apply tanh activation on final logits
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
input; if False, sigma is an independent parameter
"""
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma self.conditioned_sigma = conditioned_sigma

View File

@ -1,14 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias from typing import TypeAlias
import numpy as np import numpy as np
import torch import torch
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | torch.device TDevice: TypeAlias = str | torch.device
@ -25,44 +21,8 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
class ModuleFactory(ABC): class ModuleFactory(ABC):
"""Represents a factory for the creation of a torch module given an environment and target device."""
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
pass pass
@dataclass
class IntermediateModule:
module: torch.nn.Module
output_dim: int
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
@abstractmethod
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
def __init__(
self,
preprocess_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
):
self.preprocess_net_factory = preprocess_net_factory
self.hidden_sizes = hidden_sizes
self.num_cosines = num_cosines
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
return ImplicitQuantileNetwork(
preprocess_net=preprocess_net.module,
action_shape=envs.get_action_shape(),
hidden_sizes=self.hidden_sizes,
num_cosines=self.num_cosines,
preprocess_net_output_dim=preprocess_net.output_dim,
device=device,
).to(device)

View File

@ -15,6 +15,8 @@ from tianshou.utils.string import ToStringMixin
class CriticFactory(ToStringMixin, ABC): class CriticFactory(ToStringMixin, ABC):
"""Represents a factory for the generation of a critic module."""
@abstractmethod @abstractmethod
def create_module( def create_module(
self, self,
@ -23,9 +25,11 @@ class CriticFactory(ToStringMixin, ABC):
use_action: bool, use_action: bool,
discrete_last_size_use_action_shape: bool = False, discrete_last_size_use_action_shape: bool = False,
) -> nn.Module: ) -> nn.Module:
""":param envs: the environments """Creates the critic module.
:param envs: the environments
:param device: the torch device :param device: the torch device
:param use_action: whether to (additionally) expect the action as input :param use_action: whether to expect the action as an additional input (in addition to the observations)
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
:return: the module :return: the module
""" """
@ -39,6 +43,16 @@ class CriticFactory(ToStringMixin, ABC):
lr: float, lr: float,
discrete_last_size_use_action_shape: bool = False, discrete_last_size_use_action_shape: bool = False,
) -> ModuleOpt: ) -> ModuleOpt:
"""Creates the critic module along with its optimizer for the given learning rate.
:param envs: the environments
:param device: the torch device
:param use_action: whether to expect the action as an additional input (in addition to the observations)
:param optim_factory: the optimizer factory
:param lr: the learning rate
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
:return:
"""
module = self.create_module( module = self.create_module(
envs, envs,
device, device,

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.utils.string import ToStringMixin
@dataclass
class IntermediateModule:
"""Container for a module which computes an intermediate representation (with a known dimension)."""
module: torch.nn.Module
output_dim: int
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
"""Factory for the generation of a module which computes an intermediate representation."""
@abstractmethod
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module

View File

@ -7,12 +7,16 @@ from tianshou.utils.net.common import ActorCritic
@dataclass @dataclass
class ModuleOpt: class ModuleOpt:
"""Container for a torch module along with its optimizer."""
module: torch.nn.Module module: torch.nn.Module
optim: torch.optim.Optimizer optim: torch.optim.Optimizer
@dataclass @dataclass
class ActorCriticModuleOpt: class ActorCriticModuleOpt:
"""Container for an :class:`ActorCritic` instance along with its optimizer."""
actor_critic_module: ActorCritic actor_critic_module: ActorCritic
optim: torch.optim.Optimizer optim: torch.optim.Optimizer

View File

@ -0,0 +1,30 @@
from collections.abc import Sequence
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
def __init__(
self,
preprocess_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
):
self.preprocess_net_factory = preprocess_net_factory
self.hidden_sizes = hidden_sizes
self.num_cosines = num_cosines
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
return ImplicitQuantileNetwork(
preprocess_net=preprocess_net.module,
action_shape=envs.get_action_shape(),
hidden_sizes=self.hidden_sizes,
num_cosines=self.num_cosines,
preprocess_net_output_dim=preprocess_net.output_dim,
device=device,
).to(device)

View File

@ -13,10 +13,6 @@ class OptimizerWithLearningRateProtocol(Protocol):
class OptimizerFactory(ABC, ToStringMixin): class OptimizerFactory(ABC, ToStringMixin):
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter.
# If we drop the assumption, we can't have that and will need to move the parameter
# to the optimizer factory, which is inconvenient for the user.
@abstractmethod @abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass pass

View File

@ -9,6 +9,8 @@ from tianshou.utils.string import ToStringMixin
class LRSchedulerFactory(ToStringMixin, ABC): class LRSchedulerFactory(ToStringMixin, ABC):
"""Factory for the createion of a learning rate scheduler."""
@abstractmethod @abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass pass

View File

@ -18,9 +18,12 @@ class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
""" """
def __init__(self, std_fraction: float): def __init__(self, std_fraction: float):
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
be used as the standard deviation
"""
self.std_fraction = std_fraction self.std_fraction = std_fraction
def create_noise(self, envs: Environments) -> BaseNoise: def create_noise(self, envs: Environments) -> GaussianNoise:
envs.get_type().assert_continuous(self) envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments envs: ContinuousEnvironments
return GaussianNoise(sigma=envs.max_action * self.std_fraction) return GaussianNoise(sigma=envs.max_action * self.std_fraction)

View File

@ -241,7 +241,9 @@ class Params(GetParamTransformersProtocol):
@dataclass @dataclass
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3 lr: float = 1e-3
"""the learning rate to use in the gradient-based optimizer"""
lr_scheduler_factory: LRSchedulerFactory | None = None lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler"""
def _get_param_transformers(self) -> list[ParamTransformer]: def _get_param_transformers(self) -> list[ParamTransformer]:
return [ return [

View File

@ -3,7 +3,8 @@ from collections.abc import Sequence
from typing import Generic, TypeVar from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.net.discrete import IntrinsicCuriosityModule

View File

@ -50,6 +50,8 @@ class Persistence(ABC):
class PersistenceGroup(Persistence): class PersistenceGroup(Persistence):
"""Groups persistence handler such that they can be applied collectively."""
def __init__(self, *p: Persistence, enabled: bool = True): def __init__(self, *p: Persistence, enabled: bool = True):
self.items = p self.items = p
self.enabled = enabled self.enabled = enabled
@ -69,7 +71,7 @@ class PolicyPersistence:
FILENAME = "policy.dat" FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True): def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
""":param additional_persistence: a persistence instance which is to be envoked whenever """:param additional_persistence: a persistence instance which is to be invoked whenever
this object is used to persist/restore data this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled) :param enabled: whether persistence is enabled (restoration is always enabled)
""" """

View File

@ -52,6 +52,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
""":param mean_rewards: the average undiscounted returns of the testing result """:param mean_rewards: the average undiscounted returns of the testing result
:param context: the training context
:return: True if the goal has been reached and training should stop, False otherwise :return: True if the goal has been reached and training should stop, False otherwise
""" """
@ -64,6 +65,8 @@ class TrainerStopCallback(ToStringMixin, ABC):
@dataclass @dataclass
class TrainerCallbacks: class TrainerCallbacks:
"""Container for callbacks used during training."""
epoch_callback_train: TrainerEpochCallbackTrain | None = None epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None stop_callback: TrainerStopCallback | None = None

View File

@ -12,6 +12,8 @@ if TYPE_CHECKING:
@dataclass @dataclass
class World: class World:
"""Container for instances and configuration items that are relevant to an experiment."""
envs: "Environments" envs: "Environments"
policy: "BasePolicy" policy: "BasePolicy"
train_collector: "Collector" train_collector: "Collector"