* Allow to specify trainer callbacks (train_fn, test_fn, stop_fn) in high-level API, adding the necessary abstractions and pass-on mechanisms * Add example atari_dqn_hl
409 lines
15 KiB
Python
409 lines
15 KiB
Python
from abc import ABC, abstractmethod
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Any, Literal, Protocol
|
|
|
|
import torch
|
|
|
|
from tianshou.exploration import BaseNoise
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.module.core import TDevice
|
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
|
from tianshou.highlevel.params.dist_fn import (
|
|
DistributionFunctionFactory,
|
|
DistributionFunctionFactoryDefault,
|
|
TDistributionFunction,
|
|
)
|
|
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
|
from tianshou.highlevel.params.noise import NoiseFactory
|
|
from tianshou.utils import MultipleLRSchedulers
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ParamTransformerData:
|
|
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
|
|
|
|
The representation contains the superset of all data items that are required by different types of agent factories.
|
|
An agent factory is expected to set only the attributes that are relevant to its parameters.
|
|
"""
|
|
|
|
envs: Environments
|
|
device: TDevice
|
|
optim_factory: OptimizerFactory
|
|
optim: torch.optim.Optimizer | None = None
|
|
"""the single optimizer for the case where there is just one"""
|
|
actor: ModuleOpt | None = None
|
|
critic1: ModuleOpt | None = None
|
|
critic2: ModuleOpt | None = None
|
|
|
|
|
|
class ParamTransformer(ABC):
|
|
"""Transforms one or more parameters from the representation used by the high-level API
|
|
to the representation required by the (low-level) policy implementation.
|
|
It operates directly on a dictionary of keyword arguments, which is initially
|
|
generated from the parameter dataclass (subclass of `Params`).
|
|
"""
|
|
|
|
@abstractmethod
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
|
|
value = d[key]
|
|
if drop:
|
|
del d[key]
|
|
return value
|
|
|
|
|
|
class ParamTransformerDrop(ParamTransformer):
|
|
def __init__(self, *keys: str):
|
|
self.keys = keys
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
for k in self.keys:
|
|
del kwargs[k]
|
|
|
|
|
|
class ParamTransformerChangeValue(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData):
|
|
params[self.key] = self.change_value(params[self.key], data)
|
|
|
|
@abstractmethod
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
|
pass
|
|
|
|
|
|
class ParamTransformerLRScheduler(ParamTransformer):
|
|
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
|
a learning rate scheduler (added) for the data member `optim`.
|
|
"""
|
|
|
|
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
|
|
self.key_scheduler_factory = key_scheduler_factory
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
assert data.optim is not None
|
|
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
|
|
params[self.key_scheduler] = (
|
|
factory.create_scheduler(data.optim) if factory is not None else None
|
|
)
|
|
|
|
|
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
|
if more than one factory is indeed given.
|
|
"""
|
|
|
|
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
|
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
|
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
|
"""
|
|
self.optim_key_list = optim_key_list
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
lr_schedulers = []
|
|
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
|
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
|
params,
|
|
lr_scheduler_factory_key,
|
|
drop=True,
|
|
)
|
|
if lr_scheduler_factory is not None:
|
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
|
match len(lr_schedulers):
|
|
case 0:
|
|
lr_scheduler = None
|
|
case 1:
|
|
lr_scheduler = lr_schedulers[0]
|
|
case _:
|
|
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
|
params[self.key_scheduler] = lr_scheduler
|
|
|
|
|
|
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
|
|
def __init__(
|
|
self,
|
|
key_scheduler_factory_actor: str,
|
|
key_scheduler_factory_critic: str,
|
|
key_scheduler: str,
|
|
):
|
|
self.key_factory_actor = key_scheduler_factory_actor
|
|
self.key_factory_critic = key_scheduler_factory_critic
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
transformer = ParamTransformerMultiLRScheduler(
|
|
[
|
|
(data.actor.optim, self.key_factory_actor),
|
|
(data.critic1.optim, self.key_factory_critic),
|
|
],
|
|
self.key_scheduler,
|
|
)
|
|
transformer.transform(params, data)
|
|
|
|
|
|
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
|
def __init__(
|
|
self,
|
|
key_scheduler_factory_actor: str,
|
|
key_scheduler_factory_critic1: str,
|
|
key_scheduler_factory_critic2: str,
|
|
key_scheduler: str,
|
|
):
|
|
self.key_factory_actor = key_scheduler_factory_actor
|
|
self.key_factory_critic1 = key_scheduler_factory_critic1
|
|
self.key_factory_critic2 = key_scheduler_factory_critic2
|
|
self.key_scheduler = key_scheduler
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
transformer = ParamTransformerMultiLRScheduler(
|
|
[
|
|
(data.actor.optim, self.key_factory_actor),
|
|
(data.critic1.optim, self.key_factory_critic1),
|
|
(data.critic2.optim, self.key_factory_critic2),
|
|
],
|
|
self.key_scheduler,
|
|
)
|
|
transformer.transform(params, data)
|
|
|
|
|
|
class ParamTransformerAutoAlpha(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
alpha = self.get(kwargs, self.key)
|
|
if isinstance(alpha, AutoAlphaFactory):
|
|
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
|
|
|
|
|
class ParamTransformerNoiseFactory(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
value = params[self.key]
|
|
if isinstance(value, NoiseFactory):
|
|
params[self.key] = value.create_noise(data.envs)
|
|
|
|
|
|
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
value = kwargs[self.key]
|
|
if isinstance(value, EnvValueFactory):
|
|
kwargs[self.key] = value.create_value(data.envs)
|
|
|
|
|
|
class ParamTransformerDistributionFunction(ParamTransformer):
|
|
def __init__(self, key: str):
|
|
self.key = key
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
value = kwargs[self.key]
|
|
if value == "default":
|
|
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
|
elif isinstance(value, DistributionFunctionFactory):
|
|
kwargs[self.key] = value.create_dist_fn(data.envs)
|
|
|
|
|
|
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
|
if value == "default":
|
|
return data.envs.get_type().is_continuous()
|
|
else:
|
|
return value
|
|
|
|
|
|
class GetParamTransformersProtocol(Protocol):
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class Params(GetParamTransformersProtocol):
|
|
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
|
params = asdict(self)
|
|
for transformer in self._get_param_transformers():
|
|
transformer.transform(params, data)
|
|
return params
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
return []
|
|
|
|
|
|
@dataclass
|
|
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
|
lr: float = 1e-3
|
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
|
|
def _get_param_transformers(self):
|
|
return [
|
|
ParamTransformerDrop("lr"),
|
|
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
|
actor_lr: float = 1e-3
|
|
critic_lr: float = 1e-3
|
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
|
|
def _get_param_transformers(self):
|
|
return [
|
|
ParamTransformerDrop("actor_lr", "critic_lr"),
|
|
ParamTransformerActorAndCriticLRScheduler(
|
|
"actor_lr_scheduler_factory",
|
|
"critic_lr_scheduler_factory",
|
|
"lr_scheduler",
|
|
),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class PGParams(Params):
|
|
"""Config of general policy-gradient algorithms."""
|
|
|
|
discount_factor: float = 0.99
|
|
reward_normalization: bool = False
|
|
deterministic_eval: bool = False
|
|
action_scaling: bool | Literal["default"] = "default"
|
|
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
transformers = super()._get_param_transformers()
|
|
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
|
return transformers
|
|
|
|
|
|
@dataclass
|
|
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
|
|
vf_coef: float = 0.5
|
|
ent_coef: float = 0.01
|
|
max_grad_norm: float | None = None
|
|
gae_lambda: float = 0.95
|
|
max_batchsize: int = 256
|
|
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
transformers = super()._get_param_transformers()
|
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
|
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
|
return transformers
|
|
|
|
|
|
@dataclass
|
|
class PPOParams(A2CParams):
|
|
"""PPO specific config."""
|
|
|
|
eps_clip: float = 0.2
|
|
dual_clip: float | None = None
|
|
value_clip: bool = False
|
|
advantage_normalization: bool = True
|
|
recompute_advantage: bool = False
|
|
|
|
|
|
@dataclass
|
|
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
|
actor_lr: float = 1e-3
|
|
critic1_lr: float = 1e-3
|
|
critic2_lr: float = 1e-3
|
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
|
|
def _get_param_transformers(self):
|
|
return [
|
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
|
ParamTransformerActorDualCriticsLRScheduler(
|
|
"actor_lr_scheduler_factory",
|
|
"critic1_lr_scheduler_factory",
|
|
"critic2_lr_scheduler_factory",
|
|
"lr_scheduler",
|
|
),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|
tau: float = 0.005
|
|
gamma: float = 0.99
|
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
|
estimation_step: int = 1
|
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
|
deterministic_eval: bool = True
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip"] | None = "clip"
|
|
|
|
def _get_param_transformers(self):
|
|
transformers = super()._get_param_transformers()
|
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
|
return transformers
|
|
|
|
|
|
@dataclass
|
|
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
|
discount_factor: float = 0.99
|
|
estimation_step: int = 1
|
|
target_update_freq: int = 0
|
|
reward_normalization: bool = False
|
|
is_double: bool = True
|
|
clip_loss_grad: bool = False
|
|
|
|
def _get_param_transformers(self):
|
|
transformers = super()._get_param_transformers()
|
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
|
return transformers
|
|
|
|
|
|
@dataclass
|
|
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
|
tau: float = 0.005
|
|
gamma: float = 0.99
|
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
|
estimation_step: int = 1
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip"] | None = "clip"
|
|
|
|
def _get_param_transformers(self):
|
|
transformers = super()._get_param_transformers()
|
|
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
|
return transformers
|
|
|
|
|
|
@dataclass
|
|
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
|
tau: float = 0.005
|
|
gamma: float = 0.99
|
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
|
policy_noise: float | FloatEnvValueFactory = 0.2
|
|
noise_clip: float | FloatEnvValueFactory = 0.5
|
|
update_actor_freq: int = 2
|
|
estimation_step: int = 1
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip"] | None = "clip"
|
|
|
|
def _get_param_transformers(self):
|
|
transformers = super()._get_param_transformers()
|
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
|
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
|
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
|
return transformers
|