Docstring: minor changes to let ruff pass
This commit is contained in:
parent
28fda00b27
commit
2e39a252e3
@ -39,6 +39,7 @@ def write_to_file(content: str, path: str):
|
|||||||
|
|
||||||
def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""):
|
def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""):
|
||||||
"""Creates/updates documentation in form of rst files for modules and packages.
|
"""Creates/updates documentation in form of rst files for modules and packages.
|
||||||
|
|
||||||
Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed
|
Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed
|
||||||
should be deleted by hand.
|
should be deleted by hand.
|
||||||
|
|
||||||
@ -116,6 +117,6 @@ if __name__ == "__main__":
|
|||||||
docs_root = Path(__file__).parent
|
docs_root = Path(__file__).parent
|
||||||
make_rst(
|
make_rst(
|
||||||
docs_root / ".." / "tianshou",
|
docs_root / ".." / "tianshou",
|
||||||
docs_root / "api",
|
docs_root / "03_api",
|
||||||
clean=True,
|
clean=True,
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@ class ReplayBuffer:
|
|||||||
stores all the data in a batch with circular-queue style.
|
stores all the data in a batch with circular-queue style.
|
||||||
|
|
||||||
For the example usage of ReplayBuffer, please check out Section Buffer in
|
For the example usage of ReplayBuffer, please check out Section Buffer in
|
||||||
:doc:`/tutorials/01_concepts`.
|
:doc:`/01_tutorials/01_concepts`.
|
||||||
|
|
||||||
:param size: the maximum size of replay buffer.
|
:param size: the maximum size of replay buffer.
|
||||||
:param stack_num: the frame-stack sampling argument, should be greater than or
|
:param stack_num: the frame-stack sampling argument, should be greater than or
|
||||||
|
@ -81,8 +81,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
) -> "Environments":
|
) -> "Environments":
|
||||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and
|
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||||
the type of environment (continuous/discrete).
|
|
||||||
|
|
||||||
:param factory_fn: the factory for a single environment instance
|
:param factory_fn: the factory for a single environment instance
|
||||||
:param env_type: the type of environments created by `factory_fn`
|
:param env_type: the type of environments created by `factory_fn`
|
||||||
@ -115,8 +114,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def set_persistence(self, *p: Persistence) -> None:
|
def set_persistence(self, *p: Persistence) -> None:
|
||||||
"""Associates the given persistence handlers which may persist and restore
|
"""Associates the given persistence handlers which may persist and restore environment-specific information.
|
||||||
environment-specific information.
|
|
||||||
|
|
||||||
:param p: persistence handlers
|
:param p: persistence handlers
|
||||||
"""
|
"""
|
||||||
|
@ -188,7 +188,9 @@ class Experiment(ToStringMixin):
|
|||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
logger_run_id: str | None = None,
|
logger_run_id: str | None = None,
|
||||||
) -> ExperimentResult:
|
) -> ExperimentResult:
|
||||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
"""Run the experiment and return the results.
|
||||||
|
|
||||||
|
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||||
directory) where all results associated with the experiment will be saved.
|
directory) where all results associated with the experiment will be saved.
|
||||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||||
a nested directory structure will be created.
|
a nested directory structure will be created.
|
||||||
@ -327,6 +329,7 @@ class ExperimentBuilder:
|
|||||||
|
|
||||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||||
"""Allows to customize the logger factory to use.
|
"""Allows to customize the logger factory to use.
|
||||||
|
|
||||||
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||||
|
|
||||||
:param logger_factory: the factory to use
|
:param logger_factory: the factory to use
|
||||||
@ -346,6 +349,7 @@ class ExperimentBuilder:
|
|||||||
|
|
||||||
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||||
"""Allows to customize the gradient-based optimizer to use.
|
"""Allows to customize the gradient-based optimizer to use.
|
||||||
|
|
||||||
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
|
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
|
||||||
|
|
||||||
:param optim_factory: the optimizer factory
|
:param optim_factory: the optimizer factory
|
||||||
@ -390,6 +394,7 @@ class ExperimentBuilder:
|
|||||||
|
|
||||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||||
"""Allows to define a callback that decides whether training shall stop early.
|
"""Allows to define a callback that decides whether training shall stop early.
|
||||||
|
|
||||||
The callback receives the undiscounted returns of the testing result.
|
The callback receives the undiscounted returns of the testing result.
|
||||||
|
|
||||||
:param callback: the callback
|
:param callback: the callback
|
||||||
@ -435,6 +440,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
|
|
||||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||||
"""Allows to customize the actor component via the specification of a factory.
|
"""Allows to customize the actor component via the specification of a factory.
|
||||||
|
|
||||||
If this function is not called, a default actor factory (with default parameters) will be used.
|
If this function is not called, a default actor factory (with default parameters) will be used.
|
||||||
|
|
||||||
:param actor_factory: the factory to use for the creation of the actor network
|
:param actor_factory: the factory to use for the creation of the actor network
|
||||||
@ -450,7 +456,9 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
"""Adds a default actor factory with the given parameters.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||||
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||||
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||||
shall be computed from the input; if False, sigma is an independent parameter.
|
shall be computed from the input; if False, sigma is an independent parameter.
|
||||||
@ -479,9 +487,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||||
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs
|
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters."""
|
||||||
Gaussian distribution parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||||
@ -494,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||||
|
|
||||||
The default actor factory uses an MLP-style architecture.
|
The default actor factory uses an MLP-style architecture.
|
||||||
|
|
||||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||||
@ -523,6 +530,7 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
|||||||
hidden_activation: ModuleType = torch.nn.ReLU,
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||||
|
|
||||||
The default actor factory uses an MLP-style architecture.
|
The default actor factory uses an MLP-style architecture.
|
||||||
|
|
||||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||||
@ -700,6 +708,7 @@ class _BuilderMixinCriticEnsembleFactory:
|
|||||||
|
|
||||||
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||||
"""Specifies that the given factory shall be used for the critic ensemble.
|
"""Specifies that the given factory shall be used for the critic ensemble.
|
||||||
|
|
||||||
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
||||||
|
|
||||||
:param factory: the critic ensemble factory
|
:param factory: the critic ensemble factory
|
||||||
|
@ -19,7 +19,9 @@ class LoggerFactory(ToStringMixin, ABC):
|
|||||||
run_id: str | None,
|
run_id: str | None,
|
||||||
config_dict: dict,
|
config_dict: dict,
|
||||||
) -> TLogger:
|
) -> TLogger:
|
||||||
""":param log_dir: path to the directory in which log data is to be stored
|
"""Creates the logger.
|
||||||
|
|
||||||
|
:param log_dir: path to the directory in which log data is to be stored
|
||||||
:param experiment_name: the name of the job, which may contain `os.path.sep`
|
:param experiment_name: the name of the job, which may contain `os.path.sep`
|
||||||
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
||||||
:param config_dict: a dictionary with data that is to be logged
|
:param config_dict: a dictionary with data that is to be logged
|
||||||
|
@ -168,7 +168,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
|||||||
conditioned_sigma: bool = False,
|
conditioned_sigma: bool = False,
|
||||||
activation: ModuleType = nn.ReLU,
|
activation: ModuleType = nn.ReLU,
|
||||||
):
|
):
|
||||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
"""For actors with Gaussian policies.
|
||||||
|
|
||||||
|
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||||
:param unbounded: whether to apply tanh activation on final logits
|
:param unbounded: whether to apply tanh activation on final logits
|
||||||
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
|
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
|
||||||
input; if False, sigma is an independent parameter
|
input; if False, sigma is an independent parameter
|
||||||
@ -229,9 +231,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
|||||||
|
|
||||||
|
|
||||||
class ActorFactoryTransientStorageDecorator(ActorFactory):
|
class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||||
"""Wraps an actor factory, storing the most recently created actor instance such that it
|
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
|
||||||
can be retrieved.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
|
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
|
@ -20,7 +20,9 @@ class OptimizerFactory(ABC, ToStringMixin):
|
|||||||
|
|
||||||
class OptimizerFactoryTorch(OptimizerFactory):
|
class OptimizerFactoryTorch(OptimizerFactory):
|
||||||
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
|
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
|
||||||
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
"""Factory for torch optimizers.
|
||||||
|
|
||||||
|
:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||||
which will be passed the module parameters, the learning rate as `lr` and the
|
which will be passed the module parameters, the learning rate as `lr` and the
|
||||||
kwargs provided.
|
kwargs provided.
|
||||||
:param kwargs: keyword arguments to provide at optimizer construction
|
:param kwargs: keyword arguments to provide at optimizer construction
|
||||||
|
@ -12,13 +12,12 @@ class NoiseFactory(ToStringMixin, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
||||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
|
||||||
|
|
||||||
This factory can only be applied to continuous action spaces.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, std_fraction: float):
|
def __init__(self, std_fraction: float):
|
||||||
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||||
|
|
||||||
|
This factory can only be applied to continuous action spaces.
|
||||||
|
|
||||||
|
:param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||||
be used as the standard deviation
|
be used as the standard deviation
|
||||||
"""
|
"""
|
||||||
self.std_fraction = std_fraction
|
self.std_fraction = std_fraction
|
||||||
|
@ -43,7 +43,9 @@ class ParamTransformerData:
|
|||||||
|
|
||||||
|
|
||||||
class ParamTransformer(ABC):
|
class ParamTransformer(ABC):
|
||||||
"""Transforms one or more parameters from the representation used by the high-level API
|
"""Base class for parameter transformations from high to low-level API.
|
||||||
|
|
||||||
|
Transforms one or more parameters from the representation used by the high-level API
|
||||||
to the representation required by the (low-level) policy implementation.
|
to the representation required by the (low-level) policy implementation.
|
||||||
It operates directly on a dictionary of keyword arguments, which is initially
|
It operates directly on a dictionary of keyword arguments, which is initially
|
||||||
generated from the parameter dataclass (subclass of `Params`).
|
generated from the parameter dataclass (subclass of `Params`).
|
||||||
@ -83,7 +85,9 @@ class ParamTransformerChangeValue(ParamTransformer):
|
|||||||
|
|
||||||
|
|
||||||
class ParamTransformerLRScheduler(ParamTransformer):
|
class ParamTransformerLRScheduler(ParamTransformer):
|
||||||
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
"""Transformer for learning rate scheduler params.
|
||||||
|
|
||||||
|
Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||||
a learning rate scheduler (added) for the data member `optim`.
|
a learning rate scheduler (added) for the data member `optim`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -100,12 +104,12 @@ class ParamTransformerLRScheduler(ParamTransformer):
|
|||||||
|
|
||||||
|
|
||||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||||
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
|
||||||
if more than one factory is indeed given.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
||||||
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
"""Transforms several scheduler factories into a single scheduler.
|
||||||
|
|
||||||
|
The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given.
|
||||||
|
|
||||||
|
:param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||||
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
||||||
"""
|
"""
|
||||||
self.optim_key_list = optim_key_list
|
self.optim_key_list = optim_key_list
|
||||||
|
@ -57,9 +57,9 @@ class PersistenceGroup(Persistence):
|
|||||||
|
|
||||||
|
|
||||||
class PolicyPersistence:
|
class PolicyPersistence:
|
||||||
"""Handles persistence of the policy."""
|
|
||||||
|
|
||||||
class Mode(Enum):
|
class Mode(Enum):
|
||||||
|
"""Mode of persistence."""
|
||||||
|
|
||||||
POLICY_STATE_DICT = "policy_state_dict"
|
POLICY_STATE_DICT = "policy_state_dict"
|
||||||
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
|
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
|
||||||
such a dictionary, it is necessary to first create a structurally equivalent object which can
|
such a dictionary, it is necessary to first create a structurally equivalent object which can
|
||||||
@ -81,7 +81,9 @@ class PolicyPersistence:
|
|||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
mode: Mode = Mode.POLICY,
|
mode: Mode = Mode.POLICY,
|
||||||
):
|
):
|
||||||
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
"""Handles persistence of the policy.
|
||||||
|
|
||||||
|
:param additional_persistence: a persistence instance which is to be invoked whenever
|
||||||
this object is used to persist/restore data
|
this object is used to persist/restore data
|
||||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||||
:param mode: the persistence mode
|
:param mode: the persistence mode
|
||||||
|
@ -51,7 +51,9 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||||
""":param mean_rewards: the average undiscounted returns of the testing result
|
"""Determines whether training should stop.
|
||||||
|
|
||||||
|
:param mean_rewards: the average undiscounted returns of the testing result
|
||||||
:param context: the training context
|
:param context: the training context
|
||||||
:return: True if the goal has been reached and training should stop, False otherwise
|
:return: True if the goal has been reached and training should stop, False otherwise
|
||||||
"""
|
"""
|
||||||
|
@ -159,8 +159,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
info: dict[str, Any] | None = None,
|
info: dict[str, Any] | None = None,
|
||||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||||
) -> np.ndarray | int:
|
) -> np.ndarray | int:
|
||||||
"""Get action as int (for discrete env's) or array (for continuous ones) from
|
"""Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info.
|
||||||
an env's observation and info.
|
|
||||||
|
|
||||||
:param obs: observation from the gym's env.
|
:param obs: observation from the gym's env.
|
||||||
:param info: information given by the gym's env.
|
:param info: information given by the gym's env.
|
||||||
|
@ -16,6 +16,8 @@ reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
|
||||||
class StringConverter(ABC):
|
class StringConverter(ABC):
|
||||||
"""Abstraction for a string conversion mechanism."""
|
"""Abstraction for a string conversion mechanism."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user