Tianshou/tianshou/highlevel/params/policy_params.py

81 lines
2.3 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from typing import Dict, Any, Literal
import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
class ParamTransformer(ABC):
@abstractmethod
def transform(self, kwargs: Dict[str, Any]) -> None:
pass
@staticmethod
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
return value
@dataclass
class Params:
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
d = asdict(self)
for transformer in transformers:
transformer.transform(d)
return d
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool = True
action_bound_method: Literal["clip", "tanh"] | None = "clip"
@dataclass
class A2CParams(PGParams):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
@dataclass
class PPOParams(A2CParams):
"""PPO specific config."""
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False
advantage_normalization: bool = True
recompute_advantage: bool = False
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class SACParams(Params):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
exploration_noise: BaseNoise | Literal["default"] | None = None
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None