Docstring: minor changes to let ruff pass
This commit is contained in:
parent
28fda00b27
commit
2e39a252e3
@ -39,6 +39,7 @@ def write_to_file(content: str, path: str):
|
||||
|
||||
def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""):
|
||||
"""Creates/updates documentation in form of rst files for modules and packages.
|
||||
|
||||
Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed
|
||||
should be deleted by hand.
|
||||
|
||||
@ -116,6 +117,6 @@ if __name__ == "__main__":
|
||||
docs_root = Path(__file__).parent
|
||||
make_rst(
|
||||
docs_root / ".." / "tianshou",
|
||||
docs_root / "api",
|
||||
docs_root / "03_api",
|
||||
clean=True,
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ class ReplayBuffer:
|
||||
stores all the data in a batch with circular-queue style.
|
||||
|
||||
For the example usage of ReplayBuffer, please check out Section Buffer in
|
||||
:doc:`/tutorials/01_concepts`.
|
||||
:doc:`/01_tutorials/01_concepts`.
|
||||
|
||||
:param size: the maximum size of replay buffer.
|
||||
:param stack_num: the frame-stack sampling argument, should be greater than or
|
||||
|
@ -81,8 +81,7 @@ class Environments(ToStringMixin, ABC):
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and
|
||||
the type of environment (continuous/discrete).
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||
|
||||
:param factory_fn: the factory for a single environment instance
|
||||
:param env_type: the type of environments created by `factory_fn`
|
||||
@ -115,8 +114,7 @@ class Environments(ToStringMixin, ABC):
|
||||
}
|
||||
|
||||
def set_persistence(self, *p: Persistence) -> None:
|
||||
"""Associates the given persistence handlers which may persist and restore
|
||||
environment-specific information.
|
||||
"""Associates the given persistence handlers which may persist and restore environment-specific information.
|
||||
|
||||
:param p: persistence handlers
|
||||
"""
|
||||
|
@ -188,7 +188,9 @@ class Experiment(ToStringMixin):
|
||||
experiment_name: str | None = None,
|
||||
logger_run_id: str | None = None,
|
||||
) -> ExperimentResult:
|
||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
"""Run the experiment and return the results.
|
||||
|
||||
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||
a nested directory structure will be created.
|
||||
@ -327,6 +329,7 @@ class ExperimentBuilder:
|
||||
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
"""Allows to customize the logger factory to use.
|
||||
|
||||
If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used.
|
||||
|
||||
:param logger_factory: the factory to use
|
||||
@ -346,6 +349,7 @@ class ExperimentBuilder:
|
||||
|
||||
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||
"""Allows to customize the gradient-based optimizer to use.
|
||||
|
||||
By default, :class:`OptimizerFactoryAdam` will be used with default parameters.
|
||||
|
||||
:param optim_factory: the optimizer factory
|
||||
@ -390,6 +394,7 @@ class ExperimentBuilder:
|
||||
|
||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||
"""Allows to define a callback that decides whether training shall stop early.
|
||||
|
||||
The callback receives the undiscounted returns of the testing result.
|
||||
|
||||
:param callback: the callback
|
||||
@ -435,6 +440,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
|
||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||
"""Allows to customize the actor component via the specification of a factory.
|
||||
|
||||
If this function is not called, a default actor factory (with default parameters) will be used.
|
||||
|
||||
:param actor_factory: the factory to use for the creation of the actor network
|
||||
@ -450,7 +456,9 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
"""Adds a default actor factory with the given parameters.
|
||||
|
||||
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||
shall be computed from the input; if False, sigma is an independent parameter.
|
||||
@ -479,9 +487,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs
|
||||
Gaussian distribution parameters.
|
||||
"""
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||
@ -494,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||
|
||||
The default actor factory uses an MLP-style architecture.
|
||||
|
||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||
@ -523,6 +530,7 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
||||
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||
) -> Self:
|
||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||
|
||||
The default actor factory uses an MLP-style architecture.
|
||||
|
||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||
@ -700,6 +708,7 @@ class _BuilderMixinCriticEnsembleFactory:
|
||||
|
||||
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||
"""Specifies that the given factory shall be used for the critic ensemble.
|
||||
|
||||
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
||||
|
||||
:param factory: the critic ensemble factory
|
||||
|
@ -19,7 +19,9 @@ class LoggerFactory(ToStringMixin, ABC):
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
""":param log_dir: path to the directory in which log data is to be stored
|
||||
"""Creates the logger.
|
||||
|
||||
:param log_dir: path to the directory in which log data is to be stored
|
||||
:param experiment_name: the name of the job, which may contain `os.path.sep`
|
||||
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
||||
:param config_dict: a dictionary with data that is to be logged
|
||||
|
@ -168,7 +168,9 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||
conditioned_sigma: bool = False,
|
||||
activation: ModuleType = nn.ReLU,
|
||||
):
|
||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
"""For actors with Gaussian policies.
|
||||
|
||||
:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||
:param unbounded: whether to apply tanh activation on final logits
|
||||
:param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the
|
||||
input; if False, sigma is an independent parameter
|
||||
@ -229,9 +231,7 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
||||
|
||||
|
||||
class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||
"""Wraps an actor factory, storing the most recently created actor instance such that it
|
||||
can be retrieved.
|
||||
"""
|
||||
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
|
||||
|
||||
def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture):
|
||||
self.actor_factory = actor_factory
|
||||
|
@ -20,7 +20,9 @@ class OptimizerFactory(ABC, ToStringMixin):
|
||||
|
||||
class OptimizerFactoryTorch(OptimizerFactory):
|
||||
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
|
||||
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||
"""Factory for torch optimizers.
|
||||
|
||||
:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||
which will be passed the module parameters, the learning rate as `lr` and the
|
||||
kwargs provided.
|
||||
:param kwargs: keyword arguments to provide at optimizer construction
|
||||
|
@ -12,13 +12,12 @@ class NoiseFactory(ToStringMixin, ABC):
|
||||
|
||||
|
||||
class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||
|
||||
This factory can only be applied to continuous action spaces.
|
||||
"""
|
||||
|
||||
def __init__(self, std_fraction: float):
|
||||
""":param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||
|
||||
This factory can only be applied to continuous action spaces.
|
||||
|
||||
:param std_fraction: fraction (between 0 and 1) of the maximum action value that shall
|
||||
be used as the standard deviation
|
||||
"""
|
||||
self.std_fraction = std_fraction
|
||||
|
@ -43,7 +43,9 @@ class ParamTransformerData:
|
||||
|
||||
|
||||
class ParamTransformer(ABC):
|
||||
"""Transforms one or more parameters from the representation used by the high-level API
|
||||
"""Base class for parameter transformations from high to low-level API.
|
||||
|
||||
Transforms one or more parameters from the representation used by the high-level API
|
||||
to the representation required by the (low-level) policy implementation.
|
||||
It operates directly on a dictionary of keyword arguments, which is initially
|
||||
generated from the parameter dataclass (subclass of `Params`).
|
||||
@ -83,7 +85,9 @@ class ParamTransformerChangeValue(ParamTransformer):
|
||||
|
||||
|
||||
class ParamTransformerLRScheduler(ParamTransformer):
|
||||
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||
"""Transformer for learning rate scheduler params.
|
||||
|
||||
Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||
a learning rate scheduler (added) for the data member `optim`.
|
||||
"""
|
||||
|
||||
@ -100,12 +104,12 @@ class ParamTransformerLRScheduler(ParamTransformer):
|
||||
|
||||
|
||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
||||
if more than one factory is indeed given.
|
||||
"""
|
||||
|
||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
||||
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||
"""Transforms several scheduler factories into a single scheduler.
|
||||
|
||||
The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given.
|
||||
|
||||
:param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
||||
"""
|
||||
self.optim_key_list = optim_key_list
|
||||
|
@ -57,9 +57,9 @@ class PersistenceGroup(Persistence):
|
||||
|
||||
|
||||
class PolicyPersistence:
|
||||
"""Handles persistence of the policy."""
|
||||
|
||||
class Mode(Enum):
|
||||
"""Mode of persistence."""
|
||||
|
||||
POLICY_STATE_DICT = "policy_state_dict"
|
||||
"""Persist only the policy's state dictionary. Note that for a policy to be restored from
|
||||
such a dictionary, it is necessary to first create a structurally equivalent object which can
|
||||
@ -81,7 +81,9 @@ class PolicyPersistence:
|
||||
enabled: bool = True,
|
||||
mode: Mode = Mode.POLICY,
|
||||
):
|
||||
""":param additional_persistence: a persistence instance which is to be invoked whenever
|
||||
"""Handles persistence of the policy.
|
||||
|
||||
:param additional_persistence: a persistence instance which is to be invoked whenever
|
||||
this object is used to persist/restore data
|
||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||
:param mode: the persistence mode
|
||||
|
@ -51,7 +51,9 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
""":param mean_rewards: the average undiscounted returns of the testing result
|
||||
"""Determines whether training should stop.
|
||||
|
||||
:param mean_rewards: the average undiscounted returns of the testing result
|
||||
:param context: the training context
|
||||
:return: True if the goal has been reached and training should stop, False otherwise
|
||||
"""
|
||||
|
@ -159,8 +159,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
info: dict[str, Any] | None = None,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
) -> np.ndarray | int:
|
||||
"""Get action as int (for discrete env's) or array (for continuous ones) from
|
||||
an env's observation and info.
|
||||
"""Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info.
|
||||
|
||||
:param obs: observation from the gym's env.
|
||||
:param info: information given by the gym's env.
|
||||
|
@ -16,6 +16,8 @@ reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
|
||||
class StringConverter(ABC):
|
||||
"""Abstraction for a string conversion mechanism."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user