Tianshou/tianshou/highlevel/params/policy_params.py

653 lines
26 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol
import torch
from torch.optim.lr_scheduler import LRScheduler
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
)
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin
@dataclass(kw_only=True)
class ParamTransformerData:
"""Holds data that can be used by `ParamTransformer` instances to perform their transformation.
The representation contains the superset of all data items that are required by different types of agent factories.
An agent factory is expected to set only the attributes that are relevant to its parameters.
"""
envs: Environments
device: TDevice
optim_factory: OptimizerFactory
optim: torch.optim.Optimizer | None = None
"""the single optimizer for the case where there is just one"""
actor: ModuleOpt | None = None
critic1: ModuleOpt | None = None
critic2: ModuleOpt | None = None
class ParamTransformer(ABC):
"""Transforms one or more parameters from the representation used by the high-level API
to the representation required by the (low-level) policy implementation.
It operates directly on a dictionary of keyword arguments, which is initially
generated from the parameter dataclass (subclass of `Params`).
"""
@abstractmethod
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
pass
@staticmethod
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
return value
class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
for k in self.keys:
del kwargs[k]
class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
params[self.key] = self.change_value(params[self.key], data)
@abstractmethod
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
pass
class ParamTransformerLRScheduler(ParamTransformer):
"""Transforms a key containing a learning rate scheduler factory (removed) into a key containing
a learning rate scheduler (added) for the data member `optim`.
"""
def __init__(self, key_scheduler_factory: str, key_scheduler: str):
self.key_scheduler_factory = key_scheduler_factory
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.optim is not None
factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
params[self.key_scheduler] = (
factory.create_scheduler(data.optim) if factory is not None else None
)
class ParamTransformerMultiLRScheduler(ParamTransformer):
"""Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
if more than one factory is indeed given.
"""
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
""":param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
:param key_scheduler: the key under which to store the resulting learning rate scheduler
"""
self.optim_key_list = optim_key_list
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
params,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
match len(lr_schedulers):
case 0:
lr_scheduler = None
case 1:
lr_scheduler = lr_schedulers[0]
case _:
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
params[self.key_scheduler] = lr_scheduler
class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic = key_scheduler_factory_critic
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
def __init__(
self,
key_scheduler_factory_actor: str,
key_scheduler_factory_critic1: str,
key_scheduler_factory_critic2: str,
key_scheduler: str,
):
self.key_factory_actor = key_scheduler_factory_actor
self.key_factory_critic1 = key_scheduler_factory_critic1
self.key_factory_critic2 = key_scheduler_factory_critic2
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
(data.critic1.optim, self.key_factory_critic1),
(data.critic2.optim, self.key_factory_critic2),
],
self.key_scheduler,
)
transformer.transform(params, data)
class ParamTransformerAutoAlpha(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
alpha = self.get(kwargs, self.key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
class ParamTransformerNoiseFactory(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if isinstance(value, NoiseFactory):
value = value.create_noise(data.envs)
return value
class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if isinstance(value, EnvValueFactory):
value = value.create_value(data.envs)
return value
class ParamTransformerDistributionFunction(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
elif isinstance(value, DistributionFunctionFactory):
value = value.create_dist_fn(data.envs)
return value
class ParamTransformerActionScaling(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
return data.envs.get_type().is_continuous()
else:
return value
class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]:
pass
@dataclass
class Params(GetParamTransformersProtocol, ToStringMixin):
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self)
for transformer in self._get_param_transformers():
transformer.transform(params, data)
return params
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass
class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3
"""the learning rate to use in the gradient-based optimizer"""
lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
]
@dataclass
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
actor_lr: float = 1e-3
"""the learning rate to use for the actor network"""
critic_lr: float = 1e-3
"""the learning rate to use for the critic network"""
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler to use for the actor network (if any)"""
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler to use for the critic network (if any)"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("actor_lr", "critic_lr"),
ParamTransformerActorAndCriticLRScheduler(
"actor_lr_scheduler_factory",
"critic_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class ParamsMixinActionScaling(GetParamTransformersProtocol):
action_scaling: bool | Literal["default"] = "default"
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""
method to bound action to range [-1, 1]. Only used if the action_space is continuous.
"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass
class ParamsMixinExplorationNoise(GetParamTransformersProtocol):
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
"""
If not None, add noise to actions for exploration.
This is useful when solving "hard exploration" problems.
It can either be a distribution, a factory for the creation of a distribution or "default".
When set to "default", use Gaussian noise with standard deviation 0.1.
"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return [ParamTransformerNoiseFactory("exploration_noise")]
@dataclass
class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99
"""
discount factor (gamma) for future rewards; must be in [0, 1]
"""
reward_normalization: bool = False
"""
if True, will normalize the returns by subtracting the running mean and dividing by the running
standard deviation.
"""
deterministic_eval: bool = False
"""
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training.
"""
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
"""
This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function.
When set to "default", a factory which creates Gaussian distributions from mean and standard
deviation will be used for the continuous case and which creates categorical distributions
for the discrete case (see :class:`DistributionFunctionFactoryDefault`)
"""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
transformers.append(ParamTransformerActionScaling("action_scaling"))
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
return transformers
@dataclass
class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol):
gae_lambda: float = 0.95
"""
determines the blend between Monte Carlo and one-step temporal difference (TD) estimates of the advantage
function in general advantage estimation (GAE).
A value of 0 gives a fully TD-based estimate; lambda=1 gives a fully Monte Carlo estimate.
"""
max_batchsize: int = 256
"""the maximum size of the batch when computing general advantage estimation (GAE)"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return []
@dataclass
class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation):
vf_coef: float = 0.5
"""weight (coefficient) of the value loss in the loss function"""
ent_coef: float = 0.01
"""weight (coefficient) of the entropy loss in the loss function"""
max_grad_norm: float | None = None
"""maximum norm for clipping gradients in backpropagation"""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self))
return transformers
@dataclass
class PPOParams(A2CParams):
eps_clip: float = 0.2
"""
determines the range of allowed change in the policy during a policy update:
The ratio between the probabilities indicated by the new and old policy is
constrained to stay in the interval [1 - eps_clip, 1 + eps_clip].
Small values thus force the new policy to stay close to the old policy.
Typical values range between 0.1 and 0.3.
The optimal epsilon depends on the environment; more stochastic environments may need larger epsilons.
"""
dual_clip: float | None = None
"""
determines the lower bound clipping for the probability ratio
(corresponds to parameter c in arXiv:1912.09729, Equation 5).
If set to None, dual clipping is not used and the bounds described in parameter eps_clip apply.
If set to a float value c, the lower bound is changed from 1 - eps_clip to c,
where c < 1 - eps_clip.
Setting c > 0 reduces policy oscillation and further stabilizes training.
Typical values are between 0 and 0.5. Smaller values provide more stability.
Setting c = 0 yields PPO with only the upper bound.
"""
value_clip: bool = False
"""
whether to apply clipping of the predicted value function during policy learning.
Value clipping discourages large changes in value predictions between updates.
Inaccurate value predictions can lead to bad policy updates, which can cause training instability.
Clipping values prevents sporadic large errors from skewing policy updates too much.
"""
advantage_normalization: bool = True
"""whether to apply per mini-batch advantage normalization."""
recompute_advantage: bool = False
"""
whether to recompute advantage every update repeat as described in
https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5.
The original PPO implementation splits the data in each policy iteration
step into individual transitions and then randomly assigns them to minibatches.
This makes it impossible to compute advantages as the temporal structure is broken.
Therefore, the advantages are computed once at the beginning of each policy iteration step and
then used in minibatch policy and value function optimization.
This results in higher diversity of data in each minibatch at the cost of
using slightly stale advantage estimations.
Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning
of each pass over the data instead of just once per iteration.
"""
@dataclass
class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation):
optim_critic_iters: int = 5
"""number of times to optimize critic network per update."""
actor_step_size: float = 0.5
"""step size for actor update in natural gradient direction"""
advantage_normalization: bool = True
"""whether to do per mini-batch advantage normalization."""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self))
return transformers
@dataclass
class TRPOParams(NPGParams):
max_kl: float = 0.01
"""
maximum KL divergence, used to constrain each actor network update.
"""
backtrack_coeff: float = 0.8
"""
coefficient with which to reduce the step size when constraints are not met.
"""
max_backtracks: int = 10
"""maximum number of times to backtrack in line search when the constraints are not met."""
@dataclass
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
actor_lr: float = 1e-3
"""the learning rate to use for the actor network"""
critic1_lr: float = 1e-3
"""the learning rate to use for the first critic network"""
critic2_lr: float = 1e-3
"""the learning rate to use for the second critic network"""
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler to use for the actor network (if any)"""
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler to use for the first critic network (if any)"""
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
"""factory for the creation of a learning rate scheduler to use for the second critic network (if any)"""
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler(
"actor_lr_scheduler_factory",
"critic1_lr_scheduler_factory",
"critic2_lr_scheduler_factory",
"lr_scheduler",
),
]
@dataclass
class _SACParams(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005
"""controls the contribution of the entropy term in the overall optimization objective,
i.e. the desired amount of randomness in the optimal policy.
Higher values mean greater target entropy and therefore more randomness in the policy.
Lower values mean lower target entropy and therefore a more deterministic policy.
"""
gamma: float = 0.99
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
alpha: float | AutoAlphaFactory = 0.2
"""
controls the relative importance (coefficient) of the entropy term in the loss function.
This can be a constant or a factory for the creation of a representation that allows the
parameter to be automatically tuned;
use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard
auto-adjusted alpha.
"""
estimation_step: int = 1
"""the number of steps to look ahead"""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass
class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling):
deterministic_eval: bool = True
"""
whether to use deterministic action (mean of Gaussian policy) in evaluation mode instead of stochastic
action sampled by the policy. Does not affect training."""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
return transformers
@dataclass
class DiscreteSACParams(_SACParams):
pass
@dataclass
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99
"""
discount factor (gamma) for future rewards; must be in [0, 1]
"""
estimation_step: int = 1
"""the number of steps to look ahead"""
target_update_freq: int = 0
"""the target network update frequency (0 if no target network is to be used)"""
reward_normalization: bool = False
"""whether to normalize the returns to Normal(0, 1)"""
is_double: bool = True
"""whether to use double Q learning"""
clip_loss_grad: bool = False
"""whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber
loss instead of the MSE loss."""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
return transformers
@dataclass
class IQNParams(DQNParams):
sample_size: int = 32
"""the number of samples for policy evaluation"""
online_sample_size: int = 8
"""the number of samples for online model in training"""
target_sample_size: int = 8
"""the number of samples for target model in training."""
num_quantiles: int = 200
"""the number of quantile midpoints in the inverse cumulative distribution function of the value"""
hidden_sizes: Sequence[int] = ()
"""hidden dimensions to use in the IQN network"""
num_cosines: int = 64
"""number of cosines to use in the IQN network"""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines"))
return transformers
@dataclass
class DDPGParams(
Params,
ParamsMixinActorAndCritic,
ParamsMixinExplorationNoise,
ParamsMixinActionScaling,
):
tau: float = 0.005
"""
controls the soft update of the target network.
It determines how slowly the target networks track the main networks.
Smaller tau means slower tracking and more stable learning.
"""
gamma: float = 0.99
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
estimation_step: int = 1
"""the number of steps to look ahead."""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
return transformers
@dataclass
class REDQParams(DDPGParams):
ensemble_size: int = 10
"""the number of sub-networks in the critic ensemble"""
subset_size: int = 2
"""the number of networks in the subset"""
alpha: float | AutoAlphaFactory = 0.2
"""
controls the relative importance (coefficient) of the entropy term in the loss function.
This can be a constant or a factory for the creation of a representation that allows the
parameter to be automatically tuned;
use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard
auto-adjusted alpha.
"""
estimation_step: int = 1
"""the number of steps to look ahead"""
actor_delay: int = 20
"""the number of critic updates before an actor update"""
deterministic_eval: bool = True
"""
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training.
"""
target_mode: Literal["mean", "min"] = "min"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass
class TD3Params(
Params,
ParamsMixinActorAndDualCritics,
ParamsMixinExplorationNoise,
ParamsMixinActionScaling,
):
tau: float = 0.005
"""
controls the soft update of the target network.
It determines how slowly the target networks track the main networks.
Smaller tau means slower tracking and more stable learning.
"""
gamma: float = 0.99
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
policy_noise: float | FloatEnvValueFactory = 0.2
"""the scale of the the noise used in updating policy network"""
noise_clip: float | FloatEnvValueFactory = 0.5
"""determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]"""
update_actor_freq: int = 2
"""the update frequency of actor network"""
estimation_step: int = 1
"""the number of steps to look ahead."""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
return transformers