Add documentation, improve structure of 'module' package
This commit is contained in:
parent
97e21b5ddf
commit
4b270eaa2d
@ -8,9 +8,11 @@ from torch import nn
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.actor import ActorFactory
|
||||
from tianshou.highlevel.module.core import (
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.module.intermediate import (
|
||||
IntermediateModule,
|
||||
IntermediateModuleFactory,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||
|
||||
|
@ -81,6 +81,8 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentFactory(ABC, ToStringMixin):
|
||||
"""Factory for the creation of an agent's policy, its trainer as well as collectors."""
|
||||
|
||||
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
|
||||
self.sampling_config = sampling_config
|
||||
self.optim_factory = optim_factory
|
||||
|
@ -14,6 +14,8 @@ TObservationShape: TypeAlias = int | Sequence[int]
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
"""Enumeration of environment types."""
|
||||
|
||||
CONTINUOUS = "continuous"
|
||||
DISCRETE = "discrete"
|
||||
|
||||
@ -33,6 +35,8 @@ class EnvType(Enum):
|
||||
|
||||
|
||||
class Environments(ToStringMixin, ABC):
|
||||
"""Represents (vectorized) environments."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
@ -52,6 +56,11 @@ class Environments(ToStringMixin, ABC):
|
||||
}
|
||||
|
||||
def set_persistence(self, *p: Persistence) -> None:
|
||||
"""Associates the given persistence handlers which may persist and restore
|
||||
environment-specific information.
|
||||
|
||||
:param p: persistence handlers
|
||||
"""
|
||||
self.persistence = p
|
||||
|
||||
@abstractmethod
|
||||
@ -74,6 +83,8 @@ class Environments(ToStringMixin, ABC):
|
||||
|
||||
|
||||
class ContinuousEnvironments(Environments):
|
||||
"""Represents (vectorized) continuous environments."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||
@ -110,6 +121,8 @@ class ContinuousEnvironments(Environments):
|
||||
|
||||
|
||||
class DiscreteEnvironments(Environments):
|
||||
"""Represents (vectorized) discrete environments."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||
|
@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
|
||||
)
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
|
||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
ActorFactoryDefault,
|
||||
@ -38,8 +38,6 @@ from tianshou.highlevel.module.actor import (
|
||||
IntermediateModuleFactoryFromActorFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.core import (
|
||||
ImplicitQuantileNetworkFactory,
|
||||
IntermediateModuleFactory,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import (
|
||||
@ -49,6 +47,8 @@ from tianshou.highlevel.module.critic import (
|
||||
CriticFactoryDefault,
|
||||
CriticFactoryReuseActor,
|
||||
)
|
||||
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
||||
from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory
|
||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
@ -116,8 +116,12 @@ class ExperimentConfig:
|
||||
|
||||
@dataclass
|
||||
class ExperimentResult:
|
||||
"""Contains the results of an experiment."""
|
||||
|
||||
world: World
|
||||
"""contains all the essential instances of the experiment"""
|
||||
trainer_result: dict[str, Any] | None
|
||||
"""dictionary of results as returned by the trained (if any)"""
|
||||
|
||||
|
||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
@ -140,7 +144,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = DefaultLoggerFactory()
|
||||
logger_factory = LoggerFactoryDefault()
|
||||
self.config = config
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
@ -179,7 +183,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
pickle.dump(self, f)
|
||||
|
||||
def run(
|
||||
self, experiment_name: str | None = None, logger_run_id: str | None = None,
|
||||
self,
|
||||
experiment_name: str | None = None,
|
||||
logger_run_id: str | None = None,
|
||||
) -> ExperimentResult:
|
||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
@ -317,14 +323,31 @@ class ExperimentBuilder:
|
||||
return self
|
||||
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
"""Allows to customize the logger factory to use.
|
||||
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||
|
||||
:param logger_factory: the factory to use
|
||||
:return: the builder
|
||||
"""
|
||||
self._logger_factory = logger_factory
|
||||
return self
|
||||
|
||||
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
|
||||
"""Allows to define a wrapper around the policy that is created, extending the original policy.
|
||||
|
||||
:param policy_wrapper_factory: the factory for the wrapper
|
||||
:return: the builder
|
||||
"""
|
||||
self._policy_wrapper_factory = policy_wrapper_factory
|
||||
return self
|
||||
|
||||
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||
"""Allows to customize the gradient-based optimizer to use.
|
||||
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
|
||||
|
||||
:param optim_factory: the optimizer factory
|
||||
:return: the builder
|
||||
"""
|
||||
self._optim_factory = optim_factory
|
||||
return self
|
||||
|
||||
@ -345,14 +368,30 @@ class ExperimentBuilder:
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
||||
"""Allows to define a callback function which is called at the beginning of every epoch during training.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.epoch_callback_train = callback
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
||||
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.epoch_callback_test = callback
|
||||
return self
|
||||
|
||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||
"""Allows to define a callback that decides whether training shall stop early.
|
||||
The callback receives the undiscounted returns of the testing result.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.stop_callback = callback
|
||||
return self
|
||||
|
||||
@ -367,6 +406,10 @@ class ExperimentBuilder:
|
||||
return self._optim_factory
|
||||
|
||||
def build(self) -> Experiment:
|
||||
"""Creates the experiment based on the options specified via this builder.
|
||||
|
||||
:return: the experiment
|
||||
"""
|
||||
agent_factory = self._create_agent_factory()
|
||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||
if self._policy_wrapper_factory:
|
||||
@ -388,6 +431,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
self._actor_factory: ActorFactory | None = None
|
||||
|
||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||
"""Allows to customize the actor component via the specification of a factory.
|
||||
If this function is not called, a default actor factory (with default parameters) will be used.
|
||||
|
||||
:param actor_factory: the factory to use for the creation of the actor network
|
||||
:return: the builder
|
||||
"""
|
||||
self._actor_factory = actor_factory
|
||||
return self
|
||||
|
||||
@ -397,6 +446,12 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||
shall be computed from the input; if False, sigma is an independent parameter.
|
||||
:return: the builder
|
||||
"""
|
||||
self._actor_factory = ActorFactoryDefault(
|
||||
self._continuous_actor_type,
|
||||
hidden_sizes,
|
||||
@ -406,6 +461,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
return self
|
||||
|
||||
def get_actor_future(self) -> ActorFuture:
|
||||
""":return: an object, which, in the future, will contain the actor instance that is created for the experiment."""
|
||||
return self._actor_future
|
||||
|
||||
def _get_actor_factory(self) -> ActorFactory:
|
||||
@ -431,6 +487,15 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||
The default actor factory uses an MLP-style architecture.
|
||||
|
||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||
shall be computed from the input; if False, sigma is an independent parameter.
|
||||
:return: the builder
|
||||
"""
|
||||
return super()._with_actor_factory_default(
|
||||
hidden_sizes,
|
||||
continuous_unbounded=continuous_unbounded,
|
||||
@ -445,6 +510,12 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||
|
||||
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||
The default actor factory uses an MLP-style architecture.
|
||||
|
||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||
:return: the builder
|
||||
"""
|
||||
return super()._with_actor_factory_default(hidden_sizes)
|
||||
|
||||
|
||||
@ -480,6 +551,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
super().__init__(1, actor_future_provider)
|
||||
|
||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for the critic.
|
||||
|
||||
:param critic_factory: the critic factory
|
||||
:return: the builder
|
||||
"""
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
|
||||
@ -487,6 +563,11 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> Self:
|
||||
"""Makes the critic use the default, MLP-style architecture with the given parameters.
|
||||
|
||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||
:return: the builder
|
||||
"""
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
@ -496,7 +577,7 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
|
||||
super().__init__(actor_future_provider)
|
||||
|
||||
def with_critic_factory_use_actor(self) -> Self:
|
||||
"""Makes the critic use the same network as the actor."""
|
||||
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
|
||||
return self._with_critic_factory_use_actor(0)
|
||||
|
||||
|
||||
@ -505,6 +586,11 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
super().__init__(2, actor_future_provider)
|
||||
|
||||
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for both critics.
|
||||
|
||||
:param critic_factory: the critic factory
|
||||
:return: the builder
|
||||
"""
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory(i, critic_factory)
|
||||
return self
|
||||
@ -513,17 +599,27 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> Self:
|
||||
"""Makes both critics use the default, MLP-style architecture with the given parameters.
|
||||
|
||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||
:return: the builder
|
||||
"""
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory_default(i, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_common_critic_factory_use_actor(self) -> Self:
|
||||
"""Makes all critics use the same network as the actor."""
|
||||
"""Makes both critics reuse the actor's preprocessing network (parameter sharing)."""
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory_use_actor(i)
|
||||
return self
|
||||
|
||||
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for the first critic.
|
||||
|
||||
:param critic_factory: the critic factory
|
||||
:return: the builder
|
||||
"""
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
|
||||
@ -531,14 +627,24 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> Self:
|
||||
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
|
||||
|
||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||
:return: the builder
|
||||
"""
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic1_factory_use_actor(self) -> Self:
|
||||
"""Makes the critic use the same network as the actor."""
|
||||
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
|
||||
return self._with_critic_factory_use_actor(0)
|
||||
|
||||
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for the second critic.
|
||||
|
||||
:param critic_factory: the critic factory
|
||||
:return: the builder
|
||||
"""
|
||||
self._with_critic_factory(1, critic_factory)
|
||||
return self
|
||||
|
||||
@ -546,11 +652,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> Self:
|
||||
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
|
||||
|
||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||
:return: the builder
|
||||
"""
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic2_factory_use_actor(self) -> Self:
|
||||
"""Makes the second critic use the same network as the actor."""
|
||||
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
|
||||
return self._with_critic_factory_use_actor(1)
|
||||
|
||||
|
||||
@ -559,6 +670,12 @@ class _BuilderMixinCriticEnsembleFactory:
|
||||
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
|
||||
|
||||
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for the critic ensemble.
|
||||
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
||||
|
||||
:param critic_factory: the critic factory
|
||||
:return: the builder
|
||||
"""
|
||||
self.critic_ensemble_factory = factory
|
||||
return self
|
||||
|
||||
@ -566,6 +683,11 @@ class _BuilderMixinCriticEnsembleFactory:
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> Self:
|
||||
"""Allows to customize the parameters of the default critic ensemble factory.
|
||||
|
||||
:param hidden_sizes: the sequence of sizes of hidden layers in the network architecture
|
||||
:return: the builder
|
||||
"""
|
||||
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
||||
return self
|
||||
|
||||
|
@ -27,7 +27,7 @@ class LoggerFactory(ToStringMixin, ABC):
|
||||
"""
|
||||
|
||||
|
||||
class DefaultLoggerFactory(LoggerFactory):
|
||||
class LoggerFactoryDefault(LoggerFactory):
|
||||
def __init__(
|
||||
self,
|
||||
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
||||
|
@ -9,12 +9,14 @@ from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.core import (
|
||||
IntermediateModule,
|
||||
IntermediateModuleFactory,
|
||||
ModuleFactory,
|
||||
TDevice,
|
||||
init_linear_orthogonal,
|
||||
)
|
||||
from tianshou.highlevel.module.intermediate import (
|
||||
IntermediateModule,
|
||||
IntermediateModuleFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
@ -157,6 +159,11 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||
unbounded: bool = True,
|
||||
conditioned_sigma: bool = False,
|
||||
):
|
||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
:param unbounded: whether to apply tanh activation on final logits
|
||||
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
|
||||
input; if False, sigma is an independent parameter
|
||||
"""
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.unbounded = unbounded
|
||||
self.conditioned_sigma = conditioned_sigma
|
||||
|
@ -1,14 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TDevice: TypeAlias = str | torch.device
|
||||
|
||||
@ -25,44 +21,8 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
|
||||
|
||||
|
||||
class ModuleFactory(ABC):
|
||||
"""Represents a factory for the creation of a torch module given an environment and target device."""
|
||||
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateModule:
|
||||
module: torch.nn.Module
|
||||
output_dim: int
|
||||
|
||||
|
||||
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
|
||||
@abstractmethod
|
||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||
pass
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
return self.create_intermediate_module(envs, device).module
|
||||
|
||||
|
||||
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net_factory: IntermediateModuleFactory,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
num_cosines: int = 64,
|
||||
):
|
||||
self.preprocess_net_factory = preprocess_net_factory
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.num_cosines = num_cosines
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
|
||||
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
|
||||
return ImplicitQuantileNetwork(
|
||||
preprocess_net=preprocess_net.module,
|
||||
action_shape=envs.get_action_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
num_cosines=self.num_cosines,
|
||||
preprocess_net_output_dim=preprocess_net.output_dim,
|
||||
device=device,
|
||||
).to(device)
|
||||
|
@ -15,6 +15,8 @@ from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class CriticFactory(ToStringMixin, ABC):
|
||||
"""Represents a factory for the generation of a critic module."""
|
||||
|
||||
@abstractmethod
|
||||
def create_module(
|
||||
self,
|
||||
@ -23,9 +25,11 @@ class CriticFactory(ToStringMixin, ABC):
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> nn.Module:
|
||||
""":param envs: the environments
|
||||
"""Creates the critic module.
|
||||
|
||||
:param envs: the environments
|
||||
:param device: the torch device
|
||||
:param use_action: whether to (additionally) expect the action as input
|
||||
:param use_action: whether to expect the action as an additional input (in addition to the observations)
|
||||
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
||||
:return: the module
|
||||
"""
|
||||
@ -39,6 +43,16 @@ class CriticFactory(ToStringMixin, ABC):
|
||||
lr: float,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> ModuleOpt:
|
||||
"""Creates the critic module along with its optimizer for the given learning rate.
|
||||
|
||||
:param envs: the environments
|
||||
:param device: the torch device
|
||||
:param use_action: whether to expect the action as an additional input (in addition to the observations)
|
||||
:param optim_factory: the optimizer factory
|
||||
:param lr: the learning rate
|
||||
:param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape
|
||||
:return:
|
||||
"""
|
||||
module = self.create_module(
|
||||
envs,
|
||||
device,
|
||||
|
27
tianshou/highlevel/module/intermediate.py
Normal file
27
tianshou/highlevel/module/intermediate.py
Normal file
@ -0,0 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateModule:
|
||||
"""Container for a module which computes an intermediate representation (with a known dimension)."""
|
||||
|
||||
module: torch.nn.Module
|
||||
output_dim: int
|
||||
|
||||
|
||||
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
|
||||
"""Factory for the generation of a module which computes an intermediate representation."""
|
||||
|
||||
@abstractmethod
|
||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||
pass
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
return self.create_intermediate_module(envs, device).module
|
@ -7,12 +7,16 @@ from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
@dataclass
|
||||
class ModuleOpt:
|
||||
"""Container for a torch module along with its optimizer."""
|
||||
|
||||
module: torch.nn.Module
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorCriticModuleOpt:
|
||||
"""Container for an :class:`ActorCritic` instance along with its optimizer."""
|
||||
|
||||
actor_critic_module: ActorCritic
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
|
30
tianshou/highlevel/module/special.py
Normal file
30
tianshou/highlevel/module/special.py
Normal file
@ -0,0 +1,30 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
||||
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net_factory: IntermediateModuleFactory,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
num_cosines: int = 64,
|
||||
):
|
||||
self.preprocess_net_factory = preprocess_net_factory
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.num_cosines = num_cosines
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
|
||||
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
|
||||
return ImplicitQuantileNetwork(
|
||||
preprocess_net=preprocess_net.module,
|
||||
action_shape=envs.get_action_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
num_cosines=self.num_cosines,
|
||||
preprocess_net_output_dim=preprocess_net.output_dim,
|
||||
device=device,
|
||||
).to(device)
|
@ -13,10 +13,6 @@ class OptimizerWithLearningRateProtocol(Protocol):
|
||||
|
||||
|
||||
class OptimizerFactory(ABC, ToStringMixin):
|
||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
||||
# Right now, the learning rate is typically a configuration parameter.
|
||||
# If we drop the assumption, we can't have that and will need to move the parameter
|
||||
# to the optimizer factory, which is inconvenient for the user.
|
||||
@abstractmethod
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
@ -9,6 +9,8 @@ from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class LRSchedulerFactory(ToStringMixin, ABC):
|
||||
"""Factory for the createion of a learning rate scheduler."""
|
||||
|
||||
@abstractmethod
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
pass
|
||||
|
@ -18,9 +18,12 @@ class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
||||
"""
|
||||
|
||||
def __init__(self, std_fraction: float):
|
||||
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||
be used as the standard deviation
|
||||
"""
|
||||
self.std_fraction = std_fraction
|
||||
|
||||
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||
def create_noise(self, envs: Environments) -> GaussianNoise:
|
||||
envs.get_type().assert_continuous(self)
|
||||
envs: ContinuousEnvironments
|
||||
return GaussianNoise(sigma=envs.max_action * self.std_fraction)
|
||||
|
@ -241,7 +241,9 @@ class Params(GetParamTransformersProtocol):
|
||||
@dataclass
|
||||
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
||||
lr: float = 1e-3
|
||||
"""the learning rate to use in the gradient-based optimizer"""
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
"""factory for the creation of a learning rate scheduler"""
|
||||
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
return [
|
||||
|
@ -3,7 +3,8 @@ from collections.abc import Sequence
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.policy import BasePolicy, ICMPolicy
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
|
@ -50,6 +50,8 @@ class Persistence(ABC):
|
||||
|
||||
|
||||
class PersistenceGroup(Persistence):
|
||||
"""Groups persistence handler such that they can be applied collectively."""
|
||||
|
||||
def __init__(self, *p: Persistence, enabled: bool = True):
|
||||
self.items = p
|
||||
self.enabled = enabled
|
||||
@ -69,7 +71,7 @@ class PolicyPersistence:
|
||||
FILENAME = "policy.dat"
|
||||
|
||||
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
||||
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
||||
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
||||
this object is used to persist/restore data
|
||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||
"""
|
||||
|
@ -52,6 +52,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
""":param mean_rewards: the average undiscounted returns of the testing result
|
||||
:param context: the training context
|
||||
:return: True if the goal has been reached and training should stop, False otherwise
|
||||
"""
|
||||
|
||||
@ -64,6 +65,8 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
|
||||
@dataclass
|
||||
class TrainerCallbacks:
|
||||
"""Container for callbacks used during training."""
|
||||
|
||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||
stop_callback: TrainerStopCallback | None = None
|
||||
|
@ -12,6 +12,8 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class World:
|
||||
"""Container for instances and configuration items that are relevant to an experiment."""
|
||||
|
||||
envs: "Environments"
|
||||
policy: "BasePolicy"
|
||||
train_collector: "Collector"
|
||||
|
Loading…
x
Reference in New Issue
Block a user