Docstring: minor changes to let ruff pass

This commit is contained in:
Michael Panchenko 2023-12-04 13:52:10 +01:00
parent 28fda00b27
commit 2e39a252e3
13 changed files with 56 additions and 36 deletions

View File

@ -39,6 +39,7 @@ def write_to_file(content: str, path: str):
def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""):
"""Creates/updates documentation in form of rst files for modules and packages.
Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed
should be deleted by hand.
@ -116,6 +117,6 @@ if __name__ == "__main__":
docs_root = Path(__file__).parent
make_rst(
docs_root / ".." / "tianshou",
docs_root / "api",
docs_root / "03_api",
clean=True,
)

View File

@ -18,7 +18,7 @@ class ReplayBuffer:
stores all the data in a batch with circular-queue style.
For the example usage of ReplayBuffer, please check out Section Buffer in
:doc:`/tutorials/01_concepts`.
:doc:`/01_tutorials/01_concepts`.
:param size: the maximum size of replay buffer.
:param stack_num: the frame-stack sampling argument, should be greater than or

View File

@ -81,8 +81,7 @@ class Environments(ToStringMixin, ABC):
num_training_envs: int,
num_test_envs: int,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and
the type of environment (continuous/discrete).
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
:param factory_fn: the factory for a single environment instance
:param env_type: the type of environments created by `factory_fn`
@ -115,8 +114,7 @@ class Environments(ToStringMixin, ABC):
}
def set_persistence(self, *p: Persistence) -> None:
"""Associates the given persistence handlers which may persist and restore
environment-specific information.
"""Associates the given persistence handlers which may persist and restore environment-specific information.
:param p: persistence handlers
"""

View File

@ -188,7 +188,9 @@ class Experiment(ToStringMixin):
experiment_name: str | None = None,
logger_run_id: str | None = None,
) -> ExperimentResult:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
"""Run the experiment and return the results.
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
a nested directory structure will be created.
@ -327,6 +329,7 @@ class ExperimentBuilder:
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
"""Allows to customize the logger factory to use.
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
:param logger_factory: the factory to use
@ -346,6 +349,7 @@ class ExperimentBuilder:
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
"""Allows to customize the gradient-based optimizer to use.
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
:param optim_factory: the optimizer factory
@ -390,6 +394,7 @@ class ExperimentBuilder:
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
"""Allows to define a callback that decides whether training shall stop early.
The callback receives the undiscounted returns of the testing result.
:param callback: the callback
@ -435,6 +440,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
"""Allows to customize the actor component via the specification of a factory.
If this function is not called, a default actor factory (with default parameters) will be used.
:param actor_factory: the factory to use for the creation of the actor network
@ -450,7 +456,9 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
"""Adds a default actor factory with the given parameters.
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter.
@ -479,9 +487,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs
Gaussian distribution parameters.
"""
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters."""
def __init__(self) -> None:
super().__init__(ContinuousActorType.GAUSSIAN)
@ -494,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
continuous_conditioned_sigma: bool = False,
) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
@ -523,6 +530,7 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
@ -700,6 +708,7 @@ class _BuilderMixinCriticEnsembleFactory:
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
"""Specifies that the given factory shall be used for the critic ensemble.
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
:param factory: the critic ensemble factory

View File

@ -19,7 +19,9 @@ class LoggerFactory(ToStringMixin, ABC):
run_id: str | None,
config_dict: dict,
) -> TLogger:
""":param log_dir: path to the directory in which log data is to be stored
"""Creates the logger.
:param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain `os.path.sep`
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
:param config_dict: a dictionary with data that is to be logged

View File

@ -168,7 +168,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
conditioned_sigma: bool = False,
activation: ModuleType = nn.ReLU,
):
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
"""For actors with Gaussian policies.
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param unbounded: whether to apply tanh activation on final logits
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
input; if False, sigma is an independent parameter
@ -229,9 +231,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
class ActorFactoryTransientStorageDecorator(ActorFactory):
"""Wraps an actor factory, storing the most recently created actor instance such that it
can be retrieved.
"""
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
self.actor_factory = actor_factory

View File

@ -20,7 +20,9 @@ class OptimizerFactory(ABC, ToStringMixin):
class OptimizerFactoryTorch(OptimizerFactory):
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
"""Factory for torch optimizers.
:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
which will be passed the module parameters, the learning rate as `lr` and the
kwargs provided.
:param kwargs: keyword arguments to provide at optimizer construction

View File

@ -12,13 +12,12 @@ class NoiseFactory(ToStringMixin, ABC):
class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
This factory can only be applied to continuous action spaces.
"""
def __init__(self, std_fraction: float):
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
This factory can only be applied to continuous action spaces.
:param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
be used as the standard deviation
"""
self.std_fraction = std_fraction

View File

@ -43,7 +43,9 @@ class ParamTransformerData:
class ParamTransformer(ABC):
"""Transforms one or more parameters from the representation used by the high-level API
"""Base class for parameter transformations from high to low-level API.
Transforms one or more parameters from the representation used by the high-level API
to the representation required by the (low-level) policy implementation.
It operates directly on a dictionary of keyword arguments, which is initially
generated from the parameter dataclass (subclass of `Params`).
@ -83,7 +85,9 @@ class ParamTransformerChangeValue(ParamTransformer):
class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
"""Transformer for learning rate scheduler params.
Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`.
"""
@ -100,12 +104,12 @@ class ParamTransformerLRScheduler(ParamTransformer):
class ParamTransformerMultiLRScheduler(ParamTransformer):
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
if more than one factory is indeed given.
"""
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
"""Transforms several scheduler factories into a single scheduler.
The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given.
:param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
:param key_scheduler: the key under which to store the resulting learning rate scheduler
"""
self.optim_key_list = optim_key_list

View File

@ -57,9 +57,9 @@ class PersistenceGroup(Persistence):
class PolicyPersistence:
"""Handles persistence of the policy."""
class Mode(Enum):
"""Mode of persistence."""
POLICY_STATE_DICT = "policy_state_dict"
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
such a dictionary, it is necessary to first create a structurally equivalent object which can
@ -81,7 +81,9 @@ class PolicyPersistence:
enabled: bool = True,
mode: Mode = Mode.POLICY,
):
""":param additional_persistence: a persistence instance which is to be invoked whenever
"""Handles persistence of the policy.
:param additional_persistence: a persistence instance which is to be invoked whenever
this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled)
:param mode: the persistence mode

View File

@ -51,7 +51,9 @@ class TrainerStopCallback(ToStringMixin, ABC):
@abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
""":param mean_rewards: the average undiscounted returns of the testing result
"""Determines whether training should stop.
:param mean_rewards: the average undiscounted returns of the testing result
:param context: the training context
:return: True if the goal has been reached and training should stop, False otherwise
"""

View File

@ -159,8 +159,7 @@ class BasePolicy(ABC, nn.Module):
info: dict[str, Any] | None = None,
state: dict | BatchProtocol | np.ndarray | None = None,
) -> np.ndarray | int:
"""Get action as int (for discrete env's) or array (for continuous ones) from
an env's observation and info.
"""Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info.
:param obs: observation from the gym's env.
:param info: information given by the gym's env.

View File

@ -16,6 +16,8 @@ reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
log = logging.getLogger(__name__)
# ruff: noqa
class StringConverter(ABC):
"""Abstraction for a string conversion mechanism."""