Remove parameter transformers from config object state,

composing the list dynamically instead
This commit is contained in:
Dominik Jain 2023-09-27 18:20:49 +02:00
parent 78b6dd1f49
commit acd89fa3b0

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, Literal
from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol
import torch
@ -159,39 +159,33 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
kwargs[self.key] = value.create_value(data.envs)
class ITransformableParams(ABC):
@abstractmethod
def _add_transformer(self, transformer: ParamTransformer):
class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]:
pass
@dataclass
class Params(ITransformableParams):
_param_transformers: list[ParamTransformer] = field(
init=False,
default_factory=list,
repr=False,
)
def _add_transformer(self, transformer: ParamTransformer):
self._param_transformers.append(transformer)
class Params(GetParamTransformersProtocol):
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self)
for transformer in self._param_transformers:
for transformer in self._get_param_transformers():
transformer.transform(params, data)
del params["_param_transformers"]
return params
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
def __post_init__(self):
self._add_transformer(ParamTransformerDrop("lr"))
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
def _get_param_transformers(self):
return [
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
]
@dataclass
@ -226,7 +220,7 @@ class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
@dataclass
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
@ -234,16 +228,16 @@ class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def __post_init__(self):
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
self._add_transformer(
def _get_param_transformers(self):
return [
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler(
"actor_lr_scheduler_factory",
"critic1_lr_scheduler_factory",
"critic2_lr_scheduler_factory",
"lr_scheduler",
),
)
]
@dataclass
@ -257,10 +251,12 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def __post_init__(self):
ParamsMixinActorAndDualCritics.__post_init__(self)
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass
@ -275,9 +271,10 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
# TODO change to stateless variant
def __post_init__(self):
ParamsMixinActorAndDualCritics.__post_init__(self)
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
return transformers