Remove parameter transformers from config object state,
composing the list dynamically instead
This commit is contained in:
parent
78b6dd1f49
commit
acd89fa3b0
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
@ -159,39 +159,33 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||
kwargs[self.key] = value.create_value(data.envs)
|
||||
|
||||
|
||||
class ITransformableParams(ABC):
|
||||
@abstractmethod
|
||||
def _add_transformer(self, transformer: ParamTransformer):
|
||||
class GetParamTransformersProtocol(Protocol):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params(ITransformableParams):
|
||||
_param_transformers: list[ParamTransformer] = field(
|
||||
init=False,
|
||||
default_factory=list,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def _add_transformer(self, transformer: ParamTransformer):
|
||||
self._param_transformers.append(transformer)
|
||||
|
||||
class Params(GetParamTransformersProtocol):
|
||||
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
||||
params = asdict(self)
|
||||
for transformer in self._param_transformers:
|
||||
for transformer in self._get_param_transformers():
|
||||
transformer.transform(params, data)
|
||||
del params["_param_transformers"]
|
||||
return params
|
||||
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
|
||||
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._add_transformer(ParamTransformerDrop("lr"))
|
||||
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
|
||||
def _get_param_transformers(self):
|
||||
return [
|
||||
ParamTransformerDrop("lr"),
|
||||
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -226,7 +220,7 @@ class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
||||
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
@ -234,16 +228,16 @@ class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
|
||||
self._add_transformer(
|
||||
def _get_param_transformers(self):
|
||||
return [
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerActorDualCriticsLRScheduler(
|
||||
"actor_lr_scheduler_factory",
|
||||
"critic1_lr_scheduler_factory",
|
||||
"critic2_lr_scheduler_factory",
|
||||
"lr_scheduler",
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -257,10 +251,12 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
def __post_init__(self):
|
||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
|
||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
def _get_param_transformers(self):
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -275,9 +271,10 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
# TODO change to stateless variant
|
||||
def __post_init__(self):
|
||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
||||
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
||||
def _get_param_transformers(self):
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
||||
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
||||
return transformers
|
||||
|
Loading…
x
Reference in New Issue
Block a user