Move parameter transformation directly into parameter objects,

achieving greater separation of concerns and improved maintainability
This commit is contained in:
Dominik Jain 2023-09-26 17:43:16 +02:00
parent 38cf982034
commit d4e604b46e
2 changed files with 217 additions and 121 deletions

View File

@ -1,7 +1,6 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
@ -19,12 +18,8 @@ from tianshou.highlevel.module import (
TDevice,
)
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.highlevel.params.policy_params import (
ParamTransformer,
ParamTransformerData,
PPOParams,
SACParams,
TD3Params,
@ -32,7 +27,6 @@ from tianshou.highlevel.params.policy_params import (
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model"
@ -145,26 +139,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
)
class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
def transform(self, kwargs: dict[str, Any]) -> None:
for k in self.keys:
del kwargs[k]
class ParamTransformerLRScheduler(ParamTransformer):
def __init__(self, optim: torch.optim.Optimizer):
self.optim = optim
def transform(self, kwargs: dict[str, Any]) -> None:
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
kwargs["lr_scheduler"] = (
factory.create_scheduler(self.optim) if factory is not None else None
)
class _ActorMixin:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
@ -269,8 +243,12 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler(actor_critic.optim),
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
optim=actor_critic.optim,
),
)
return PPOPolicy(
actor=actor_critic.actor,
@ -282,43 +260,6 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
)
class ParamTransformerAlpha(ParamTransformer):
def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
self.envs = envs
self.optim_factory = optim_factory
self.device = device
def transform(self, kwargs: dict[str, Any]) -> None:
key = "alpha"
alpha = self.get(kwargs, key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
class ParamTransformerMultiLRScheduler(ParamTransformer):
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
self.optim_key_list = optim_key_list
def transform(self, kwargs: dict[str, Any]) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
kwargs,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
match len(lr_schedulers):
case 0:
lr_scheduler = None
case 1:
lr_scheduler = lr_schedulers[0]
case _:
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
kwargs["lr_scheduler"] = lr_scheduler
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
@ -346,15 +287,14 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler(
[
(actor.optim, "actor_lr_scheduler_factory"),
(critic1.optim, "critic1_lr_scheduler_factory"),
(critic2.optim, "critic2_lr_scheduler_factory"),
],
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
)
return SACPolicy(
actor=actor.module,
@ -369,28 +309,6 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str, envs: Environments):
self.key = key
self.envs = envs
def transform(self, kwargs: dict[str, Any]) -> None:
value = kwargs[self.key]
if isinstance(value, NoiseFactory):
kwargs[self.key] = value.create_noise(self.envs)
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str, envs: Environments):
self.key = key
self.envs = envs
def transform(self, kwargs: dict[str, Any]) -> None:
value = kwargs[self.key]
if isinstance(value, FloatEnvParamFactory):
kwargs[self.key] = value.create_param(self.envs)
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
@ -418,17 +336,14 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler(
[
(actor.optim, "actor_lr_scheduler_factory"),
(critic1.optim, "critic1_lr_scheduler_factory"),
(critic2.optim, "critic2_lr_scheduler_factory"),
],
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic1,
critic2=critic2,
),
ParamTransformerNoiseFactory("exploration_noise", envs),
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
)
return TD3Policy(
actor=actor.module,

View File

@ -1,19 +1,41 @@
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from typing import Any, Literal
import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import ModuleOpt, TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.utils import MultipleLRSchedulers
@dataclass(kw_only=True)
class ParamTransformerData:
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
The representation contains the superset of all data items that are required by different types of agent factories.
An agent factory is expected to set only the attributes that are relevant to its parameters.
"""
envs: Environments
device: TDevice
optim_factory: OptimizerFactory
optim: torch.optim.Optimizer | None = None
"""the single optimizer for the case where there is just one"""
actor: ModuleOpt | None = None
critic1: ModuleOpt | None = None
critic2: ModuleOpt | None = None
class ParamTransformer(ABC):
@abstractmethod
def transform(self, kwargs: dict[str, Any]) -> None:
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
pass
@staticmethod
@ -24,13 +46,152 @@ class ParamTransformer(ABC):
return value
class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
for k in self.keys:
del kwargs[k]
class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`.
"""
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
self.key_scheduler_factory = key_scheduler_factory
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.optim is not None
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
params[self.key_scheduler] = (
factory.create_scheduler(data.optim) if factory is not None else None
)
class ParamTransformerMultiLRScheduler(ParamTransformer):
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
if more than one factory is indeed given.
"""
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
:param key_scheduler: the key under which to store the resulting learning rate scheduler
"""
self.optim_key_list = optim_key_list
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
params,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
match len(lr_schedulers):
case 0:
lr_scheduler = None
case 1:
lr_scheduler = lr_schedulers[0]
case _:
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
params[self.key_scheduler] = lr_scheduler
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic1: str,
key_scheduler_factory_critic2: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic1 = key_scheduler_factory_critic1
self.key_factory_critic2 = key_scheduler_factory_critic2
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic1),
(data.critic2.optim, self.key_factory_critic2),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerAutoAlpha(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
alpha = self.get(kwargs, self.key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
value = params[self.key]
if isinstance(value, NoiseFactory):
params[self.key] = value.create_noise(data.envs)
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
if isinstance(value, FloatEnvParamFactory):
kwargs[self.key] = value.create_param(data.envs)
class ITransformableParams(ABC):
@abstractmethod
def _add_transformer(self, transformer: ParamTransformer):
pass
@dataclass
class Params:
def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
d = asdict(self)
for transformer in transformers:
transformer.transform(d)
return d
class Params(ITransformableParams):
_param_transformers: list[ParamTransformer] = field(
init=False,
default_factory=list,
repr=False,
)
def _add_transformer(self, transformer: ParamTransformer):
self._param_transformers.append(transformer)
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self)
for transformer in self._param_transformers:
transformer.transform(params, data)
del params["_param_transformers"]
return params
@dataclass
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
def __post_init__(self):
self._add_transformer(ParamTransformerDrop("lr"))
self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
@dataclass
@ -54,7 +215,7 @@ class A2CParams(PGParams):
@dataclass
class PPOParams(A2CParams):
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
"""PPO specific config."""
eps_clip: float = 0.2
@ -62,12 +223,10 @@ class PPOParams(A2CParams):
value_clip: bool = False
advantage_normalization: bool = True
recompute_advantage: bool = False
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class ActorAndDualCriticsParams(Params):
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
@ -75,21 +234,37 @@ class ActorAndDualCriticsParams(Params):
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def __post_init__(self):
self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
self._add_transformer(
ParamTransformerActorDualCriticsLRScheduler(
"actor_lr_scheduler_factory",
"critic1_lr_scheduler_factory",
"critic2_lr_scheduler_factory",
"lr_scheduler",
),
)
@dataclass
class SACParams(ActorAndDualCriticsParams):
class SACParams(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
exploration_noise: BaseNoise | Literal["default"] | None = None
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def __post_init__(self):
ParamsMixinActorAndDualCritics.__post_init__(self)
self._add_transformer(ParamTransformerAutoAlpha("alpha"))
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
@dataclass
class TD3Params(ActorAndDualCriticsParams):
class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
@ -99,3 +274,9 @@ class TD3Params(ActorAndDualCriticsParams):
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def __post_init__(self):
ParamsMixinActorAndDualCritics.__post_init__(self)
self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))