Tianshou/tianshou/highlevel/params/policy_params.py

394 lines
14 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol
import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
TDistributionFunction,
)
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.utils import MultipleLRSchedulers
@dataclass(kw_only=True)
class ParamTransformerData:
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
The representation contains the superset of all data items that are required by different types of agent factories.
An agent factory is expected to set only the attributes that are relevant to its parameters.
"""
envs: Environments
device: TDevice
optim_factory: OptimizerFactory
optim: torch.optim.Optimizer | None = None
"""the single optimizer for the case where there is just one"""
actor: ModuleOpt | None = None
critic1: ModuleOpt | None = None
critic2: ModuleOpt | None = None
class ParamTransformer(ABC):
"""Transforms one or more parameters from the representation used by the high-level API
to the representation required by the (low-level) policy implementation.
It operates directly on a dictionary of keyword arguments, which is initially
generated from the parameter dataclass (subclass of `Params`).
"""
@abstractmethod
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
pass
@staticmethod
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
return value
class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
for k in self.keys:
del kwargs[k]
class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData):
params[self.key] = self.change_value(params[self.key], data)
@abstractmethod
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
pass
class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`.
"""
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
self.key_scheduler_factory = key_scheduler_factory
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.optim is not None
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
params[self.key_scheduler] = (
factory.create_scheduler(data.optim) if factory is not None else None
)
class ParamTransformerMultiLRScheduler(ParamTransformer):
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
if more than one factory is indeed given.
"""
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
:param key_scheduler: the key under which to store the resulting learning rate scheduler
"""
self.optim_key_list = optim_key_list
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
params,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
match len(lr_schedulers):
case 0:
lr_scheduler = None
case 1:
lr_scheduler = lr_schedulers[0]
case _:
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
params[self.key_scheduler] = lr_scheduler
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic = key_scheduler_factory_critic
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic1: str,
key_scheduler_factory_critic2: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic1 = key_scheduler_factory_critic1
self.key_factory_critic2 = key_scheduler_factory_critic2
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic1),
(data.critic2.optim, self.key_factory_critic2),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerAutoAlpha(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
alpha = self.get(kwargs, self.key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
value = params[self.key]
if isinstance(value, NoiseFactory):
params[self.key] = value.create_noise(data.envs)
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
if isinstance(value, EnvValueFactory):
kwargs[self.key] = value.create_value(data.envs)
class ParamTransformerDistributionFunction(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
if value == "default":
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
elif isinstance(value, DistributionFunctionFactory):
kwargs[self.key] = value.create_dist_fn(data.envs)
class ParamTransformerActionScaling(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
return data.envs.get_type().is_continuous()
else:
return value
class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]:
pass
@dataclass
class Params(GetParamTransformersProtocol):
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self)
for transformer in self._get_param_transformers():
transformer.transform(params, data)
return params
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
return [
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
]
@dataclass
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
actor_lr: float = 1e-3
critic_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
return [
ParamTransformerDrop("actor_lr", "critic_lr"),
ParamTransformerActorAndCriticLRScheduler(
"actor_lr_scheduler_factory",
"critic_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool | Literal["default"] = "default"
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerActionScaling("action_scaling"))
return transformers
@dataclass
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
return transformers
@dataclass
class PPOParams(A2CParams):
"""PPO specific config."""
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False
advantage_normalization: bool = True
recompute_advantage: bool = False
@dataclass
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
return [
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler(
"actor_lr_scheduler_factory",
"critic1_lr_scheduler_factory",
"critic2_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class SACParams(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass
class DDPGParams(Params, ParamsMixinActorAndCritic):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass
class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
policy_noise: float | FloatEnvValueFactory = 0.2
noise_clip: float | FloatEnvValueFactory = 0.5
update_actor_freq: int = 2
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
return transformers