2023-09-25 17:56:37 +02:00
|
|
|
from abc import ABC, abstractmethod
|
2023-10-11 15:31:38 +02:00
|
|
|
from collections.abc import Sequence
|
2023-09-27 18:20:49 +02:00
|
|
|
from dataclasses import asdict, dataclass
|
|
|
|
from typing import Any, Literal, Protocol
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
import torch
|
2023-10-09 17:22:52 +02:00
|
|
|
from torch.optim.lr_scheduler import LRScheduler
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
from tianshou.exploration import BaseNoise
|
2023-09-26 17:43:16 +02:00
|
|
|
from tianshou.highlevel.env import Environments
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.highlevel.module.core import TDevice
|
|
|
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
2023-09-26 17:43:16 +02:00
|
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
2023-09-25 17:56:37 +02:00
|
|
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
2023-09-28 14:28:03 +02:00
|
|
|
from tianshou.highlevel.params.dist_fn import (
|
|
|
|
DistributionFunctionFactory,
|
|
|
|
DistributionFunctionFactoryDefault,
|
|
|
|
)
|
2023-09-27 17:20:35 +02:00
|
|
|
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
2023-09-25 17:56:37 +02:00
|
|
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
2023-09-26 15:35:18 +02:00
|
|
|
from tianshou.highlevel.params.noise import NoiseFactory
|
2023-10-10 12:55:25 +02:00
|
|
|
from tianshou.policy.modelfree.pg import TDistributionFunction
|
2023-09-26 17:43:16 +02:00
|
|
|
from tianshou.utils import MultipleLRSchedulers
|
2023-11-07 10:54:22 +01:00
|
|
|
from tianshou.utils.string import ToStringMixin
|
2023-09-26 17:43:16 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
|
|
class ParamTransformerData:
|
|
|
|
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
|
|
|
|
|
|
|
|
The representation contains the superset of all data items that are required by different types of agent factories.
|
|
|
|
An agent factory is expected to set only the attributes that are relevant to its parameters.
|
|
|
|
"""
|
|
|
|
|
|
|
|
envs: Environments
|
|
|
|
device: TDevice
|
|
|
|
optim_factory: OptimizerFactory
|
|
|
|
optim: torch.optim.Optimizer | None = None
|
|
|
|
"""the single optimizer for the case where there is just one"""
|
|
|
|
actor: ModuleOpt | None = None
|
|
|
|
critic1: ModuleOpt | None = None
|
|
|
|
critic2: ModuleOpt | None = None
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ParamTransformer(ABC):
|
2023-09-28 14:28:03 +02:00
|
|
|
"""Transforms one or more parameters from the representation used by the high-level API
|
|
|
|
to the representation required by the (low-level) policy implementation.
|
|
|
|
It operates directly on a dictionary of keyword arguments, which is initially
|
|
|
|
generated from the parameter dataclass (subclass of `Params`).
|
|
|
|
"""
|
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
@abstractmethod
|
2023-09-26 17:43:16 +02:00
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
2023-09-25 17:56:37 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
@staticmethod
|
2023-09-26 15:35:18 +02:00
|
|
|
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
|
2023-09-25 17:56:37 +02:00
|
|
|
value = d[key]
|
|
|
|
if drop:
|
|
|
|
del d[key]
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
2023-09-26 17:43:16 +02:00
|
|
|
class ParamTransformerDrop(ParamTransformer):
|
|
|
|
def __init__(self, *keys: str):
|
|
|
|
self.keys = keys
|
|
|
|
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
|
|
for k in self.keys:
|
|
|
|
del kwargs[k]
|
|
|
|
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
class ParamTransformerChangeValue(ParamTransformer):
|
|
|
|
def __init__(self, key: str):
|
|
|
|
self.key = key
|
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
2023-09-28 20:07:52 +02:00
|
|
|
params[self.key] = self.change_value(params[self.key], data)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-09-26 17:43:16 +02:00
|
|
|
class ParamTransformerLRScheduler(ParamTransformer):
|
|
|
|
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
|
|
|
|
a learning rate scheduler (added) for the data member `optim`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
|
|
|
|
self.key_scheduler_factory = key_scheduler_factory
|
|
|
|
self.key_scheduler = key_scheduler
|
|
|
|
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
|
|
assert data.optim is not None
|
|
|
|
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
|
|
|
|
params[self.key_scheduler] = (
|
|
|
|
factory.create_scheduler(data.optim) if factory is not None else None
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|
|
|
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
|
|
|
|
if more than one factory is indeed given.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
|
|
|
|
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
|
|
|
|
:param key_scheduler: the key under which to store the resulting learning rate scheduler
|
|
|
|
"""
|
|
|
|
self.optim_key_list = optim_key_list
|
|
|
|
self.key_scheduler = key_scheduler
|
|
|
|
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
|
|
lr_schedulers = []
|
|
|
|
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
|
|
|
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
|
|
|
params,
|
|
|
|
lr_scheduler_factory_key,
|
|
|
|
drop=True,
|
|
|
|
)
|
|
|
|
if lr_scheduler_factory is not None:
|
|
|
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
2023-10-09 17:22:52 +02:00
|
|
|
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
|
2023-09-26 17:43:16 +02:00
|
|
|
match len(lr_schedulers):
|
|
|
|
case 0:
|
|
|
|
lr_scheduler = None
|
|
|
|
case 1:
|
|
|
|
lr_scheduler = lr_schedulers[0]
|
|
|
|
case _:
|
|
|
|
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
|
|
|
params[self.key_scheduler] = lr_scheduler
|
|
|
|
|
|
|
|
|
2023-10-03 20:26:39 +02:00
|
|
|
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
key_scheduler_factory_actor: str,
|
|
|
|
key_scheduler_factory_critic: str,
|
|
|
|
key_scheduler: str,
|
|
|
|
):
|
|
|
|
self.key_factory_actor = key_scheduler_factory_actor
|
|
|
|
self.key_factory_critic = key_scheduler_factory_critic
|
|
|
|
self.key_scheduler = key_scheduler
|
|
|
|
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
2023-10-09 17:22:52 +02:00
|
|
|
assert data.actor is not None and data.critic1 is not None
|
2023-10-03 20:26:39 +02:00
|
|
|
transformer = ParamTransformerMultiLRScheduler(
|
|
|
|
[
|
|
|
|
(data.actor.optim, self.key_factory_actor),
|
|
|
|
(data.critic1.optim, self.key_factory_critic),
|
|
|
|
],
|
|
|
|
self.key_scheduler,
|
|
|
|
)
|
|
|
|
transformer.transform(params, data)
|
|
|
|
|
|
|
|
|
2023-09-26 17:43:16 +02:00
|
|
|
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
key_scheduler_factory_actor: str,
|
|
|
|
key_scheduler_factory_critic1: str,
|
|
|
|
key_scheduler_factory_critic2: str,
|
|
|
|
key_scheduler: str,
|
|
|
|
):
|
|
|
|
self.key_factory_actor = key_scheduler_factory_actor
|
|
|
|
self.key_factory_critic1 = key_scheduler_factory_critic1
|
|
|
|
self.key_factory_critic2 = key_scheduler_factory_critic2
|
|
|
|
self.key_scheduler = key_scheduler
|
|
|
|
|
|
|
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
2023-10-09 17:22:52 +02:00
|
|
|
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
|
2023-09-26 17:43:16 +02:00
|
|
|
transformer = ParamTransformerMultiLRScheduler(
|
|
|
|
[
|
|
|
|
(data.actor.optim, self.key_factory_actor),
|
|
|
|
(data.critic1.optim, self.key_factory_critic1),
|
|
|
|
(data.critic2.optim, self.key_factory_critic2),
|
|
|
|
],
|
|
|
|
self.key_scheduler,
|
|
|
|
)
|
|
|
|
transformer.transform(params, data)
|
|
|
|
|
|
|
|
|
|
|
|
class ParamTransformerAutoAlpha(ParamTransformer):
|
|
|
|
def __init__(self, key: str):
|
|
|
|
self.key = key
|
|
|
|
|
|
|
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
|
|
alpha = self.get(kwargs, self.key)
|
|
|
|
if isinstance(alpha, AutoAlphaFactory):
|
|
|
|
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
|
|
|
|
|
|
|
|
2023-10-10 16:12:29 +02:00
|
|
|
class ParamTransformerNoiseFactory(ParamTransformerChangeValue):
|
|
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
2023-09-26 17:43:16 +02:00
|
|
|
if isinstance(value, NoiseFactory):
|
2023-10-10 16:12:29 +02:00
|
|
|
value = value.create_noise(data.envs)
|
|
|
|
return value
|
2023-09-26 17:43:16 +02:00
|
|
|
|
|
|
|
|
2023-10-10 16:12:29 +02:00
|
|
|
class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue):
|
|
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
2023-09-27 17:20:35 +02:00
|
|
|
if isinstance(value, EnvValueFactory):
|
2023-10-10 16:12:29 +02:00
|
|
|
value = value.create_value(data.envs)
|
|
|
|
return value
|
2023-09-26 17:43:16 +02:00
|
|
|
|
|
|
|
|
2023-10-10 16:12:29 +02:00
|
|
|
class ParamTransformerDistributionFunction(ParamTransformerChangeValue):
|
|
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
2023-09-28 14:28:03 +02:00
|
|
|
if value == "default":
|
2023-10-10 16:12:29 +02:00
|
|
|
value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
2023-09-28 14:28:03 +02:00
|
|
|
elif isinstance(value, DistributionFunctionFactory):
|
2023-10-10 16:12:29 +02:00
|
|
|
value = value.create_dist_fn(data.envs)
|
|
|
|
return value
|
2023-09-28 14:28:03 +02:00
|
|
|
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
|
|
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
|
|
|
if value == "default":
|
|
|
|
return data.envs.get_type().is_continuous()
|
|
|
|
else:
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
2023-09-27 18:20:49 +02:00
|
|
|
class GetParamTransformersProtocol(Protocol):
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-09-26 17:43:16 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2023-10-17 12:05:36 +02:00
|
|
|
class Params(GetParamTransformersProtocol, ToStringMixin):
|
2023-09-26 17:43:16 +02:00
|
|
|
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
|
|
|
params = asdict(self)
|
2023-09-27 18:20:49 +02:00
|
|
|
for transformer in self._get_param_transformers():
|
2023-09-26 17:43:16 +02:00
|
|
|
transformer.transform(params, data)
|
|
|
|
return params
|
|
|
|
|
2023-09-27 18:20:49 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
return []
|
|
|
|
|
2023-09-26 17:43:16 +02:00
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
@dataclass
|
2023-09-27 18:20:49 +02:00
|
|
|
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
2023-09-26 17:43:16 +02:00
|
|
|
lr: float = 1e-3
|
2023-10-16 18:19:31 +02:00
|
|
|
"""the learning rate to use in the gradient-based optimizer"""
|
2023-09-26 17:43:16 +02:00
|
|
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
2023-10-16 18:19:31 +02:00
|
|
|
"""factory for the creation of a learning rate scheduler"""
|
2023-09-26 17:43:16 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-09-27 18:20:49 +02:00
|
|
|
return [
|
|
|
|
ParamTransformerDrop("lr"),
|
|
|
|
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
|
|
|
]
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
|
2023-10-03 20:26:39 +02:00
|
|
|
@dataclass
|
|
|
|
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
|
|
|
actor_lr: float = 1e-3
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the learning rate to use for the actor network"""
|
2023-10-03 20:26:39 +02:00
|
|
|
critic_lr: float = 1e-3
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the learning rate to use for the critic network"""
|
2023-10-03 20:26:39 +02:00
|
|
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""factory for the creation of a learning rate scheduler to use for the actor network (if any)"""
|
2023-10-03 20:26:39 +02:00
|
|
|
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""factory for the creation of a learning rate scheduler to use for the critic network (if any)"""
|
2023-10-03 20:26:39 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-10-03 20:26:39 +02:00
|
|
|
return [
|
|
|
|
ParamTransformerDrop("actor_lr", "critic_lr"),
|
|
|
|
ParamTransformerActorAndCriticLRScheduler(
|
|
|
|
"actor_lr_scheduler_factory",
|
|
|
|
"critic_lr_scheduler_factory",
|
|
|
|
"lr_scheduler",
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class ParamsMixinActionScaling(GetParamTransformersProtocol):
|
2023-09-28 20:07:52 +02:00
|
|
|
action_scaling: bool | Literal["default"] = "default"
|
|
|
|
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
2023-09-25 17:56:37 +02:00
|
|
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
method to bound action to range [-1, 1]. Only used if the action_space is continuous.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ParamsMixinExplorationNoise(GetParamTransformersProtocol):
|
|
|
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
|
|
|
"""
|
|
|
|
If not None, add noise to actions for exploration.
|
|
|
|
This is useful when solving "hard exploration" problems.
|
|
|
|
It can either be a distribution, a factory for the creation of a distribution or "default".
|
|
|
|
When set to "default", use Gaussian noise with standard deviation 0.1.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
return [ParamTransformerNoiseFactory("exploration_noise")]
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithScheduler):
|
|
|
|
discount_factor: float = 0.99
|
|
|
|
"""
|
|
|
|
discount factor (gamma) for future rewards; must be in [0, 1]
|
|
|
|
"""
|
|
|
|
reward_normalization: bool = False
|
|
|
|
"""
|
|
|
|
if True, will normalize the returns by subtracting the running mean and dividing by the running
|
|
|
|
standard deviation.
|
|
|
|
"""
|
|
|
|
deterministic_eval: bool = False
|
|
|
|
"""
|
|
|
|
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
|
|
|
|
Does not affect training.
|
|
|
|
"""
|
2023-10-10 12:55:25 +02:00
|
|
|
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
This can either be a function which maps the model output to a torch distribution or a
|
|
|
|
factory for the creation of such a function.
|
|
|
|
When set to "default", a factory which creates Gaussian distributions from mean and standard
|
|
|
|
deviation will be used for the continuous case and which creates categorical distributions
|
|
|
|
for the discrete case (see :class:`DistributionFunctionFactoryDefault`)
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
transformers = super()._get_param_transformers()
|
2023-10-17 17:42:41 +02:00
|
|
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
2023-10-10 12:55:25 +02:00
|
|
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
2023-09-28 20:07:52 +02:00
|
|
|
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
2023-10-10 12:55:25 +02:00
|
|
|
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
2023-09-28 20:07:52 +02:00
|
|
|
return transformers
|
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol):
|
|
|
|
gae_lambda: float = 0.95
|
|
|
|
"""
|
|
|
|
determines the blend between Monte Carlo and one-step temporal difference (TD) estimates of the advantage
|
|
|
|
function in general advantage estimation (GAE).
|
|
|
|
A value of 0 gives a fully TD-based estimate; lambda=1 gives a fully Monte Carlo estimate.
|
|
|
|
"""
|
|
|
|
max_batchsize: int = 256
|
|
|
|
"""the maximum size of the batch when computing general advantage estimation (GAE)"""
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation):
|
2023-09-25 17:56:37 +02:00
|
|
|
vf_coef: float = 0.5
|
2023-10-17 17:42:41 +02:00
|
|
|
"""weight (coefficient) of the value loss in the loss function"""
|
2023-09-25 17:56:37 +02:00
|
|
|
ent_coef: float = 0.01
|
2023-10-17 17:42:41 +02:00
|
|
|
"""weight (coefficient) of the entropy loss in the loss function"""
|
2023-09-25 17:56:37 +02:00
|
|
|
max_grad_norm: float | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""maximum norm for clipping gradients in backpropagation"""
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self))
|
|
|
|
return transformers
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2023-09-28 14:28:03 +02:00
|
|
|
class PPOParams(A2CParams):
|
2023-09-25 17:56:37 +02:00
|
|
|
eps_clip: float = 0.2
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
determines the range of allowed change in the policy during a policy update:
|
|
|
|
The ratio between the probabilities indicated by the new and old policy is
|
|
|
|
constrained to stay in the interval [1 - eps_clip, 1 + eps_clip].
|
|
|
|
Small values thus force the new policy to stay close to the old policy.
|
|
|
|
Typical values range between 0.1 and 0.3.
|
|
|
|
The optimal epsilon depends on the environment; more stochastic environments may need larger epsilons.
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
dual_clip: float | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
determines the lower bound clipping for the probability ratio
|
|
|
|
(corresponds to parameter c in arXiv:1912.09729, Equation 5).
|
|
|
|
If set to None, dual clipping is not used and the bounds described in parameter eps_clip apply.
|
|
|
|
If set to a float value c, the lower bound is changed from 1 - eps_clip to c,
|
|
|
|
where c < 1 - eps_clip.
|
|
|
|
Setting c > 0 reduces policy oscillation and further stabilizes training.
|
|
|
|
Typical values are between 0 and 0.5. Smaller values provide more stability.
|
|
|
|
Setting c = 0 yields PPO with only the upper bound.
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
value_clip: bool = False
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
whether to apply clipping of the predicted value function during policy learning.
|
|
|
|
Value clipping discourages large changes in value predictions between updates.
|
|
|
|
Inaccurate value predictions can lead to bad policy updates, which can cause training instability.
|
|
|
|
Clipping values prevents sporadic large errors from skewing policy updates too much.
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
advantage_normalization: bool = True
|
2023-10-17 17:42:41 +02:00
|
|
|
"""whether to apply per mini-batch advantage normalization."""
|
2023-09-25 17:56:37 +02:00
|
|
|
recompute_advantage: bool = False
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
whether to recompute advantage every update repeat as described in
|
|
|
|
https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5.
|
|
|
|
The original PPO implementation splits the data in each policy iteration
|
|
|
|
step into individual transitions and then randomly assigns them to minibatches.
|
|
|
|
This makes it impossible to compute advantages as the temporal structure is broken.
|
|
|
|
Therefore, the advantages are computed once at the beginning of each policy iteration step and
|
|
|
|
then used in minibatch policy and value function optimization.
|
|
|
|
This results in higher diversity of data in each minibatch at the cost of
|
|
|
|
using slightly stale advantage estimations.
|
|
|
|
Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning
|
|
|
|
of each pass over the data instead of just once per iteration.
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
|
2023-10-10 13:47:30 +02:00
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation):
|
2023-10-10 13:47:30 +02:00
|
|
|
optim_critic_iters: int = 5
|
2023-10-17 17:42:41 +02:00
|
|
|
"""number of times to optimize critic network per update."""
|
2023-10-10 13:47:30 +02:00
|
|
|
actor_step_size: float = 0.5
|
2023-10-17 17:42:41 +02:00
|
|
|
"""step size for actor update in natural gradient direction"""
|
2023-10-10 13:47:30 +02:00
|
|
|
advantage_normalization: bool = True
|
2023-10-17 17:42:41 +02:00
|
|
|
"""whether to do per mini-batch advantage normalization."""
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self))
|
|
|
|
return transformers
|
2023-10-10 13:47:30 +02:00
|
|
|
|
|
|
|
|
2023-10-10 14:14:00 +02:00
|
|
|
@dataclass
|
|
|
|
class TRPOParams(NPGParams):
|
|
|
|
max_kl: float = 0.01
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
maximum KL divergence, used to constrain each actor network update.
|
|
|
|
"""
|
2023-10-10 14:14:00 +02:00
|
|
|
backtrack_coeff: float = 0.8
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
coefficient with which to reduce the step size when constraints are not met.
|
|
|
|
"""
|
2023-10-10 14:14:00 +02:00
|
|
|
max_backtracks: int = 10
|
2023-10-17 17:42:41 +02:00
|
|
|
"""maximum number of times to backtrack in line search when the constraints are not met."""
|
2023-10-10 14:14:00 +02:00
|
|
|
|
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
@dataclass
|
2023-09-27 18:20:49 +02:00
|
|
|
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
2023-09-26 15:35:18 +02:00
|
|
|
actor_lr: float = 1e-3
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the learning rate to use for the actor network"""
|
2023-09-26 15:35:18 +02:00
|
|
|
critic1_lr: float = 1e-3
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the learning rate to use for the first critic network"""
|
2023-09-26 15:35:18 +02:00
|
|
|
critic2_lr: float = 1e-3
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the learning rate to use for the second critic network"""
|
2023-09-26 15:35:18 +02:00
|
|
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""factory for the creation of a learning rate scheduler to use for the actor network (if any)"""
|
2023-09-26 15:35:18 +02:00
|
|
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""factory for the creation of a learning rate scheduler to use for the first critic network (if any)"""
|
2023-09-26 15:35:18 +02:00
|
|
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
2023-10-17 17:42:41 +02:00
|
|
|
"""factory for the creation of a learning rate scheduler to use for the second critic network (if any)"""
|
2023-09-26 15:35:18 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-09-27 18:20:49 +02:00
|
|
|
return [
|
|
|
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerActorDualCriticsLRScheduler(
|
|
|
|
"actor_lr_scheduler_factory",
|
|
|
|
"critic1_lr_scheduler_factory",
|
|
|
|
"critic2_lr_scheduler_factory",
|
|
|
|
"lr_scheduler",
|
|
|
|
),
|
2023-09-27 18:20:49 +02:00
|
|
|
]
|
2023-09-26 17:43:16 +02:00
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
|
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class _SACParams(Params, ParamsMixinActorAndDualCritics):
|
2023-09-25 17:56:37 +02:00
|
|
|
tau: float = 0.005
|
2023-10-17 17:42:41 +02:00
|
|
|
"""controls the contribution of the entropy term in the overall optimization objective,
|
|
|
|
i.e. the desired amount of randomness in the optimal policy.
|
|
|
|
Higher values mean greater target entropy and therefore more randomness in the policy.
|
|
|
|
Lower values mean lower target entropy and therefore a more deterministic policy.
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
gamma: float = 0.99
|
2023-10-17 17:42:41 +02:00
|
|
|
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
|
|
|
|
alpha: float | AutoAlphaFactory = 0.2
|
|
|
|
"""
|
|
|
|
controls the relative importance (coefficient) of the entropy term in the loss function.
|
|
|
|
This can be a constant or a factory for the creation of a representation that allows the
|
|
|
|
parameter to be automatically tuned;
|
|
|
|
use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard
|
|
|
|
auto-adjusted alpha.
|
|
|
|
"""
|
2023-09-25 17:56:37 +02:00
|
|
|
estimation_step: int = 1
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of steps to look ahead"""
|
2023-09-26 15:35:18 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-09-27 18:20:49 +02:00
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
|
|
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
2023-10-03 20:26:39 +02:00
|
|
|
return transformers
|
|
|
|
|
|
|
|
|
2023-10-10 19:11:49 +02:00
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling):
|
|
|
|
deterministic_eval: bool = True
|
|
|
|
"""
|
|
|
|
whether to use deterministic action (mean of Gaussian policy) in evaluation mode instead of stochastic
|
|
|
|
action sampled by the policy. Does not affect training."""
|
2023-10-10 19:11:49 +02:00
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
transformers = super()._get_param_transformers()
|
2023-10-17 17:42:41 +02:00
|
|
|
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
|
|
|
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
2023-10-10 19:11:49 +02:00
|
|
|
return transformers
|
|
|
|
|
|
|
|
|
2023-10-17 17:42:41 +02:00
|
|
|
@dataclass
|
|
|
|
class DiscreteSACParams(_SACParams):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-10-05 15:39:32 +02:00
|
|
|
@dataclass
|
|
|
|
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
|
|
|
discount_factor: float = 0.99
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
discount factor (gamma) for future rewards; must be in [0, 1]
|
|
|
|
"""
|
2023-10-05 15:39:32 +02:00
|
|
|
estimation_step: int = 1
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of steps to look ahead"""
|
2023-10-05 15:39:32 +02:00
|
|
|
target_update_freq: int = 0
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the target network update frequency (0 if no target network is to be used)"""
|
2023-10-05 15:39:32 +02:00
|
|
|
reward_normalization: bool = False
|
2023-10-17 17:42:41 +02:00
|
|
|
"""whether to normalize the returns to Normal(0, 1)"""
|
2023-10-05 15:39:32 +02:00
|
|
|
is_double: bool = True
|
2023-10-17 17:42:41 +02:00
|
|
|
"""whether to use double Q learning"""
|
2023-10-05 15:39:32 +02:00
|
|
|
clip_loss_grad: bool = False
|
2023-10-17 17:42:41 +02:00
|
|
|
"""whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber
|
|
|
|
loss instead of the MSE loss."""
|
2023-10-05 15:39:32 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-10-05 15:39:32 +02:00
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
|
|
|
return transformers
|
|
|
|
|
|
|
|
|
2023-10-11 15:31:38 +02:00
|
|
|
@dataclass
|
|
|
|
class IQNParams(DQNParams):
|
|
|
|
sample_size: int = 32
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of samples for policy evaluation"""
|
2023-10-11 15:31:38 +02:00
|
|
|
online_sample_size: int = 8
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of samples for online model in training"""
|
2023-10-11 15:31:38 +02:00
|
|
|
target_sample_size: int = 8
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of samples for target model in training."""
|
2023-10-11 15:31:38 +02:00
|
|
|
num_quantiles: int = 200
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of quantile midpoints in the inverse cumulative distribution function of the value"""
|
2023-10-11 15:31:38 +02:00
|
|
|
hidden_sizes: Sequence[int] = ()
|
|
|
|
"""hidden dimensions to use in the IQN network"""
|
|
|
|
num_cosines: int = 64
|
|
|
|
"""number of cosines to use in the IQN network"""
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines"))
|
|
|
|
return transformers
|
|
|
|
|
|
|
|
|
2023-10-03 20:26:39 +02:00
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class DDPGParams(
|
|
|
|
Params,
|
|
|
|
ParamsMixinActorAndCritic,
|
|
|
|
ParamsMixinExplorationNoise,
|
|
|
|
ParamsMixinActionScaling,
|
|
|
|
):
|
2023-10-03 20:26:39 +02:00
|
|
|
tau: float = 0.005
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
controls the soft update of the target network.
|
|
|
|
It determines how slowly the target networks track the main networks.
|
|
|
|
Smaller tau means slower tracking and more stable learning.
|
|
|
|
"""
|
2023-10-03 20:26:39 +02:00
|
|
|
gamma: float = 0.99
|
2023-10-17 17:42:41 +02:00
|
|
|
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
|
2023-10-03 20:26:39 +02:00
|
|
|
estimation_step: int = 1
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of steps to look ahead."""
|
2023-10-03 20:26:39 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-10-03 20:26:39 +02:00
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
2023-10-17 17:42:41 +02:00
|
|
|
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
|
|
|
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
2023-09-27 18:20:49 +02:00
|
|
|
return transformers
|
2023-09-26 17:43:16 +02:00
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
|
2023-10-10 15:49:05 +02:00
|
|
|
@dataclass
|
|
|
|
class REDQParams(DDPGParams):
|
|
|
|
ensemble_size: int = 10
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of sub-networks in the critic ensemble"""
|
2023-10-10 15:49:05 +02:00
|
|
|
subset_size: int = 2
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of networks in the subset"""
|
|
|
|
alpha: float | AutoAlphaFactory = 0.2
|
|
|
|
"""
|
|
|
|
controls the relative importance (coefficient) of the entropy term in the loss function.
|
|
|
|
This can be a constant or a factory for the creation of a representation that allows the
|
|
|
|
parameter to be automatically tuned;
|
|
|
|
use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard
|
|
|
|
auto-adjusted alpha.
|
|
|
|
"""
|
2023-10-10 15:49:05 +02:00
|
|
|
estimation_step: int = 1
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of steps to look ahead"""
|
2023-10-10 15:49:05 +02:00
|
|
|
actor_delay: int = 20
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of critic updates before an actor update"""
|
2023-10-10 15:49:05 +02:00
|
|
|
deterministic_eval: bool = True
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
|
|
|
|
Does not affect training.
|
|
|
|
"""
|
2023-10-10 15:49:05 +02:00
|
|
|
target_mode: Literal["mean", "min"] = "min"
|
|
|
|
|
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
|
|
|
return transformers
|
|
|
|
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
@dataclass
|
2023-10-17 17:42:41 +02:00
|
|
|
class TD3Params(
|
|
|
|
Params,
|
|
|
|
ParamsMixinActorAndDualCritics,
|
|
|
|
ParamsMixinExplorationNoise,
|
|
|
|
ParamsMixinActionScaling,
|
|
|
|
):
|
2023-09-26 15:35:18 +02:00
|
|
|
tau: float = 0.005
|
2023-10-17 17:42:41 +02:00
|
|
|
"""
|
|
|
|
controls the soft update of the target network.
|
|
|
|
It determines how slowly the target networks track the main networks.
|
|
|
|
Smaller tau means slower tracking and more stable learning.
|
|
|
|
"""
|
2023-09-26 15:35:18 +02:00
|
|
|
gamma: float = 0.99
|
2023-10-17 17:42:41 +02:00
|
|
|
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
|
2023-09-27 17:20:35 +02:00
|
|
|
policy_noise: float | FloatEnvValueFactory = 0.2
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the scale of the the noise used in updating policy network"""
|
2023-09-27 17:20:35 +02:00
|
|
|
noise_clip: float | FloatEnvValueFactory = 0.5
|
2023-10-17 17:42:41 +02:00
|
|
|
"""determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]"""
|
2023-09-26 15:35:18 +02:00
|
|
|
update_actor_freq: int = 2
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the update frequency of actor network"""
|
2023-09-26 15:35:18 +02:00
|
|
|
estimation_step: int = 1
|
2023-10-17 17:42:41 +02:00
|
|
|
"""the number of steps to look ahead."""
|
2023-09-26 17:43:16 +02:00
|
|
|
|
2023-10-09 17:22:52 +02:00
|
|
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
2023-09-27 18:20:49 +02:00
|
|
|
transformers = super()._get_param_transformers()
|
|
|
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
2023-10-17 17:42:41 +02:00
|
|
|
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
|
|
|
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
2023-09-27 18:20:49 +02:00
|
|
|
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
|
|
|
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
|
|
|
return transformers
|