Move parameter transformation directly into parameter objects,
achieving greater separation of concerns and improved maintainability
This commit is contained in:
parent
38cf982034
commit
d4e604b46e
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -19,12 +18,8 @@ from tianshou.highlevel.module import (
|
|||||||
TDevice,
|
TDevice,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
|
||||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
|
||||||
from tianshou.highlevel.params.noise import NoiseFactory
|
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
ParamTransformer,
|
ParamTransformerData,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
@ -32,7 +27,6 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
from tianshou.policy.modelfree.pg import TDistParams
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils import MultipleLRSchedulers
|
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
@ -145,26 +139,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerDrop(ParamTransformer):
|
|
||||||
def __init__(self, *keys: str):
|
|
||||||
self.keys = keys
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
|
||||||
for k in self.keys:
|
|
||||||
del kwargs[k]
|
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerLRScheduler(ParamTransformer):
|
|
||||||
def __init__(self, optim: torch.optim.Optimizer):
|
|
||||||
self.optim = optim
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
|
||||||
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
|
||||||
kwargs["lr_scheduler"] = (
|
|
||||||
factory.create_scheduler(self.optim) if factory is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _ActorMixin:
|
class _ActorMixin:
|
||||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||||
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
||||||
@ -269,8 +243,12 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
|||||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerDrop("lr"),
|
ParamTransformerData(
|
||||||
ParamTransformerLRScheduler(actor_critic.optim),
|
envs=envs,
|
||||||
|
device=device,
|
||||||
|
optim_factory=self.optim_factory,
|
||||||
|
optim=actor_critic.optim,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return PPOPolicy(
|
return PPOPolicy(
|
||||||
actor=actor_critic.actor,
|
actor=actor_critic.actor,
|
||||||
@ -282,43 +260,6 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerAlpha(ParamTransformer):
|
|
||||||
def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
|
|
||||||
self.envs = envs
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
|
||||||
key = "alpha"
|
|
||||||
alpha = self.get(kwargs, key)
|
|
||||||
if isinstance(alpha, AutoAlphaFactory):
|
|
||||||
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
|
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|
||||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
|
|
||||||
self.optim_key_list = optim_key_list
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
|
||||||
lr_schedulers = []
|
|
||||||
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
|
||||||
kwargs,
|
|
||||||
lr_scheduler_factory_key,
|
|
||||||
drop=True,
|
|
||||||
)
|
|
||||||
if lr_scheduler_factory is not None:
|
|
||||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
|
||||||
match len(lr_schedulers):
|
|
||||||
case 0:
|
|
||||||
lr_scheduler = None
|
|
||||||
case 1:
|
|
||||||
lr_scheduler = lr_schedulers[0]
|
|
||||||
case _:
|
|
||||||
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
|
||||||
kwargs["lr_scheduler"] = lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -346,15 +287,14 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
ParamTransformerData(
|
||||||
ParamTransformerMultiLRScheduler(
|
envs=envs,
|
||||||
[
|
device=device,
|
||||||
(actor.optim, "actor_lr_scheduler_factory"),
|
optim_factory=self.optim_factory,
|
||||||
(critic1.optim, "critic1_lr_scheduler_factory"),
|
actor=actor,
|
||||||
(critic2.optim, "critic2_lr_scheduler_factory"),
|
critic1=critic1,
|
||||||
],
|
critic2=critic2,
|
||||||
),
|
),
|
||||||
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
|
|
||||||
)
|
)
|
||||||
return SACPolicy(
|
return SACPolicy(
|
||||||
actor=actor.module,
|
actor=actor.module,
|
||||||
@ -369,28 +309,6 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerNoiseFactory(ParamTransformer):
|
|
||||||
def __init__(self, key: str, envs: Environments):
|
|
||||||
self.key = key
|
|
||||||
self.envs = envs
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
|
||||||
value = kwargs[self.key]
|
|
||||||
if isinstance(value, NoiseFactory):
|
|
||||||
kwargs[self.key] = value.create_noise(self.envs)
|
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
|
||||||
def __init__(self, key: str, envs: Environments):
|
|
||||||
self.key = key
|
|
||||||
self.envs = envs
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
|
||||||
value = kwargs[self.key]
|
|
||||||
if isinstance(value, FloatEnvParamFactory):
|
|
||||||
kwargs[self.key] = value.create_param(self.envs)
|
|
||||||
|
|
||||||
|
|
||||||
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -418,17 +336,14 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
ParamTransformerData(
|
||||||
ParamTransformerMultiLRScheduler(
|
envs=envs,
|
||||||
[
|
device=device,
|
||||||
(actor.optim, "actor_lr_scheduler_factory"),
|
optim_factory=self.optim_factory,
|
||||||
(critic1.optim, "critic1_lr_scheduler_factory"),
|
actor=actor,
|
||||||
(critic2.optim, "critic2_lr_scheduler_factory"),
|
critic1=critic1,
|
||||||
],
|
critic2=critic2,
|
||||||
),
|
),
|
||||||
ParamTransformerNoiseFactory("exploration_noise", envs),
|
|
||||||
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
|
|
||||||
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
|
|
||||||
)
|
)
|
||||||
return TD3Policy(
|
return TD3Policy(
|
||||||
actor=actor.module,
|
actor=actor.module,
|
||||||
|
@ -1,19 +1,41 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.module import ModuleOpt, TDevice
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||||
from tianshou.highlevel.params.noise import NoiseFactory
|
from tianshou.highlevel.params.noise import NoiseFactory
|
||||||
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class ParamTransformerData:
|
||||||
|
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
|
||||||
|
|
||||||
|
The representation contains the superset of all data items that are required by different types of agent factories.
|
||||||
|
An agent factory is expected to set only the attributes that are relevant to its parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
envs: Environments
|
||||||
|
device: TDevice
|
||||||
|
optim_factory: OptimizerFactory
|
||||||
|
optim: torch.optim.Optimizer | None = None
|
||||||
|
"""the single optimizer for the case where there is just one"""
|
||||||
|
actor: ModuleOpt | None = None
|
||||||
|
critic1: ModuleOpt | None = None
|
||||||
|
critic2: ModuleOpt | None = None
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformer(ABC):
|
class ParamTransformer(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -24,13 +46,152 @@ class ParamTransformer(ABC):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerDrop(ParamTransformer):
|
||||||
|
def __init__(self, *keys: str):
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
for k in self.keys:
|
||||||
|
del kwargs[k]
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerLRScheduler(ParamTransformer):
|
||||||
|
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||||
|
a learning rate scheduler (added) for the data member `optim`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
|
||||||
|
self.key_scheduler_factory = key_scheduler_factory
|
||||||
|
self.key_scheduler = key_scheduler
|
||||||
|
|
||||||
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
assert data.optim is not None
|
||||||
|
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
|
||||||
|
params[self.key_scheduler] = (
|
||||||
|
factory.create_scheduler(data.optim) if factory is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||||
|
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
||||||
|
if more than one factory is indeed given.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
||||||
|
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||||
|
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
||||||
|
"""
|
||||||
|
self.optim_key_list = optim_key_list
|
||||||
|
self.key_scheduler = key_scheduler
|
||||||
|
|
||||||
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
lr_schedulers = []
|
||||||
|
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
||||||
|
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
||||||
|
params,
|
||||||
|
lr_scheduler_factory_key,
|
||||||
|
drop=True,
|
||||||
|
)
|
||||||
|
if lr_scheduler_factory is not None:
|
||||||
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||||
|
match len(lr_schedulers):
|
||||||
|
case 0:
|
||||||
|
lr_scheduler = None
|
||||||
|
case 1:
|
||||||
|
lr_scheduler = lr_schedulers[0]
|
||||||
|
case _:
|
||||||
|
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
||||||
|
params[self.key_scheduler] = lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key_scheduler_factory_actor: str,
|
||||||
|
key_scheduler_factory_critic1: str,
|
||||||
|
key_scheduler_factory_critic2: str,
|
||||||
|
key_scheduler: str,
|
||||||
|
):
|
||||||
|
self.key_factory_actor = key_scheduler_factory_actor
|
||||||
|
self.key_factory_critic1 = key_scheduler_factory_critic1
|
||||||
|
self.key_factory_critic2 = key_scheduler_factory_critic2
|
||||||
|
self.key_scheduler = key_scheduler
|
||||||
|
|
||||||
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
transformer = ParamTransformerMultiLRScheduler(
|
||||||
|
[
|
||||||
|
(data.actor.optim, self.key_factory_actor),
|
||||||
|
(data.critic1.optim, self.key_factory_critic1),
|
||||||
|
(data.critic2.optim, self.key_factory_critic2),
|
||||||
|
],
|
||||||
|
self.key_scheduler,
|
||||||
|
)
|
||||||
|
transformer.transform(params, data)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerAutoAlpha(ParamTransformer):
|
||||||
|
def __init__(self, key: str):
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
alpha = self.get(kwargs, self.key)
|
||||||
|
if isinstance(alpha, AutoAlphaFactory):
|
||||||
|
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerNoiseFactory(ParamTransformer):
|
||||||
|
def __init__(self, key: str):
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
value = params[self.key]
|
||||||
|
if isinstance(value, NoiseFactory):
|
||||||
|
params[self.key] = value.create_noise(data.envs)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||||
|
def __init__(self, key: str):
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
value = kwargs[self.key]
|
||||||
|
if isinstance(value, FloatEnvParamFactory):
|
||||||
|
kwargs[self.key] = value.create_param(data.envs)
|
||||||
|
|
||||||
|
|
||||||
|
class ITransformableParams(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def _add_transformer(self, transformer: ParamTransformer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Params:
|
class Params(ITransformableParams):
|
||||||
def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
|
_param_transformers: list[ParamTransformer] = field(
|
||||||
d = asdict(self)
|
init=False,
|
||||||
for transformer in transformers:
|
default_factory=list,
|
||||||
transformer.transform(d)
|
repr=False,
|
||||||
return d
|
)
|
||||||
|
|
||||||
|
def _add_transformer(self, transformer: ParamTransformer):
|
||||||
|
self._param_transformers.append(transformer)
|
||||||
|
|
||||||
|
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
||||||
|
params = asdict(self)
|
||||||
|
for transformer in self._param_transformers:
|
||||||
|
transformer.transform(params, data)
|
||||||
|
del params["_param_transformers"]
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
|
||||||
|
lr: float = 1e-3
|
||||||
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._add_transformer(ParamTransformerDrop("lr"))
|
||||||
|
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -54,7 +215,7 @@ class A2CParams(PGParams):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPOParams(A2CParams):
|
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
|
||||||
"""PPO specific config."""
|
"""PPO specific config."""
|
||||||
|
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
@ -62,12 +223,10 @@ class PPOParams(A2CParams):
|
|||||||
value_clip: bool = False
|
value_clip: bool = False
|
||||||
advantage_normalization: bool = True
|
advantage_normalization: bool = True
|
||||||
recompute_advantage: bool = False
|
recompute_advantage: bool = False
|
||||||
lr: float = 1e-3
|
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActorAndDualCriticsParams(Params):
|
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
||||||
actor_lr: float = 1e-3
|
actor_lr: float = 1e-3
|
||||||
critic1_lr: float = 1e-3
|
critic1_lr: float = 1e-3
|
||||||
critic2_lr: float = 1e-3
|
critic2_lr: float = 1e-3
|
||||||
@ -75,21 +234,37 @@ class ActorAndDualCriticsParams(Params):
|
|||||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
|
||||||
|
self._add_transformer(
|
||||||
|
ParamTransformerActorDualCriticsLRScheduler(
|
||||||
|
"actor_lr_scheduler_factory",
|
||||||
|
"critic1_lr_scheduler_factory",
|
||||||
|
"critic2_lr_scheduler_factory",
|
||||||
|
"lr_scheduler",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACParams(ActorAndDualCriticsParams):
|
class SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
exploration_noise: BaseNoise | Literal["default"] | None = None
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
||||||
deterministic_eval: bool = True
|
deterministic_eval: bool = True
|
||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||||
|
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
|
||||||
|
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TD3Params(ActorAndDualCriticsParams):
|
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
||||||
@ -99,3 +274,9 @@ class TD3Params(ActorAndDualCriticsParams):
|
|||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||||
|
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
|
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
||||||
|
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user