Move parameter transformation directly into parameter objects,
achieving greater separation of concerns and improved maintainability
This commit is contained in:
parent
38cf982034
commit
d4e604b46e
@ -1,7 +1,6 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,12 +18,8 @@ from tianshou.highlevel.module import (
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
ParamTransformer,
|
||||
ParamTransformerData,
|
||||
PPOParams,
|
||||
SACParams,
|
||||
TD3Params,
|
||||
@ -32,7 +27,6 @@ from tianshou.highlevel.params.policy_params import (
|
||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
@ -145,26 +139,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
class ParamTransformerDrop(ParamTransformer):
|
||||
def __init__(self, *keys: str):
|
||||
self.keys = keys
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
for k in self.keys:
|
||||
del kwargs[k]
|
||||
|
||||
|
||||
class ParamTransformerLRScheduler(ParamTransformer):
|
||||
def __init__(self, optim: torch.optim.Optimizer):
|
||||
self.optim = optim
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
||||
kwargs["lr_scheduler"] = (
|
||||
factory.create_scheduler(self.optim) if factory is not None else None
|
||||
)
|
||||
|
||||
|
||||
class _ActorMixin:
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
||||
@ -269,8 +243,12 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("lr"),
|
||||
ParamTransformerLRScheduler(actor_critic.optim),
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
optim=actor_critic.optim,
|
||||
),
|
||||
)
|
||||
return PPOPolicy(
|
||||
actor=actor_critic.actor,
|
||||
@ -282,43 +260,6 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
||||
)
|
||||
|
||||
|
||||
class ParamTransformerAlpha(ParamTransformer):
|
||||
def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
|
||||
self.envs = envs
|
||||
self.optim_factory = optim_factory
|
||||
self.device = device
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
key = "alpha"
|
||||
alpha = self.get(kwargs, key)
|
||||
if isinstance(alpha, AutoAlphaFactory):
|
||||
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
|
||||
|
||||
|
||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
|
||||
self.optim_key_list = optim_key_list
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
lr_schedulers = []
|
||||
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
||||
kwargs,
|
||||
lr_scheduler_factory_key,
|
||||
drop=True,
|
||||
)
|
||||
if lr_scheduler_factory is not None:
|
||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||
match len(lr_schedulers):
|
||||
case 0:
|
||||
lr_scheduler = None
|
||||
case 1:
|
||||
lr_scheduler = lr_schedulers[0]
|
||||
case _:
|
||||
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
||||
kwargs["lr_scheduler"] = lr_scheduler
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@ -346,15 +287,14 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(actor.optim, "actor_lr_scheduler_factory"),
|
||||
(critic1.optim, "critic1_lr_scheduler_factory"),
|
||||
(critic2.optim, "critic2_lr_scheduler_factory"),
|
||||
],
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
actor=actor,
|
||||
critic1=critic1,
|
||||
critic2=critic2,
|
||||
),
|
||||
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
|
||||
)
|
||||
return SACPolicy(
|
||||
actor=actor.module,
|
||||
@ -369,28 +309,6 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
)
|
||||
|
||||
|
||||
class ParamTransformerNoiseFactory(ParamTransformer):
|
||||
def __init__(self, key: str, envs: Environments):
|
||||
self.key = key
|
||||
self.envs = envs
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
value = kwargs[self.key]
|
||||
if isinstance(value, NoiseFactory):
|
||||
kwargs[self.key] = value.create_noise(self.envs)
|
||||
|
||||
|
||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||
def __init__(self, key: str, envs: Environments):
|
||||
self.key = key
|
||||
self.envs = envs
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
value = kwargs[self.key]
|
||||
if isinstance(value, FloatEnvParamFactory):
|
||||
kwargs[self.key] = value.create_param(self.envs)
|
||||
|
||||
|
||||
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@ -418,17 +336,14 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(actor.optim, "actor_lr_scheduler_factory"),
|
||||
(critic1.optim, "critic1_lr_scheduler_factory"),
|
||||
(critic2.optim, "critic2_lr_scheduler_factory"),
|
||||
],
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
actor=actor,
|
||||
critic1=critic1,
|
||||
critic2=critic2,
|
||||
),
|
||||
ParamTransformerNoiseFactory("exploration_noise", envs),
|
||||
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
|
||||
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
|
||||
)
|
||||
return TD3Policy(
|
||||
actor=actor.module,
|
||||
|
@ -1,19 +1,41 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module import ModuleOpt, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ParamTransformerData:
|
||||
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
|
||||
|
||||
The representation contains the superset of all data items that are required by different types of agent factories.
|
||||
An agent factory is expected to set only the attributes that are relevant to its parameters.
|
||||
"""
|
||||
|
||||
envs: Environments
|
||||
device: TDevice
|
||||
optim_factory: OptimizerFactory
|
||||
optim: torch.optim.Optimizer | None = None
|
||||
"""the single optimizer for the case where there is just one"""
|
||||
actor: ModuleOpt | None = None
|
||||
critic1: ModuleOpt | None = None
|
||||
critic2: ModuleOpt | None = None
|
||||
|
||||
|
||||
class ParamTransformer(ABC):
|
||||
@abstractmethod
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -24,13 +46,152 @@ class ParamTransformer(ABC):
|
||||
return value
|
||||
|
||||
|
||||
class ParamTransformerDrop(ParamTransformer):
|
||||
def __init__(self, *keys: str):
|
||||
self.keys = keys
|
||||
|
||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
for k in self.keys:
|
||||
del kwargs[k]
|
||||
|
||||
|
||||
class ParamTransformerLRScheduler(ParamTransformer):
|
||||
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
||||
a learning rate scheduler (added) for the data member `optim`.
|
||||
"""
|
||||
|
||||
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
|
||||
self.key_scheduler_factory = key_scheduler_factory
|
||||
self.key_scheduler = key_scheduler
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
assert data.optim is not None
|
||||
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
|
||||
params[self.key_scheduler] = (
|
||||
factory.create_scheduler(data.optim) if factory is not None else None
|
||||
)
|
||||
|
||||
|
||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
||||
if more than one factory is indeed given.
|
||||
"""
|
||||
|
||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
||||
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
||||
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
||||
"""
|
||||
self.optim_key_list = optim_key_list
|
||||
self.key_scheduler = key_scheduler
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
lr_schedulers = []
|
||||
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
||||
params,
|
||||
lr_scheduler_factory_key,
|
||||
drop=True,
|
||||
)
|
||||
if lr_scheduler_factory is not None:
|
||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||
match len(lr_schedulers):
|
||||
case 0:
|
||||
lr_scheduler = None
|
||||
case 1:
|
||||
lr_scheduler = lr_schedulers[0]
|
||||
case _:
|
||||
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
||||
params[self.key_scheduler] = lr_scheduler
|
||||
|
||||
|
||||
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
key_scheduler_factory_actor: str,
|
||||
key_scheduler_factory_critic1: str,
|
||||
key_scheduler_factory_critic2: str,
|
||||
key_scheduler: str,
|
||||
):
|
||||
self.key_factory_actor = key_scheduler_factory_actor
|
||||
self.key_factory_critic1 = key_scheduler_factory_critic1
|
||||
self.key_factory_critic2 = key_scheduler_factory_critic2
|
||||
self.key_scheduler = key_scheduler
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
transformer = ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(data.actor.optim, self.key_factory_actor),
|
||||
(data.critic1.optim, self.key_factory_critic1),
|
||||
(data.critic2.optim, self.key_factory_critic2),
|
||||
],
|
||||
self.key_scheduler,
|
||||
)
|
||||
transformer.transform(params, data)
|
||||
|
||||
|
||||
class ParamTransformerAutoAlpha(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
alpha = self.get(kwargs, self.key)
|
||||
if isinstance(alpha, AutoAlphaFactory):
|
||||
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
||||
|
||||
|
||||
class ParamTransformerNoiseFactory(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
value = params[self.key]
|
||||
if isinstance(value, NoiseFactory):
|
||||
params[self.key] = value.create_noise(data.envs)
|
||||
|
||||
|
||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
value = kwargs[self.key]
|
||||
if isinstance(value, FloatEnvParamFactory):
|
||||
kwargs[self.key] = value.create_param(data.envs)
|
||||
|
||||
|
||||
class ITransformableParams(ABC):
|
||||
@abstractmethod
|
||||
def _add_transformer(self, transformer: ParamTransformer):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
|
||||
d = asdict(self)
|
||||
for transformer in transformers:
|
||||
transformer.transform(d)
|
||||
return d
|
||||
class Params(ITransformableParams):
|
||||
_param_transformers: list[ParamTransformer] = field(
|
||||
init=False,
|
||||
default_factory=list,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def _add_transformer(self, transformer: ParamTransformer):
|
||||
self._param_transformers.append(transformer)
|
||||
|
||||
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
||||
params = asdict(self)
|
||||
for transformer in self._param_transformers:
|
||||
transformer.transform(params, data)
|
||||
del params["_param_transformers"]
|
||||
return params
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._add_transformer(ParamTransformerDrop("lr"))
|
||||
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -54,7 +215,7 @@ class A2CParams(PGParams):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOParams(A2CParams):
|
||||
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
|
||||
"""PPO specific config."""
|
||||
|
||||
eps_clip: float = 0.2
|
||||
@ -62,12 +223,10 @@ class PPOParams(A2CParams):
|
||||
value_clip: bool = False
|
||||
advantage_normalization: bool = True
|
||||
recompute_advantage: bool = False
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorAndDualCriticsParams(Params):
|
||||
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
@ -75,21 +234,37 @@ class ActorAndDualCriticsParams(Params):
|
||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
|
||||
self._add_transformer(
|
||||
ParamTransformerActorDualCriticsLRScheduler(
|
||||
"actor_lr_scheduler_factory",
|
||||
"critic1_lr_scheduler_factory",
|
||||
"critic2_lr_scheduler_factory",
|
||||
"lr_scheduler",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACParams(ActorAndDualCriticsParams):
|
||||
class SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||
estimation_step: int = 1
|
||||
exploration_noise: BaseNoise | Literal["default"] | None = None
|
||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
||||
deterministic_eval: bool = True
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
def __post_init__(self):
|
||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
|
||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TD3Params(ActorAndDualCriticsParams):
|
||||
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
||||
@ -99,3 +274,9 @@ class TD3Params(ActorAndDualCriticsParams):
|
||||
estimation_step: int = 1
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
def __post_init__(self):
|
||||
ParamsMixinActorAndDualCritics.__post_init__(self)
|
||||
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
||||
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user