diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 44c569b..9fcebc6 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, field -from typing import Any, Literal +from dataclasses import asdict, dataclass +from typing import Any, Literal, Protocol import torch @@ -159,39 +159,33 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer): kwargs[self.key] = value.create_value(data.envs) -class ITransformableParams(ABC): - @abstractmethod - def _add_transformer(self, transformer: ParamTransformer): +class GetParamTransformersProtocol(Protocol): + def _get_param_transformers(self) -> list[ParamTransformer]: pass @dataclass -class Params(ITransformableParams): - _param_transformers: list[ParamTransformer] = field( - init=False, - default_factory=list, - repr=False, - ) - - def _add_transformer(self, transformer: ParamTransformer): - self._param_transformers.append(transformer) - +class Params(GetParamTransformersProtocol): def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: params = asdict(self) - for transformer in self._param_transformers: + for transformer in self._get_param_transformers(): transformer.transform(params, data) - del params["_param_transformers"] return params + def _get_param_transformers(self) -> list[ParamTransformer]: + return [] + @dataclass -class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC): +class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): lr: float = 1e-3 lr_scheduler_factory: LRSchedulerFactory | None = None - def __post_init__(self): - self._add_transformer(ParamTransformerDrop("lr")) - self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler")) + def _get_param_transformers(self): + return [ + ParamTransformerDrop("lr"), + ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), + ] @dataclass @@ -226,7 +220,7 @@ class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler): @dataclass -class ParamsMixinActorAndDualCritics(ITransformableParams, ABC): +class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): actor_lr: float = 1e-3 critic1_lr: float = 1e-3 critic2_lr: float = 1e-3 @@ -234,16 +228,16 @@ class ParamsMixinActorAndDualCritics(ITransformableParams, ABC): critic1_lr_scheduler_factory: LRSchedulerFactory | None = None critic2_lr_scheduler_factory: LRSchedulerFactory | None = None - def __post_init__(self): - self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr")) - self._add_transformer( + def _get_param_transformers(self): + return [ + ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), ParamTransformerActorDualCriticsLRScheduler( "actor_lr_scheduler_factory", "critic1_lr_scheduler_factory", "critic2_lr_scheduler_factory", "lr_scheduler", ), - ) + ] @dataclass @@ -257,10 +251,12 @@ class SACParams(Params, ParamsMixinActorAndDualCritics): action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" - def __post_init__(self): - ParamsMixinActorAndDualCritics.__post_init__(self) - self._add_transformer(ParamTransformerAutoAlpha("alpha")) - self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) + def _get_param_transformers(self): + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.append(ParamTransformerAutoAlpha("alpha")) + transformers.append(ParamTransformerNoiseFactory("exploration_noise")) + return transformers @dataclass @@ -275,9 +271,10 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics): action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" - # TODO change to stateless variant - def __post_init__(self): - ParamsMixinActorAndDualCritics.__post_init__(self) - self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) - self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise")) - self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip")) + def _get_param_transformers(self): + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.append(ParamTransformerNoiseFactory("exploration_noise")) + transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) + transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) + return transformers