Adapt class naming scheme
* Use prefix convention (subclasses have superclass names as prefix) to facilitate discoverability of relevant classes via IDE autocompletion * Use dual naming, adding an alternative concise name that omits the precise OO semantics and retains only the essential part of the name (which can be more pleasing to users not accustomed to convoluted OO naming)
This commit is contained in:
parent
5bcf514c55
commit
78b6dd1f49
@ -14,7 +14,7 @@ from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LinearLRSchedulerFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ def main(
|
||||
dual_clip=dual_clip,
|
||||
recompute_advantage=recompute_adv,
|
||||
lr=lr,
|
||||
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
|
||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||
if lr_decay
|
||||
else None,
|
||||
),
|
||||
|
@ -12,7 +12,7 @@ from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
SACExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||
from tianshou.highlevel.params.policy_params import SACParams
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ def main(
|
||||
SACParams(
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,
|
||||
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
|
||||
estimation_step=n_step,
|
||||
actor_lr=actor_lr,
|
||||
critic1_lr=critic_lr,
|
||||
|
@ -12,8 +12,10 @@ from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
TD3ExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory
|
||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory
|
||||
from tianshou.highlevel.params.env_param import MaxActionScaled
|
||||
from tianshou.highlevel.params.noise import (
|
||||
MaxActionScaledGaussian,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import TD3Params
|
||||
|
||||
|
||||
@ -66,9 +68,9 @@ def main(
|
||||
gamma=gamma,
|
||||
estimation_step=n_step,
|
||||
update_actor_freq=update_actor_freq,
|
||||
noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip),
|
||||
policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise),
|
||||
exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise),
|
||||
noise_clip=MaxActionScaled(noise_clip),
|
||||
policy_noise=MaxActionScaled(policy_noise),
|
||||
exploration_noise=MaxActionScaledGaussian(exploration_noise),
|
||||
actor_lr=actor_lr,
|
||||
critic1_lr=critic_lr,
|
||||
critic2_lr=critic_lr,
|
||||
|
@ -19,12 +19,12 @@ from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ActorFactory,
|
||||
ActorFactoryDefault,
|
||||
ContinuousActorType,
|
||||
CriticFactory,
|
||||
DefaultActorFactory,
|
||||
DefaultCriticFactory,
|
||||
CriticFactoryDefault,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -175,7 +175,7 @@ class RLExperimentBuilder:
|
||||
:param weight_decay: weight decay (L2 penalty)
|
||||
:return: the builder
|
||||
"""
|
||||
self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
@ -184,7 +184,7 @@ class RLExperimentBuilder:
|
||||
|
||||
def _get_optim_factory(self) -> OptimizerFactory:
|
||||
if self._optim_factory is None:
|
||||
return AdamOptimizerFactory()
|
||||
return OptimizerFactoryAdam()
|
||||
else:
|
||||
return self._optim_factory
|
||||
|
||||
@ -215,7 +215,7 @@ class _BuilderMixinActorFactory:
|
||||
continuous_conditioned_sigma=False,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | _BuilderMixinActorFactory
|
||||
self._actor_factory = DefaultActorFactory(
|
||||
self._actor_factory = ActorFactoryDefault(
|
||||
self._continuous_actor_type,
|
||||
hidden_sizes,
|
||||
continuous_unbounded=continuous_unbounded,
|
||||
@ -225,7 +225,7 @@ class _BuilderMixinActorFactory:
|
||||
|
||||
def _get_actor_factory(self):
|
||||
if self._actor_factory is None:
|
||||
return DefaultActorFactory(self._continuous_actor_type)
|
||||
return ActorFactoryDefault(self._continuous_actor_type)
|
||||
else:
|
||||
return self._actor_factory
|
||||
|
||||
@ -268,13 +268,13 @@ class _BuilderMixinCriticsFactory:
|
||||
return self
|
||||
|
||||
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
|
||||
self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes)
|
||||
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
||||
return self
|
||||
|
||||
def _get_critic_factory(self, idx: int):
|
||||
factory = self._critic_factories[idx]
|
||||
if factory is None:
|
||||
return DefaultCriticFactory()
|
||||
return CriticFactoryDefault()
|
||||
else:
|
||||
return factory
|
||||
|
||||
@ -290,7 +290,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
def with_critic_factory_default(
|
||||
self: TBuilder,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
@ -309,7 +309,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
def with_common_critic_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
for i in range(len(self._critic_factories)):
|
||||
@ -323,7 +323,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
def with_critic1_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
@ -336,7 +336,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
def with_critic2_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
|
@ -53,7 +53,7 @@ class ActorFactory(ABC):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
|
||||
class DefaultActorFactory(ActorFactory):
|
||||
class ActorFactoryDefault(ActorFactory):
|
||||
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
@ -75,13 +75,13 @@ class DefaultActorFactory(ActorFactory):
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
match self.continuous_actor_type:
|
||||
case ContinuousActorType.GAUSSIAN:
|
||||
factory = ContinuousActorFactoryGaussian(
|
||||
factory = ActorFactoryContinuousGaussian(
|
||||
self.hidden_sizes,
|
||||
unbounded=self.continuous_unbounded,
|
||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||
)
|
||||
case ContinuousActorType.DETERMINISTIC:
|
||||
factory = ContinuousActorFactoryDeterministic(self.hidden_sizes)
|
||||
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
|
||||
case _:
|
||||
raise ValueError(self.continuous_actor_type)
|
||||
return factory.create_module(envs, device)
|
||||
@ -91,11 +91,11 @@ class DefaultActorFactory(ActorFactory):
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class ContinuousActorFactory(ActorFactory, ABC):
|
||||
class ActorFactoryContinuous(ActorFactory, ABC):
|
||||
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||
|
||||
|
||||
class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
|
||||
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
@ -113,7 +113,7 @@ class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
|
||||
).to(device)
|
||||
|
||||
|
||||
class ContinuousActorFactoryGaussian(ContinuousActorFactory):
|
||||
class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.unbounded = unbounded
|
||||
@ -148,7 +148,7 @@ class CriticFactory(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCriticFactory(CriticFactory):
|
||||
class CriticFactoryDefault(CriticFactory):
|
||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
@ -159,7 +159,7 @@ class DefaultCriticFactory(CriticFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = ContinuousNetCriticFactory(self.hidden_sizes)
|
||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
@ -167,11 +167,11 @@ class DefaultCriticFactory(CriticFactory):
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class ContinuousCriticFactory(CriticFactory, ABC):
|
||||
class CriticFactoryContinuous(CriticFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
|
@ -6,13 +6,22 @@ from torch.optim import Adam
|
||||
|
||||
|
||||
class OptimizerFactory(ABC):
|
||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
||||
# Right now, the learning rate is typically a configuration parameter.
|
||||
# If we drop the assumption, we can't have that and will need to move the parameter
|
||||
# to the optimizer factory, which is inconvenient for the user.
|
||||
@abstractmethod
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
||||
|
||||
class TorchOptimizerFactory(OptimizerFactory):
|
||||
class OptimizerFactoryTorch(OptimizerFactory):
|
||||
def __init__(self, optim_class: Any, **kwargs):
|
||||
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||
which will be passed the module parameters, the learning rate as `lr` and the
|
||||
kwargs provided.
|
||||
:param kwargs: keyword arguments to provide at optimizer construction
|
||||
"""
|
||||
self.optim_class = optim_class
|
||||
self.kwargs = kwargs
|
||||
|
||||
@ -20,7 +29,7 @@ class TorchOptimizerFactory(OptimizerFactory):
|
||||
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
|
||||
|
||||
|
||||
class AdamOptimizerFactory(OptimizerFactory):
|
||||
class OptimizerFactoryAdam(OptimizerFactory):
|
||||
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
|
||||
self.weight_decay = weight_decay
|
||||
self.eps = eps
|
||||
|
@ -19,7 +19,7 @@ class AutoAlphaFactory(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
||||
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
||||
def __init__(self, lr: float = 3e-4):
|
||||
self.lr = lr
|
||||
|
||||
|
@ -1,24 +1,32 @@
|
||||
"""Factories for the generation of environment-dependent parameters."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||
|
||||
T = TypeVar("T")
|
||||
TValue = TypeVar("TValue")
|
||||
TEnvs = TypeVar("TEnvs", bound=Environments)
|
||||
|
||||
|
||||
class FloatEnvParamFactory(ABC):
|
||||
class EnvValueFactory(Generic[TValue, TEnvs], ABC):
|
||||
@abstractmethod
|
||||
def create_param(self, envs: Environments) -> float:
|
||||
def create_value(self, envs: TEnvs) -> TValue:
|
||||
pass
|
||||
|
||||
|
||||
class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory):
|
||||
class FloatEnvValueFactory(EnvValueFactory[float, TEnvs], Generic[TEnvs], ABC):
|
||||
"""Serves as a type bound for float value factories."""
|
||||
|
||||
|
||||
class FloatEnvValueFactoryMaxActionScaled(FloatEnvValueFactory[ContinuousEnvironments]):
|
||||
def __init__(self, value: float):
|
||||
""":param value: value with which to scale the max action value"""
|
||||
self.value = value
|
||||
|
||||
def create_param(self, envs: Environments) -> float:
|
||||
def create_value(self, envs: ContinuousEnvironments) -> float:
|
||||
envs.get_type().assert_continuous(self)
|
||||
envs: ContinuousEnvironments
|
||||
return envs.max_action * self.value
|
||||
|
||||
|
||||
class MaxActionScaled(FloatEnvValueFactoryMaxActionScaled):
|
||||
pass
|
||||
|
@ -13,7 +13,7 @@ class LRSchedulerFactory(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
|
@ -10,7 +10,7 @@ class NoiseFactory(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
|
||||
class NoiseFactoryMaxActionScaledGaussian(NoiseFactory):
|
||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||
|
||||
This factory can only be applied to continuous action spaces.
|
||||
@ -23,3 +23,7 @@ class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
|
||||
envs.get_type().assert_continuous(self)
|
||||
envs: ContinuousEnvironments
|
||||
return GaussianNoise(sigma=envs.max_action * self.std_fraction)
|
||||
|
||||
|
||||
class MaxActionScaledGaussian(NoiseFactoryMaxActionScaledGaussian):
|
||||
pass
|
||||
|
@ -9,7 +9,7 @@ from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module import ModuleOpt, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
@ -155,8 +155,8 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||
|
||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
value = kwargs[self.key]
|
||||
if isinstance(value, FloatEnvParamFactory):
|
||||
kwargs[self.key] = value.create_param(data.envs)
|
||||
if isinstance(value, EnvValueFactory):
|
||||
kwargs[self.key] = value.create_value(data.envs)
|
||||
|
||||
|
||||
class ITransformableParams(ABC):
|
||||
@ -268,13 +268,14 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
||||
policy_noise: float | FloatEnvParamFactory = 0.2
|
||||
noise_clip: float | FloatEnvParamFactory = 0.5
|
||||
policy_noise: float | FloatEnvValueFactory = 0.2
|
||||
noise_clip: float | FloatEnvValueFactory = 0.5
|
||||
update_actor_freq: int = 2
|
||||
estimation_step: int = 1
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
# TODO change to stateless variant
|
||||
def __post_init__(self):
|
||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user