* Use prefix convention (subclasses have superclass names as prefix) to facilitate discoverability of relevant classes via IDE autocompletion * Use dual naming, adding an alternative concise name that omits the precise OO semantics and retains only the essential part of the name (which can be more pleasing to users not accustomed to convoluted OO naming)
284 lines
10 KiB
Python
284 lines
10 KiB
Python
from abc import ABC, abstractmethod
|
|
from dataclasses import asdict, dataclass, field
|
|
from typing import Any, Literal
|
|
|
|
import torch
|
|
|
|
from tianshou.exploration import BaseNoise
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.module import ModuleOpt, TDevice
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
|
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
|
from tianshou.highlevel.params.noise import NoiseFactory
|
|
from tianshou.utils import MultipleLRSchedulers
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ParamTransformerData:
|
|
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
|
|
|
|
The representation contains the superset of all data items that are required by different types of agent factories.
|
|
An agent factory is expected to set only the attributes that are relevant to its parameters.
|
|
"""
|
|
|
|
envs: Environments
|
|
device: TDevice
|
|
optim_factory: OptimizerFactory
|
|
optim: torch.optim.Optimizer | None = None
|
|
"""the single optimizer for the case where there is just one"""
|
|
actor: ModuleOpt | None = None
|
|
critic1: ModuleOpt | None = None
|
|
critic2: ModuleOpt | None = None
|
|
|
|
|
|
class ParamTransformer(ABC):
|
|
@abstractmethod
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
|
|
value = d[key]
|
|
if drop:
|
|
del d[key]
|
|
return value
|
|
|
|
|
|
class ParamTransformerDrop(ParamTransformer):
|
|
def __init__(self, *keys: str):
|
|
self.keys = keys
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
for k in self.keys:
|
|
del kwargs[k]
|
|
|
|
|
|
class ParamTransformerLRScheduler(ParamTransformer):
|
|
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
|
a learning rate scheduler (added) for the data member `optim`.
|
|
"""
|
|
|
|
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
|
|
self.key_scheduler_factory = key_scheduler_factory
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
assert data.optim is not None
|
|
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
|
|
params[self.key_scheduler] = (
|
|
factory.create_scheduler(data.optim) if factory is not None else None
|
|
)
|
|
|
|
|
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
|
if more than one factory is indeed given.
|
|
"""
|
|
|
|
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
|
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
|
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
|
"""
|
|
self.optim_key_list = optim_key_list
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
lr_schedulers = []
|
|
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
|
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
|
params,
|
|
lr_scheduler_factory_key,
|
|
drop=True,
|
|
)
|
|
if lr_scheduler_factory is not None:
|
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
|
match len(lr_schedulers):
|
|
case 0:
|
|
lr_scheduler = None
|
|
case 1:
|
|
lr_scheduler = lr_schedulers[0]
|
|
case _:
|
|
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
|
params[self.key_scheduler] = lr_scheduler
|
|
|
|
|
|
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
|
def __init__(
|
|
self,
|
|
key_scheduler_factory_actor: str,
|
|
key_scheduler_factory_critic1: str,
|
|
key_scheduler_factory_critic2: str,
|
|
key_scheduler: str,
|
|
):
|
|
self.key_factory_actor = key_scheduler_factory_actor
|
|
self.key_factory_critic1 = key_scheduler_factory_critic1
|
|
self.key_factory_critic2 = key_scheduler_factory_critic2
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
transformer = ParamTransformerMultiLRScheduler(
|
|
[
|
|
(data.actor.optim, self.key_factory_actor),
|
|
(data.critic1.optim, self.key_factory_critic1),
|
|
(data.critic2.optim, self.key_factory_critic2),
|
|
],
|
|
self.key_scheduler,
|
|
)
|
|
transformer.transform(params, data)
|
|
|
|
|
|
class ParamTransformerAutoAlpha(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
alpha = self.get(kwargs, self.key)
|
|
if isinstance(alpha, AutoAlphaFactory):
|
|
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
|
|
|
|
|
class ParamTransformerNoiseFactory(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
value = params[self.key]
|
|
if isinstance(value, NoiseFactory):
|
|
params[self.key] = value.create_noise(data.envs)
|
|
|
|
|
|
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
value = kwargs[self.key]
|
|
if isinstance(value, EnvValueFactory):
|
|
kwargs[self.key] = value.create_value(data.envs)
|
|
|
|
|
|
class ITransformableParams(ABC):
|
|
@abstractmethod
|
|
def _add_transformer(self, transformer: ParamTransformer):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class Params(ITransformableParams):
|
|
_param_transformers: list[ParamTransformer] = field(
|
|
init=False,
|
|
default_factory=list,
|
|
repr=False,
|
|
)
|
|
|
|
def _add_transformer(self, transformer: ParamTransformer):
|
|
self._param_transformers.append(transformer)
|
|
|
|
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
|
params = asdict(self)
|
|
for transformer in self._param_transformers:
|
|
transformer.transform(params, data)
|
|
del params["_param_transformers"]
|
|
return params
|
|
|
|
|
|
@dataclass
|
|
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
|
|
lr: float = 1e-3
|
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
|
|
def __post_init__(self):
|
|
self._add_transformer(ParamTransformerDrop("lr"))
|
|
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
|
|
|
|
|
|
@dataclass
|
|
class PGParams(Params):
|
|
"""Config of general policy-gradient algorithms."""
|
|
|
|
discount_factor: float = 0.99
|
|
reward_normalization: bool = False
|
|
deterministic_eval: bool = False
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
|
|
|
|
|
@dataclass
|
|
class A2CParams(PGParams):
|
|
vf_coef: float = 0.5
|
|
ent_coef: float = 0.01
|
|
max_grad_norm: float | None = None
|
|
gae_lambda: float = 0.95
|
|
max_batchsize: int = 256
|
|
|
|
|
|
@dataclass
|
|
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
|
|
"""PPO specific config."""
|
|
|
|
eps_clip: float = 0.2
|
|
dual_clip: float | None = None
|
|
value_clip: bool = False
|
|
advantage_normalization: bool = True
|
|
recompute_advantage: bool = False
|
|
|
|
|
|
@dataclass
|
|
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
|
actor_lr: float = 1e-3
|
|
critic1_lr: float = 1e-3
|
|
critic2_lr: float = 1e-3
|
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
|
|
def __post_init__(self):
|
|
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
|
|
self._add_transformer(
|
|
ParamTransformerActorDualCriticsLRScheduler(
|
|
"actor_lr_scheduler_factory",
|
|
"critic1_lr_scheduler_factory",
|
|
"critic2_lr_scheduler_factory",
|
|
"lr_scheduler",
|
|
),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|
tau: float = 0.005
|
|
gamma: float = 0.99
|
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
|
estimation_step: int = 1
|
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
|
deterministic_eval: bool = True
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip"] | None = "clip"
|
|
|
|
def __post_init__(self):
|
|
ParamsMixinActorAndDualCritics.__post_init__(self)
|
|
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
|
|
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
|
|
|
|
|
@dataclass
|
|
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
|
tau: float = 0.005
|
|
gamma: float = 0.99
|
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
|
policy_noise: float | FloatEnvValueFactory = 0.2
|
|
noise_clip: float | FloatEnvValueFactory = 0.5
|
|
update_actor_freq: int = 2
|
|
estimation_step: int = 1
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip"] | None = "clip"
|
|
|
|
# TODO change to stateless variant
|
|
def __post_init__(self):
|
|
ParamsMixinActorAndDualCritics.__post_init__(self)
|
|
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
|
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
|
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
|