diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index b435f7b..f079bdb 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,7 +1,6 @@ import os from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any import torch @@ -19,12 +18,8 @@ from tianshou.highlevel.module import ( TDevice, ) from tianshou.highlevel.optim import OptimizerFactory -from tianshou.highlevel.params.alpha import AutoAlphaFactory -from tianshou.highlevel.params.env_param import FloatEnvParamFactory -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory -from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.policy_params import ( - ParamTransformer, + ParamTransformerData, PPOParams, SACParams, TD3Params, @@ -32,7 +27,6 @@ from tianshou.highlevel.params.policy_params import ( from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer -from tianshou.utils import MultipleLRSchedulers from tianshou.utils.net.common import ActorCritic CHECKPOINT_DICT_KEY_MODEL = "model" @@ -145,26 +139,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC): ) -class ParamTransformerDrop(ParamTransformer): - def __init__(self, *keys: str): - self.keys = keys - - def transform(self, kwargs: dict[str, Any]) -> None: - for k in self.keys: - del kwargs[k] - - -class ParamTransformerLRScheduler(ParamTransformer): - def __init__(self, optim: torch.optim.Optimizer): - self.optim = optim - - def transform(self, kwargs: dict[str, Any]) -> None: - factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True) - kwargs["lr_scheduler"] = ( - factory.create_scheduler(self.optim) if factory is not None else None - ) - - class _ActorMixin: def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory) @@ -269,8 +243,12 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) kwargs = self.params.create_kwargs( - ParamTransformerDrop("lr"), - ParamTransformerLRScheduler(actor_critic.optim), + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + optim=actor_critic.optim, + ), ) return PPOPolicy( actor=actor_critic.actor, @@ -282,43 +260,6 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): ) -class ParamTransformerAlpha(ParamTransformer): - def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice): - self.envs = envs - self.optim_factory = optim_factory - self.device = device - - def transform(self, kwargs: dict[str, Any]) -> None: - key = "alpha" - alpha = self.get(kwargs, key) - if isinstance(alpha, AutoAlphaFactory): - kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device) - - -class ParamTransformerMultiLRScheduler(ParamTransformer): - def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]): - self.optim_key_list = optim_key_list - - def transform(self, kwargs: dict[str, Any]) -> None: - lr_schedulers = [] - for optim, lr_scheduler_factory_key in self.optim_key_list: - lr_scheduler_factory: LRSchedulerFactory | None = self.get( - kwargs, - lr_scheduler_factory_key, - drop=True, - ) - if lr_scheduler_factory is not None: - lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) - match len(lr_schedulers): - case 0: - lr_scheduler = None - case 1: - lr_scheduler = lr_schedulers[0] - case _: - lr_scheduler = MultipleLRSchedulers(*lr_schedulers) - kwargs["lr_scheduler"] = lr_scheduler - - class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( self, @@ -346,15 +287,14 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) kwargs = self.params.create_kwargs( - ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), - ParamTransformerMultiLRScheduler( - [ - (actor.optim, "actor_lr_scheduler_factory"), - (critic1.optim, "critic1_lr_scheduler_factory"), - (critic2.optim, "critic2_lr_scheduler_factory"), - ], + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic1, + critic2=critic2, ), - ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device), ) return SACPolicy( actor=actor.module, @@ -369,28 +309,6 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): ) -class ParamTransformerNoiseFactory(ParamTransformer): - def __init__(self, key: str, envs: Environments): - self.key = key - self.envs = envs - - def transform(self, kwargs: dict[str, Any]) -> None: - value = kwargs[self.key] - if isinstance(value, NoiseFactory): - kwargs[self.key] = value.create_noise(self.envs) - - -class ParamTransformerFloatEnvParamFactory(ParamTransformer): - def __init__(self, key: str, envs: Environments): - self.key = key - self.envs = envs - - def transform(self, kwargs: dict[str, Any]) -> None: - value = kwargs[self.key] - if isinstance(value, FloatEnvParamFactory): - kwargs[self.key] = value.create_param(self.envs) - - class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( self, @@ -418,17 +336,14 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) kwargs = self.params.create_kwargs( - ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), - ParamTransformerMultiLRScheduler( - [ - (actor.optim, "actor_lr_scheduler_factory"), - (critic1.optim, "critic1_lr_scheduler_factory"), - (critic2.optim, "critic2_lr_scheduler_factory"), - ], + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic1, + critic2=critic2, ), - ParamTransformerNoiseFactory("exploration_noise", envs), - ParamTransformerFloatEnvParamFactory("policy_noise", envs), - ParamTransformerFloatEnvParamFactory("noise_clip", envs), ) return TD3Policy( actor=actor.module, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 03a0417..62700e7 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -1,19 +1,41 @@ from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Any, Literal import torch from tianshou.exploration import BaseNoise +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module import ModuleOpt, TDevice +from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.env_param import FloatEnvParamFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory +from tianshou.utils import MultipleLRSchedulers + + +@dataclass(kw_only=True) +class ParamTransformerData: + """Holds data that can be used by `ParamTransformer` instances to perform their transformation. + + The representation contains the superset of all data items that are required by different types of agent factories. + An agent factory is expected to set only the attributes that are relevant to its parameters. + """ + + envs: Environments + device: TDevice + optim_factory: OptimizerFactory + optim: torch.optim.Optimizer | None = None + """the single optimizer for the case where there is just one""" + actor: ModuleOpt | None = None + critic1: ModuleOpt | None = None + critic2: ModuleOpt | None = None class ParamTransformer(ABC): @abstractmethod - def transform(self, kwargs: dict[str, Any]) -> None: + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: pass @staticmethod @@ -24,13 +46,152 @@ class ParamTransformer(ABC): return value +class ParamTransformerDrop(ParamTransformer): + def __init__(self, *keys: str): + self.keys = keys + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + for k in self.keys: + del kwargs[k] + + +class ParamTransformerLRScheduler(ParamTransformer): + """Transforms a key containing a learning rate scheduler factory (removed) into a key containing + a learning rate scheduler (added) for the data member `optim`. + """ + + def __init__(self, key_scheduler_factory: str, key_scheduler: str): + self.key_scheduler_factory = key_scheduler_factory + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + assert data.optim is not None + factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True) + params[self.key_scheduler] = ( + factory.create_scheduler(data.optim) if factory is not None else None + ) + + +class ParamTransformerMultiLRScheduler(ParamTransformer): + """Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance + if more than one factory is indeed given. + """ + + def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str): + """:param optim_key_list: a list of tuples (optimizer, key of learning rate factory) + :param key_scheduler: the key under which to store the resulting learning rate scheduler + """ + self.optim_key_list = optim_key_list + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + lr_schedulers = [] + for optim, lr_scheduler_factory_key in self.optim_key_list: + lr_scheduler_factory: LRSchedulerFactory | None = self.get( + params, + lr_scheduler_factory_key, + drop=True, + ) + if lr_scheduler_factory is not None: + lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) + match len(lr_schedulers): + case 0: + lr_scheduler = None + case 1: + lr_scheduler = lr_schedulers[0] + case _: + lr_scheduler = MultipleLRSchedulers(*lr_schedulers) + params[self.key_scheduler] = lr_scheduler + + +class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): + def __init__( + self, + key_scheduler_factory_actor: str, + key_scheduler_factory_critic1: str, + key_scheduler_factory_critic2: str, + key_scheduler: str, + ): + self.key_factory_actor = key_scheduler_factory_actor + self.key_factory_critic1 = key_scheduler_factory_critic1 + self.key_factory_critic2 = key_scheduler_factory_critic2 + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + transformer = ParamTransformerMultiLRScheduler( + [ + (data.actor.optim, self.key_factory_actor), + (data.critic1.optim, self.key_factory_critic1), + (data.critic2.optim, self.key_factory_critic2), + ], + self.key_scheduler, + ) + transformer.transform(params, data) + + +class ParamTransformerAutoAlpha(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + alpha = self.get(kwargs, self.key) + if isinstance(alpha, AutoAlphaFactory): + kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) + + +class ParamTransformerNoiseFactory(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + value = params[self.key] + if isinstance(value, NoiseFactory): + params[self.key] = value.create_noise(data.envs) + + +class ParamTransformerFloatEnvParamFactory(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + value = kwargs[self.key] + if isinstance(value, FloatEnvParamFactory): + kwargs[self.key] = value.create_param(data.envs) + + +class ITransformableParams(ABC): + @abstractmethod + def _add_transformer(self, transformer: ParamTransformer): + pass + + @dataclass -class Params: - def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]: - d = asdict(self) - for transformer in transformers: - transformer.transform(d) - return d +class Params(ITransformableParams): + _param_transformers: list[ParamTransformer] = field( + init=False, + default_factory=list, + repr=False, + ) + + def _add_transformer(self, transformer: ParamTransformer): + self._param_transformers.append(transformer) + + def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: + params = asdict(self) + for transformer in self._param_transformers: + transformer.transform(params, data) + del params["_param_transformers"] + return params + + +@dataclass +class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC): + lr: float = 1e-3 + lr_scheduler_factory: LRSchedulerFactory | None = None + + def __post_init__(self): + self._add_transformer(ParamTransformerDrop("lr")) + self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler")) @dataclass @@ -54,7 +215,7 @@ class A2CParams(PGParams): @dataclass -class PPOParams(A2CParams): +class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler): """PPO specific config.""" eps_clip: float = 0.2 @@ -62,12 +223,10 @@ class PPOParams(A2CParams): value_clip: bool = False advantage_normalization: bool = True recompute_advantage: bool = False - lr: float = 1e-3 - lr_scheduler_factory: LRSchedulerFactory | None = None @dataclass -class ActorAndDualCriticsParams(Params): +class ParamsMixinActorAndDualCritics(ITransformableParams, ABC): actor_lr: float = 1e-3 critic1_lr: float = 1e-3 critic2_lr: float = 1e-3 @@ -75,21 +234,37 @@ class ActorAndDualCriticsParams(Params): critic1_lr_scheduler_factory: LRSchedulerFactory | None = None critic2_lr_scheduler_factory: LRSchedulerFactory | None = None + def __post_init__(self): + self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr")) + self._add_transformer( + ParamTransformerActorDualCriticsLRScheduler( + "actor_lr_scheduler_factory", + "critic1_lr_scheduler_factory", + "critic2_lr_scheduler_factory", + "lr_scheduler", + ), + ) + @dataclass -class SACParams(ActorAndDualCriticsParams): +class SACParams(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 gamma: float = 0.99 alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 estimation_step: int = 1 - exploration_noise: BaseNoise | Literal["default"] | None = None + exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None deterministic_eval: bool = True action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" + def __post_init__(self): + ParamsMixinActorAndDualCritics.__post_init__(self) + self._add_transformer(ParamTransformerAutoAlpha("alpha")) + self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) + @dataclass -class TD3Params(ActorAndDualCriticsParams): +class TD3Params(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 gamma: float = 0.99 exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default" @@ -99,3 +274,9 @@ class TD3Params(ActorAndDualCriticsParams): estimation_step: int = 1 action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" + + def __post_init__(self): + ParamsMixinActorAndDualCritics.__post_init__(self) + self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) + self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise")) + self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))