Remove parameter transformers from config object state,
composing the list dynamically instead
This commit is contained in:
parent
78b6dd1f49
commit
acd89fa3b0
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -159,39 +159,33 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
|||||||
kwargs[self.key] = value.create_value(data.envs)
|
kwargs[self.key] = value.create_value(data.envs)
|
||||||
|
|
||||||
|
|
||||||
class ITransformableParams(ABC):
|
class GetParamTransformersProtocol(Protocol):
|
||||||
@abstractmethod
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
def _add_transformer(self, transformer: ParamTransformer):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Params(ITransformableParams):
|
class Params(GetParamTransformersProtocol):
|
||||||
_param_transformers: list[ParamTransformer] = field(
|
|
||||||
init=False,
|
|
||||||
default_factory=list,
|
|
||||||
repr=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_transformer(self, transformer: ParamTransformer):
|
|
||||||
self._param_transformers.append(transformer)
|
|
||||||
|
|
||||||
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
||||||
params = asdict(self)
|
params = asdict(self)
|
||||||
for transformer in self._param_transformers:
|
for transformer in self._get_param_transformers():
|
||||||
transformer.transform(params, data)
|
transformer.transform(params, data)
|
||||||
del params["_param_transformers"]
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
|
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def _get_param_transformers(self):
|
||||||
self._add_transformer(ParamTransformerDrop("lr"))
|
return [
|
||||||
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
|
ParamTransformerDrop("lr"),
|
||||||
|
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -226,7 +220,7 @@ class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
||||||
actor_lr: float = 1e-3
|
actor_lr: float = 1e-3
|
||||||
critic1_lr: float = 1e-3
|
critic1_lr: float = 1e-3
|
||||||
critic2_lr: float = 1e-3
|
critic2_lr: float = 1e-3
|
||||||
@ -234,16 +228,16 @@ class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
|||||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def _get_param_transformers(self):
|
||||||
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
|
return [
|
||||||
self._add_transformer(
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||||
ParamTransformerActorDualCriticsLRScheduler(
|
ParamTransformerActorDualCriticsLRScheduler(
|
||||||
"actor_lr_scheduler_factory",
|
"actor_lr_scheduler_factory",
|
||||||
"critic1_lr_scheduler_factory",
|
"critic1_lr_scheduler_factory",
|
||||||
"critic2_lr_scheduler_factory",
|
"critic2_lr_scheduler_factory",
|
||||||
"lr_scheduler",
|
"lr_scheduler",
|
||||||
),
|
),
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -257,10 +251,12 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
def __post_init__(self):
|
def _get_param_transformers(self):
|
||||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
transformers = super()._get_param_transformers()
|
||||||
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||||
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -275,9 +271,10 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
|||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
# TODO change to stateless variant
|
def _get_param_transformers(self):
|
||||||
def __post_init__(self):
|
transformers = super()._get_param_transformers()
|
||||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
||||||
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
||||||
|
return transformers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user