Tianshou/tianshou/highlevel/params/policy_params.py
Dominik Jain 17ef4dd5eb Support REDQ in high-level API
* Implement example mujoco_redq_hl
* Add abstraction CriticEnsembleFactory with default implementations
  to suit REDQ
* Fix type annotation of linear_layer in Net, MLP, Critic
  (was incompatible with REDQ usage)
2023-10-18 20:44:17 +02:00

437 lines
16 KiB
Python

from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol
import torch
from torch.optim.lr_scheduler import LRScheduler
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
)
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils import MultipleLRSchedulers
@dataclass(kw_only=True)
class ParamTransformerData:
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
The representation contains the superset of all data items that are required by different types of agent factories.
An agent factory is expected to set only the attributes that are relevant to its parameters.
"""
envs: Environments
device: TDevice
optim_factory: OptimizerFactory
optim: torch.optim.Optimizer | None = None
"""the single optimizer for the case where there is just one"""
actor: ModuleOpt | None = None
critic1: ModuleOpt | None = None
critic2: ModuleOpt | None = None
class ParamTransformer(ABC):
"""Transforms one or more parameters from the representation used by the high-level API
to the representation required by the (low-level) policy implementation.
It operates directly on a dictionary of keyword arguments, which is initially
generated from the parameter dataclass (subclass of `Params`).
"""
@abstractmethod
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
pass
@staticmethod
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
return value
class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
for k in self.keys:
del kwargs[k]
class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
params[self.key] = self.change_value(params[self.key], data)
@abstractmethod
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
pass
class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`.
"""
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
self.key_scheduler_factory = key_scheduler_factory
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.optim is not None
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
params[self.key_scheduler] = (
factory.create_scheduler(data.optim) if factory is not None else None
)
class ParamTransformerMultiLRScheduler(ParamTransformer):
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
if more than one factory is indeed given.
"""
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
:param key_scheduler: the key under which to store the resulting learning rate scheduler
"""
self.optim_key_list = optim_key_list
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
params,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
match len(lr_schedulers):
case 0:
lr_scheduler = None
case 1:
lr_scheduler = lr_schedulers[0]
case _:
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
params[self.key_scheduler] = lr_scheduler
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic = key_scheduler_factory_critic
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic1: str,
key_scheduler_factory_critic2: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic1 = key_scheduler_factory_critic1
self.key_factory_critic2 = key_scheduler_factory_critic2
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic1),
(data.critic2.optim, self.key_factory_critic2),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerAutoAlpha(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
alpha = self.get(kwargs, self.key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
value = params[self.key]
if isinstance(value, NoiseFactory):
params[self.key] = value.create_noise(data.envs)
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
if isinstance(value, EnvValueFactory):
kwargs[self.key] = value.create_value(data.envs)
class ParamTransformerDistributionFunction(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
if value == "default":
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
elif isinstance(value, DistributionFunctionFactory):
kwargs[self.key] = value.create_dist_fn(data.envs)
class ParamTransformerActionScaling(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
return data.envs.get_type().is_continuous()
else:
return value
class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]:
pass
@dataclass
class Params(GetParamTransformersProtocol):
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self)
for transformer in self._get_param_transformers():
transformer.transform(params, data)
return params
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
]
@dataclass
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
actor_lr: float = 1e-3
critic_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("actor_lr", "critic_lr"),
ParamTransformerActorAndCriticLRScheduler(
"actor_lr_scheduler_factory",
"critic_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class PGParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool | Literal["default"] = "default"
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
transformers.append(ParamTransformerActionScaling("action_scaling"))
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
return transformers
@dataclass
class A2CParams(PGParams):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
@dataclass
class PPOParams(A2CParams):
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False
advantage_normalization: bool = True
recompute_advantage: bool = False
@dataclass
class NPGParams(PGParams):
optim_critic_iters: int = 5
actor_step_size: float = 0.5
advantage_normalization: bool = True
gae_lambda: float = 0.95
max_batchsize: int = 256
@dataclass
class TRPOParams(NPGParams):
max_kl: float = 0.01
backtrack_coeff: float = 0.8
max_backtracks: int = 10
@dataclass
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler(
"actor_lr_scheduler_factory",
"critic1_lr_scheduler_factory",
"critic2_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class SACParams(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99
estimation_step: int = 1
target_update_freq: int = 0
reward_normalization: bool = False
is_double: bool = True
clip_loss_grad: bool = False
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
return transformers
@dataclass
class DDPGParams(Params, ParamsMixinActorAndCritic):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
return transformers
@dataclass
class REDQParams(DDPGParams):
ensemble_size: int = 10
subset_size: int = 2
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
actor_delay: int = 20
deterministic_eval: bool = True
target_mode: Literal["mean", "min"] = "min"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass
class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
policy_noise: float | FloatEnvValueFactory = 0.2
noise_clip: float | FloatEnvValueFactory = 0.5
update_actor_freq: int = 2
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
return transformers