Remove parameter transformers from config object state,

composing the list dynamically instead
This commit is contained in:
Dominik Jain 2023-09-27 18:20:49 +02:00
parent 78b6dd1f49
commit acd89fa3b0

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass
from typing import Any, Literal from typing import Any, Literal, Protocol
import torch import torch
@ -159,39 +159,33 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
kwargs[self.key] = value.create_value(data.envs) kwargs[self.key] = value.create_value(data.envs)
class ITransformableParams(ABC): class GetParamTransformersProtocol(Protocol):
@abstractmethod def _get_param_transformers(self) -> list[ParamTransformer]:
def _add_transformer(self, transformer: ParamTransformer):
pass pass
@dataclass @dataclass
class Params(ITransformableParams): class Params(GetParamTransformersProtocol):
_param_transformers: list[ParamTransformer] = field(
init=False,
default_factory=list,
repr=False,
)
def _add_transformer(self, transformer: ParamTransformer):
self._param_transformers.append(transformer)
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self) params = asdict(self)
for transformer in self._param_transformers: for transformer in self._get_param_transformers():
transformer.transform(params, data) transformer.transform(params, data)
del params["_param_transformers"]
return params return params
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass @dataclass
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC): class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3 lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None lr_scheduler_factory: LRSchedulerFactory | None = None
def __post_init__(self): def _get_param_transformers(self):
self._add_transformer(ParamTransformerDrop("lr")) return [
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler")) ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
]
@dataclass @dataclass
@ -226,7 +220,7 @@ class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
@dataclass @dataclass
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC): class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
actor_lr: float = 1e-3 actor_lr: float = 1e-3
critic1_lr: float = 1e-3 critic1_lr: float = 1e-3
critic2_lr: float = 1e-3 critic2_lr: float = 1e-3
@ -234,16 +228,16 @@ class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def __post_init__(self): def _get_param_transformers(self):
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr")) return [
self._add_transformer( ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler( ParamTransformerActorDualCriticsLRScheduler(
"actor_lr_scheduler_factory", "actor_lr_scheduler_factory",
"critic1_lr_scheduler_factory", "critic1_lr_scheduler_factory",
"critic2_lr_scheduler_factory", "critic2_lr_scheduler_factory",
"lr_scheduler", "lr_scheduler",
), ),
) ]
@dataclass @dataclass
@ -257,10 +251,12 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
def __post_init__(self): def _get_param_transformers(self):
ParamsMixinActorAndDualCritics.__post_init__(self) transformers = super()._get_param_transformers()
self._add_transformer(ParamTransformerAutoAlpha("alpha")) transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) transformers.append(ParamTransformerAutoAlpha("alpha"))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass @dataclass
@ -275,9 +271,10 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
# TODO change to stateless variant def _get_param_transformers(self):
def __post_init__(self): transformers = super()._get_param_transformers()
ParamsMixinActorAndDualCritics.__post_init__(self) transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise")) transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise")) transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip")) transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
return transformers