Adapt class naming scheme

* Use prefix convention (subclasses have superclass names as prefix) to
  facilitate discoverability of relevant classes via IDE autocompletion
* Use dual naming, adding an alternative concise name that omits the
  precise OO semantics and retains only the essential part of the name
  (which can be more pleasing to users not accustomed to
  convoluted OO naming)
This commit is contained in:
Dominik Jain 2023-09-27 17:20:35 +02:00
parent 5bcf514c55
commit 78b6dd1f49
11 changed files with 73 additions and 49 deletions

View File

@ -14,7 +14,7 @@ from tianshou.highlevel.experiment import (
PPOExperimentBuilder, PPOExperimentBuilder,
RLExperimentConfig, RLExperimentConfig,
) )
from tianshou.highlevel.params.lr_scheduler import LinearLRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_params import PPOParams
@ -81,7 +81,7 @@ def main(
dual_clip=dual_clip, dual_clip=dual_clip,
recompute_advantage=recompute_adv, recompute_advantage=recompute_adv,
lr=lr, lr=lr,
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config) lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay if lr_decay
else None, else None,
), ),

View File

@ -12,7 +12,7 @@ from tianshou.highlevel.experiment import (
RLExperimentConfig, RLExperimentConfig,
SACExperimentBuilder, SACExperimentBuilder,
) )
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams from tianshou.highlevel.params.policy_params import SACParams
@ -62,7 +62,7 @@ def main(
SACParams( SACParams(
tau=tau, tau=tau,
gamma=gamma, gamma=gamma,
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha, alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step, estimation_step=n_step,
actor_lr=actor_lr, actor_lr=actor_lr,
critic1_lr=critic_lr, critic1_lr=critic_lr,

View File

@ -12,8 +12,10 @@ from tianshou.highlevel.experiment import (
RLExperimentConfig, RLExperimentConfig,
TD3ExperimentBuilder, TD3ExperimentBuilder,
) )
from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory from tianshou.highlevel.params.env_param import MaxActionScaled
from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory from tianshou.highlevel.params.noise import (
MaxActionScaledGaussian,
)
from tianshou.highlevel.params.policy_params import TD3Params from tianshou.highlevel.params.policy_params import TD3Params
@ -66,9 +68,9 @@ def main(
gamma=gamma, gamma=gamma,
estimation_step=n_step, estimation_step=n_step,
update_actor_freq=update_actor_freq, update_actor_freq=update_actor_freq,
noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip), noise_clip=MaxActionScaled(noise_clip),
policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise), policy_noise=MaxActionScaled(policy_noise),
exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise), exploration_noise=MaxActionScaledGaussian(exploration_noise),
actor_lr=actor_lr, actor_lr=actor_lr,
critic1_lr=critic_lr, critic1_lr=critic_lr,
critic2_lr=critic_lr, critic2_lr=critic_lr,

View File

@ -19,12 +19,12 @@ from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import ( from tianshou.highlevel.module import (
ActorFactory, ActorFactory,
ActorFactoryDefault,
ContinuousActorType, ContinuousActorType,
CriticFactory, CriticFactory,
DefaultActorFactory, CriticFactoryDefault,
DefaultCriticFactory,
) )
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
@ -175,7 +175,7 @@ class RLExperimentBuilder:
:param weight_decay: weight decay (L2 penalty) :param weight_decay: weight decay (L2 penalty)
:return: the builder :return: the builder
""" """
self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay) self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self return self
@abstractmethod @abstractmethod
@ -184,7 +184,7 @@ class RLExperimentBuilder:
def _get_optim_factory(self) -> OptimizerFactory: def _get_optim_factory(self) -> OptimizerFactory:
if self._optim_factory is None: if self._optim_factory is None:
return AdamOptimizerFactory() return OptimizerFactoryAdam()
else: else:
return self._optim_factory return self._optim_factory
@ -215,7 +215,7 @@ class _BuilderMixinActorFactory:
continuous_conditioned_sigma=False, continuous_conditioned_sigma=False,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = DefaultActorFactory( self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type, self._continuous_actor_type,
hidden_sizes, hidden_sizes,
continuous_unbounded=continuous_unbounded, continuous_unbounded=continuous_unbounded,
@ -225,7 +225,7 @@ class _BuilderMixinActorFactory:
def _get_actor_factory(self): def _get_actor_factory(self):
if self._actor_factory is None: if self._actor_factory is None:
return DefaultActorFactory(self._continuous_actor_type) return ActorFactoryDefault(self._continuous_actor_type)
else: else:
return self._actor_factory return self._actor_factory
@ -268,13 +268,13 @@ class _BuilderMixinCriticsFactory:
return self return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]): def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes) self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
return self return self
def _get_critic_factory(self, idx: int): def _get_critic_factory(self, idx: int):
factory = self._critic_factories[idx] factory = self._critic_factories[idx]
if factory is None: if factory is None:
return DefaultCriticFactory() return CriticFactoryDefault()
else: else:
return factory return factory
@ -290,7 +290,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def with_critic_factory_default( def with_critic_factory_default(
self: TBuilder, self: TBuilder,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory" self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
@ -309,7 +309,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_common_critic_factory_default( def with_common_critic_factory_default(
self, self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory" self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
@ -323,7 +323,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_critic1_factory_default( def with_critic1_factory_default(
self, self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory" self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
@ -336,7 +336,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_critic2_factory_default( def with_critic2_factory_default(
self, self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory" self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)

View File

@ -53,7 +53,7 @@ class ActorFactory(ABC):
m.weight.data.copy_(0.01 * m.weight.data) m.weight.data.copy_(0.01 * m.weight.data)
class DefaultActorFactory(ActorFactory): class ActorFactoryDefault(ActorFactory):
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy.""" """An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
DEFAULT_HIDDEN_SIZES = (64, 64) DEFAULT_HIDDEN_SIZES = (64, 64)
@ -75,13 +75,13 @@ class DefaultActorFactory(ActorFactory):
if env_type == EnvType.CONTINUOUS: if env_type == EnvType.CONTINUOUS:
match self.continuous_actor_type: match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN: case ContinuousActorType.GAUSSIAN:
factory = ContinuousActorFactoryGaussian( factory = ActorFactoryContinuousGaussian(
self.hidden_sizes, self.hidden_sizes,
unbounded=self.continuous_unbounded, unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma, conditioned_sigma=self.continuous_conditioned_sigma,
) )
case ContinuousActorType.DETERMINISTIC: case ContinuousActorType.DETERMINISTIC:
factory = ContinuousActorFactoryDeterministic(self.hidden_sizes) factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
case _: case _:
raise ValueError(self.continuous_actor_type) raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device) return factory.create_module(envs, device)
@ -91,11 +91,11 @@ class DefaultActorFactory(ActorFactory):
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
class ContinuousActorFactory(ActorFactory, ABC): class ActorFactoryContinuous(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces.""" """Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ContinuousActorFactoryDeterministic(ContinuousActorFactory): class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
@ -113,7 +113,7 @@ class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
).to(device) ).to(device)
class ContinuousActorFactoryGaussian(ContinuousActorFactory): class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded self.unbounded = unbounded
@ -148,7 +148,7 @@ class CriticFactory(ABC):
pass pass
class DefaultCriticFactory(CriticFactory): class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64) DEFAULT_HIDDEN_SIZES = (64, 64)
@ -159,7 +159,7 @@ class DefaultCriticFactory(CriticFactory):
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type() env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS: if env_type == EnvType.CONTINUOUS:
factory = ContinuousNetCriticFactory(self.hidden_sizes) factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action) return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
raise NotImplementedError raise NotImplementedError
@ -167,11 +167,11 @@ class DefaultCriticFactory(CriticFactory):
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
class ContinuousCriticFactory(CriticFactory, ABC): class CriticFactoryContinuous(CriticFactory, ABC):
pass pass
class ContinuousNetCriticFactory(ContinuousCriticFactory): class CriticFactoryContinuousNet(CriticFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes

View File

@ -6,13 +6,22 @@ from torch.optim import Adam
class OptimizerFactory(ABC): class OptimizerFactory(ABC):
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter.
# If we drop the assumption, we can't have that and will need to move the parameter
# to the optimizer factory, which is inconvenient for the user.
@abstractmethod @abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass pass
class TorchOptimizerFactory(OptimizerFactory): class OptimizerFactoryTorch(OptimizerFactory):
def __init__(self, optim_class: Any, **kwargs): def __init__(self, optim_class: Any, **kwargs):
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
which will be passed the module parameters, the learning rate as `lr` and the
kwargs provided.
:param kwargs: keyword arguments to provide at optimizer construction
"""
self.optim_class = optim_class self.optim_class = optim_class
self.kwargs = kwargs self.kwargs = kwargs
@ -20,7 +29,7 @@ class TorchOptimizerFactory(OptimizerFactory):
return self.optim_class(module.parameters(), lr=lr, **self.kwargs) return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
class AdamOptimizerFactory(OptimizerFactory): class OptimizerFactoryAdam(OptimizerFactory):
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0): def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.eps = eps self.eps = eps

View File

@ -19,7 +19,7 @@ class AutoAlphaFactory(ABC):
pass pass
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name? class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4): def __init__(self, lr: float = 3e-4):
self.lr = lr self.lr = lr

View File

@ -1,24 +1,32 @@
"""Factories for the generation of environment-dependent parameters.""" """Factories for the generation of environment-dependent parameters."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TypeVar from typing import Generic, TypeVar
from tianshou.highlevel.env import ContinuousEnvironments, Environments from tianshou.highlevel.env import ContinuousEnvironments, Environments
T = TypeVar("T") TValue = TypeVar("TValue")
TEnvs = TypeVar("TEnvs", bound=Environments)
class FloatEnvParamFactory(ABC): class EnvValueFactory(Generic[TValue, TEnvs], ABC):
@abstractmethod @abstractmethod
def create_param(self, envs: Environments) -> float: def create_value(self, envs: TEnvs) -> TValue:
pass pass
class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory): class FloatEnvValueFactory(EnvValueFactory[float, TEnvs], Generic[TEnvs], ABC):
"""Serves as a type bound for float value factories."""
class FloatEnvValueFactoryMaxActionScaled(FloatEnvValueFactory[ContinuousEnvironments]):
def __init__(self, value: float): def __init__(self, value: float):
""":param value: value with which to scale the max action value""" """:param value: value with which to scale the max action value"""
self.value = value self.value = value
def create_param(self, envs: Environments) -> float: def create_value(self, envs: ContinuousEnvironments) -> float:
envs.get_type().assert_continuous(self) envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments
return envs.max_action * self.value return envs.max_action * self.value
class MaxActionScaled(FloatEnvValueFactoryMaxActionScaled):
pass

View File

@ -13,7 +13,7 @@ class LRSchedulerFactory(ABC):
pass pass
class LinearLRSchedulerFactory(LRSchedulerFactory): class LRSchedulerFactoryLinear(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig): def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config self.sampling_config = sampling_config

View File

@ -10,7 +10,7 @@ class NoiseFactory(ABC):
pass pass
class MaxActionScaledGaussianNoiseFactory(NoiseFactory): class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value. """Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
This factory can only be applied to continuous action spaces. This factory can only be applied to continuous action spaces.
@ -23,3 +23,7 @@ class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
envs.get_type().assert_continuous(self) envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments envs: ContinuousEnvironments
return GaussianNoise(sigma=envs.max_action * self.std_fraction) return GaussianNoise(sigma=envs.max_action * self.std_fraction)
class MaxActionScaledGaussian(NoiseFactoryMaxActionScaledGaussian):
pass

View File

@ -9,7 +9,7 @@ from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import ModuleOpt, TDevice from tianshou.highlevel.module import ModuleOpt, TDevice
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.utils import MultipleLRSchedulers from tianshou.utils import MultipleLRSchedulers
@ -155,8 +155,8 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key] value = kwargs[self.key]
if isinstance(value, FloatEnvParamFactory): if isinstance(value, EnvValueFactory):
kwargs[self.key] = value.create_param(data.envs) kwargs[self.key] = value.create_value(data.envs)
class ITransformableParams(ABC): class ITransformableParams(ABC):
@ -268,13 +268,14 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005 tau: float = 0.005
gamma: float = 0.99 gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default" exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
policy_noise: float | FloatEnvParamFactory = 0.2 policy_noise: float | FloatEnvValueFactory = 0.2
noise_clip: float | FloatEnvParamFactory = 0.5 noise_clip: float | FloatEnvValueFactory = 0.5
update_actor_freq: int = 2 update_actor_freq: int = 2
estimation_step: int = 1 estimation_step: int = 1
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
# TODO change to stateless variant
def __post_init__(self): def __post_init__(self):
ParamsMixinActorAndDualCritics.__post_init__(self) ParamsMixinActorAndDualCritics.__post_init__(self)
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))