From 2e39a252e3c203c6ac86d82e5db2aed30823b222 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 4 Dec 2023 13:52:10 +0100 Subject: [PATCH] Docstring: minor changes to let ruff pass --- docs/autogen_rst.py | 3 ++- tianshou/data/buffer/base.py | 2 +- tianshou/highlevel/env.py | 6 ++---- tianshou/highlevel/experiment.py | 19 ++++++++++++++----- tianshou/highlevel/logger.py | 4 +++- tianshou/highlevel/module/actor.py | 8 ++++---- tianshou/highlevel/optim.py | 4 +++- tianshou/highlevel/params/noise.py | 11 +++++------ tianshou/highlevel/params/policy_params.py | 18 +++++++++++------- tianshou/highlevel/persistence.py | 8 +++++--- tianshou/highlevel/trainer.py | 4 +++- tianshou/policy/base.py | 3 +-- tianshou/utils/string.py | 2 ++ 13 files changed, 56 insertions(+), 36 deletions(-) diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index 3f2d5ea..26a414f 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -39,6 +39,7 @@ def write_to_file(content: str, path: str): def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""): """Creates/updates documentation in form of rst files for modules and packages. + Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed should be deleted by hand. @@ -116,6 +117,6 @@ if __name__ == "__main__": docs_root = Path(__file__).parent make_rst( docs_root / ".." / "tianshou", - docs_root / "api", + docs_root / "03_api", clean=True, ) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 1594697..53fe77f 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -18,7 +18,7 @@ class ReplayBuffer: stores all the data in a batch with circular-queue style. For the example usage of ReplayBuffer, please check out Section Buffer in - :doc:`/tutorials/01_concepts`. + :doc:`/01_tutorials/01_concepts`. :param size: the maximum size of replay buffer. :param stack_num: the frame-stack sampling argument, should be greater than or diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 9ead62e..f9c9639 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -81,8 +81,7 @@ class Environments(ToStringMixin, ABC): num_training_envs: int, num_test_envs: int, ) -> "Environments": - """Creates a suitable subtype instance from a factory function that creates a single instance and - the type of environment (continuous/discrete). + """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). :param factory_fn: the factory for a single environment instance :param env_type: the type of environments created by `factory_fn` @@ -115,8 +114,7 @@ class Environments(ToStringMixin, ABC): } def set_persistence(self, *p: Persistence) -> None: - """Associates the given persistence handlers which may persist and restore - environment-specific information. + """Associates the given persistence handlers which may persist and restore environment-specific information. :param p: persistence handlers """ diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 88a54d6..5dd5179 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -188,7 +188,9 @@ class Experiment(ToStringMixin): experiment_name: str | None = None, logger_run_id: str | None = None, ) -> ExperimentResult: - """:param experiment_name: the experiment name, which corresponds to the directory (within the logging + """Run the experiment and return the results. + + :param experiment_name: the experiment name, which corresponds to the directory (within the logging directory) where all results associated with the experiment will be saved. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. @@ -327,6 +329,7 @@ class ExperimentBuilder: def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: """Allows to customize the logger factory to use. + If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used. :param logger_factory: the factory to use @@ -346,6 +349,7 @@ class ExperimentBuilder: def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: """Allows to customize the gradient-based optimizer to use. + By default, :class:`OptimizerFactoryAdam` will be used with default parameters. :param optim_factory: the optimizer factory @@ -390,6 +394,7 @@ class ExperimentBuilder: def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self: """Allows to define a callback that decides whether training shall stop early. + The callback receives the undiscounted returns of the testing result. :param callback: the callback @@ -435,6 +440,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def with_actor_factory(self, actor_factory: ActorFactory) -> Self: """Allows to customize the actor component via the specification of a factory. + If this function is not called, a default actor factory (with default parameters) will be used. :param actor_factory: the factory to use for the creation of the actor network @@ -450,7 +456,9 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: - """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure + """Adds a default actor factory with the given parameters. + + :param hidden_sizes: the sequence of hidden dimensions to use in the network structure :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) shall be computed from the input; if False, sigma is an independent parameter. @@ -479,9 +487,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): - """Specialization of the actor mixin where, in the continuous case, the actor component outputs - Gaussian distribution parameters. - """ + """Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters.""" def __init__(self) -> None: super().__init__(ContinuousActorType.GAUSSIAN) @@ -494,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): continuous_conditioned_sigma: bool = False, ) -> Self: """Defines use of the default actor factory, allowing its parameters it to be customized. + The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network @@ -523,6 +530,7 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Defines use of the default actor factory, allowing its parameters it to be customized. + The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network @@ -700,6 +708,7 @@ class _BuilderMixinCriticEnsembleFactory: def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self: """Specifies that the given factory shall be used for the critic ensemble. + If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used. :param factory: the critic ensemble factory diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 396613a..a4fe772 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -19,7 +19,9 @@ class LoggerFactory(ToStringMixin, ABC): run_id: str | None, config_dict: dict, ) -> TLogger: - """:param log_dir: path to the directory in which log data is to be stored + """Creates the logger. + + :param log_dir: path to the directory in which log data is to be stored :param experiment_name: the name of the job, which may contain `os.path.sep` :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger :param config_dict: a dictionary with data that is to be logged diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 5a65329..dac3b79 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -168,7 +168,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): conditioned_sigma: bool = False, activation: ModuleType = nn.ReLU, ): - """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure + """For actors with Gaussian policies. + + :param hidden_sizes: the sequence of hidden dimensions to use in the network structure :param unbounded: whether to apply tanh activation on final logits :param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the input; if False, sigma is an independent parameter @@ -229,9 +231,7 @@ class ActorFactoryDiscreteNet(ActorFactory): class ActorFactoryTransientStorageDecorator(ActorFactory): - """Wraps an actor factory, storing the most recently created actor instance such that it - can be retrieved. - """ + """Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved.""" def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture): self.actor_factory = actor_factory diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 8697e05..0e754b1 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -20,7 +20,9 @@ class OptimizerFactory(ABC, ToStringMixin): class OptimizerFactoryTorch(OptimizerFactory): def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): - """:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), + """Factory for torch optimizers. + + :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), which will be passed the module parameters, the learning rate as `lr` and the kwargs provided. :param kwargs: keyword arguments to provide at optimizer construction diff --git a/tianshou/highlevel/params/noise.py b/tianshou/highlevel/params/noise.py index d3e4ce6..66e0c53 100644 --- a/tianshou/highlevel/params/noise.py +++ b/tianshou/highlevel/params/noise.py @@ -12,13 +12,12 @@ class NoiseFactory(ToStringMixin, ABC): class NoiseFactoryMaxActionScaledGaussian(NoiseFactory): - """Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value. - - This factory can only be applied to continuous action spaces. - """ - def __init__(self, std_fraction: float): - """:param std_fraction: fraction (between 0 and 1) of the maximum action value that shall + """Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value. + + This factory can only be applied to continuous action spaces. + + :param std_fraction: fraction (between 0 and 1) of the maximum action value that shall be used as the standard deviation """ self.std_fraction = std_fraction diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 68ac34f..373a413 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -43,7 +43,9 @@ class ParamTransformerData: class ParamTransformer(ABC): - """Transforms one or more parameters from the representation used by the high-level API + """Base class for parameter transformations from high to low-level API. + + Transforms one or more parameters from the representation used by the high-level API to the representation required by the (low-level) policy implementation. It operates directly on a dictionary of keyword arguments, which is initially generated from the parameter dataclass (subclass of `Params`). @@ -83,7 +85,9 @@ class ParamTransformerChangeValue(ParamTransformer): class ParamTransformerLRScheduler(ParamTransformer): - """Transforms a key containing a learning rate scheduler factory (removed) into a key containing + """Transformer for learning rate scheduler params. + + Transforms a key containing a learning rate scheduler factory (removed) into a key containing a learning rate scheduler (added) for the data member `optim`. """ @@ -100,12 +104,12 @@ class ParamTransformerLRScheduler(ParamTransformer): class ParamTransformerMultiLRScheduler(ParamTransformer): - """Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance - if more than one factory is indeed given. - """ - def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str): - """:param optim_key_list: a list of tuples (optimizer, key of learning rate factory) + """Transforms several scheduler factories into a single scheduler. + + The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given. + + :param optim_key_list: a list of tuples (optimizer, key of learning rate factory) :param key_scheduler: the key under which to store the resulting learning rate scheduler """ self.optim_key_list = optim_key_list diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 4fc4f9c..52b18d1 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -57,9 +57,9 @@ class PersistenceGroup(Persistence): class PolicyPersistence: - """Handles persistence of the policy.""" - class Mode(Enum): + """Mode of persistence.""" + POLICY_STATE_DICT = "policy_state_dict" """Persist only the policy's state dictionary. Note that for a policy to be restored from such a dictionary, it is necessary to first create a structurally equivalent object which can @@ -81,7 +81,9 @@ class PolicyPersistence: enabled: bool = True, mode: Mode = Mode.POLICY, ): - """:param additional_persistence: a persistence instance which is to be invoked whenever + """Handles persistence of the policy. + + :param additional_persistence: a persistence instance which is to be invoked whenever this object is used to persist/restore data :param enabled: whether persistence is enabled (restoration is always enabled) :param mode: the persistence mode diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 2388e2a..7f70563 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -51,7 +51,9 @@ class TrainerStopCallback(ToStringMixin, ABC): @abstractmethod def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: - """:param mean_rewards: the average undiscounted returns of the testing result + """Determines whether training should stop. + + :param mean_rewards: the average undiscounted returns of the testing result :param context: the training context :return: True if the goal has been reached and training should stop, False otherwise """ diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3c8b8ad..3de0d0f 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -159,8 +159,7 @@ class BasePolicy(ABC, nn.Module): info: dict[str, Any] | None = None, state: dict | BatchProtocol | np.ndarray | None = None, ) -> np.ndarray | int: - """Get action as int (for discrete env's) or array (for continuous ones) from - an env's observation and info. + """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. :param obs: observation from the gym's env. :param info: information given by the gym's env. diff --git a/tianshou/utils/string.py b/tianshou/utils/string.py index 1d599a3..445f2e9 100644 --- a/tianshou/utils/string.py +++ b/tianshou/utils/string.py @@ -16,6 +16,8 @@ reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+") log = logging.getLogger(__name__) +# ruff: noqa + class StringConverter(ABC): """Abstraction for a string conversion mechanism."""