from abc import ABC, abstractmethod from dataclasses import asdict, dataclass from typing import Any, Literal, Protocol import torch from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.dist_fn import ( DistributionFunctionFactory, DistributionFunctionFactoryDefault, TDistributionFunction, ) from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory from tianshou.utils import MultipleLRSchedulers @dataclass(kw_only=True) class ParamTransformerData: """Holds data that can be used by `ParamTransformer` instances to perform their transformation. The representation contains the superset of all data items that are required by different types of agent factories. An agent factory is expected to set only the attributes that are relevant to its parameters. """ envs: Environments device: TDevice optim_factory: OptimizerFactory optim: torch.optim.Optimizer | None = None """the single optimizer for the case where there is just one""" actor: ModuleOpt | None = None critic1: ModuleOpt | None = None critic2: ModuleOpt | None = None class ParamTransformer(ABC): """Transforms one or more parameters from the representation used by the high-level API to the representation required by the (low-level) policy implementation. It operates directly on a dictionary of keyword arguments, which is initially generated from the parameter dataclass (subclass of `Params`). """ @abstractmethod def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: pass @staticmethod def get(d: dict[str, Any], key: str, drop: bool = False) -> Any: value = d[key] if drop: del d[key] return value class ParamTransformerDrop(ParamTransformer): def __init__(self, *keys: str): self.keys = keys def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: for k in self.keys: del kwargs[k] class ParamTransformerChangeValue(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, params: dict[str, Any], data: ParamTransformerData): params[self.key] = self.change_value(params[self.key], data) @abstractmethod def change_value(self, value: Any, data: ParamTransformerData) -> Any: pass class ParamTransformerLRScheduler(ParamTransformer): """Transforms a key containing a learning rate scheduler factory (removed) into a key containing a learning rate scheduler (added) for the data member `optim`. """ def __init__(self, key_scheduler_factory: str, key_scheduler: str): self.key_scheduler_factory = key_scheduler_factory self.key_scheduler = key_scheduler def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: assert data.optim is not None factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True) params[self.key_scheduler] = ( factory.create_scheduler(data.optim) if factory is not None else None ) class ParamTransformerMultiLRScheduler(ParamTransformer): """Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance if more than one factory is indeed given. """ def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str): """:param optim_key_list: a list of tuples (optimizer, key of learning rate factory) :param key_scheduler: the key under which to store the resulting learning rate scheduler """ self.optim_key_list = optim_key_list self.key_scheduler = key_scheduler def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: lr_schedulers = [] for optim, lr_scheduler_factory_key in self.optim_key_list: lr_scheduler_factory: LRSchedulerFactory | None = self.get( params, lr_scheduler_factory_key, drop=True, ) if lr_scheduler_factory is not None: lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) match len(lr_schedulers): case 0: lr_scheduler = None case 1: lr_scheduler = lr_schedulers[0] case _: lr_scheduler = MultipleLRSchedulers(*lr_schedulers) params[self.key_scheduler] = lr_scheduler class ParamTransformerActorAndCriticLRScheduler(ParamTransformer): def __init__( self, key_scheduler_factory_actor: str, key_scheduler_factory_critic: str, key_scheduler: str, ): self.key_factory_actor = key_scheduler_factory_actor self.key_factory_critic = key_scheduler_factory_critic self.key_scheduler = key_scheduler def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: transformer = ParamTransformerMultiLRScheduler( [ (data.actor.optim, self.key_factory_actor), (data.critic1.optim, self.key_factory_critic), ], self.key_scheduler, ) transformer.transform(params, data) class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): def __init__( self, key_scheduler_factory_actor: str, key_scheduler_factory_critic1: str, key_scheduler_factory_critic2: str, key_scheduler: str, ): self.key_factory_actor = key_scheduler_factory_actor self.key_factory_critic1 = key_scheduler_factory_critic1 self.key_factory_critic2 = key_scheduler_factory_critic2 self.key_scheduler = key_scheduler def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: transformer = ParamTransformerMultiLRScheduler( [ (data.actor.optim, self.key_factory_actor), (data.critic1.optim, self.key_factory_critic1), (data.critic2.optim, self.key_factory_critic2), ], self.key_scheduler, ) transformer.transform(params, data) class ParamTransformerAutoAlpha(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: alpha = self.get(kwargs, self.key) if isinstance(alpha, AutoAlphaFactory): kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) class ParamTransformerNoiseFactory(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: value = params[self.key] if isinstance(value, NoiseFactory): params[self.key] = value.create_noise(data.envs) class ParamTransformerFloatEnvParamFactory(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: value = kwargs[self.key] if isinstance(value, EnvValueFactory): kwargs[self.key] = value.create_value(data.envs) class ParamTransformerDistributionFunction(ParamTransformer): def __init__(self, key: str): self.key = key def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: value = kwargs[self.key] if value == "default": kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) elif isinstance(value, DistributionFunctionFactory): kwargs[self.key] = value.create_dist_fn(data.envs) class ParamTransformerActionScaling(ParamTransformerChangeValue): def change_value(self, value: Any, data: ParamTransformerData) -> Any: if value == "default": return data.envs.get_type().is_continuous() else: return value class GetParamTransformersProtocol(Protocol): def _get_param_transformers(self) -> list[ParamTransformer]: pass @dataclass class Params(GetParamTransformersProtocol): def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: params = asdict(self) for transformer in self._get_param_transformers(): transformer.transform(params, data) return params def _get_param_transformers(self) -> list[ParamTransformer]: return [] @dataclass class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): lr: float = 1e-3 lr_scheduler_factory: LRSchedulerFactory | None = None def _get_param_transformers(self): return [ ParamTransformerDrop("lr"), ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), ] @dataclass class ParamsMixinActorAndCritic(GetParamTransformersProtocol): actor_lr: float = 1e-3 critic_lr: float = 1e-3 actor_lr_scheduler_factory: LRSchedulerFactory | None = None critic_lr_scheduler_factory: LRSchedulerFactory | None = None def _get_param_transformers(self): return [ ParamTransformerDrop("actor_lr", "critic_lr"), ParamTransformerActorAndCriticLRScheduler( "actor_lr_scheduler_factory", "critic_lr_scheduler_factory", "lr_scheduler", ), ] @dataclass class PGParams(Params): """Config of general policy-gradient algorithms.""" discount_factor: float = 0.99 reward_normalization: bool = False deterministic_eval: bool = False action_scaling: bool | Literal["default"] = "default" """whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces""" action_bound_method: Literal["clip", "tanh"] | None = "clip" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.append(ParamTransformerActionScaling("action_scaling")) return transformers @dataclass class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler): vf_coef: float = 0.5 ent_coef: float = 0.01 max_grad_norm: float | None = None gae_lambda: float = 0.95 max_batchsize: int = 256 dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) transformers.append(ParamTransformerDistributionFunction("dist_fn")) return transformers @dataclass class PPOParams(A2CParams): """PPO specific config.""" eps_clip: float = 0.2 dual_clip: float | None = None value_clip: bool = False advantage_normalization: bool = True recompute_advantage: bool = False @dataclass class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): actor_lr: float = 1e-3 critic1_lr: float = 1e-3 critic2_lr: float = 1e-3 actor_lr_scheduler_factory: LRSchedulerFactory | None = None critic1_lr_scheduler_factory: LRSchedulerFactory | None = None critic2_lr_scheduler_factory: LRSchedulerFactory | None = None def _get_param_transformers(self): return [ ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), ParamTransformerActorDualCriticsLRScheduler( "actor_lr_scheduler_factory", "critic1_lr_scheduler_factory", "critic2_lr_scheduler_factory", "lr_scheduler", ), ] @dataclass class SACParams(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 gamma: float = 0.99 alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 estimation_step: int = 1 exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None deterministic_eval: bool = True action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" def _get_param_transformers(self): transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.append(ParamTransformerAutoAlpha("alpha")) transformers.append(ParamTransformerNoiseFactory("exploration_noise")) return transformers @dataclass class DQNParams(Params, ParamsMixinLearningRateWithScheduler): discount_factor: float = 0.99 estimation_step: int = 1 target_update_freq: int = 0 reward_normalization: bool = False is_double: bool = True clip_loss_grad: bool = False def _get_param_transformers(self): transformers = super()._get_param_transformers() transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) return transformers @dataclass class DDPGParams(Params, ParamsMixinActorAndCritic): tau: float = 0.005 gamma: float = 0.99 exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default" estimation_step: int = 1 action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" def _get_param_transformers(self): transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) transformers.append(ParamTransformerNoiseFactory("exploration_noise")) return transformers @dataclass class TD3Params(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 gamma: float = 0.99 exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default" policy_noise: float | FloatEnvValueFactory = 0.2 noise_clip: float | FloatEnvValueFactory = 0.5 update_actor_freq: int = 2 estimation_step: int = 1 action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" def _get_param_transformers(self): transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.append(ParamTransformerNoiseFactory("exploration_noise")) transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) return transformers