Policy objects are now parametrised by converting the parameter dataclass instances to kwargs, using some injectable conversions along the way
81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Dict, Any, Literal
|
|
|
|
import torch
|
|
|
|
from tianshou.exploration import BaseNoise
|
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
|
|
|
|
|
class ParamTransformer(ABC):
|
|
@abstractmethod
|
|
def transform(self, kwargs: Dict[str, Any]) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
|
|
value = d[key]
|
|
if drop:
|
|
del d[key]
|
|
return value
|
|
|
|
|
|
@dataclass
|
|
class Params:
|
|
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
|
|
d = asdict(self)
|
|
for transformer in transformers:
|
|
transformer.transform(d)
|
|
return d
|
|
|
|
|
|
@dataclass
|
|
class PGParams(Params):
|
|
"""Config of general policy-gradient algorithms."""
|
|
discount_factor: float = 0.99
|
|
reward_normalization: bool = False
|
|
deterministic_eval: bool = False
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
|
|
|
|
|
@dataclass
|
|
class A2CParams(PGParams):
|
|
vf_coef: float = 0.5
|
|
ent_coef: float = 0.01
|
|
max_grad_norm: float | None = None
|
|
gae_lambda: float = 0.95
|
|
max_batchsize: int = 256
|
|
|
|
|
|
@dataclass
|
|
class PPOParams(A2CParams):
|
|
"""PPO specific config."""
|
|
eps_clip: float = 0.2
|
|
dual_clip: float | None = None
|
|
value_clip: bool = False
|
|
advantage_normalization: bool = True
|
|
recompute_advantage: bool = False
|
|
lr: float = 1e-3
|
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
|
|
|
|
@dataclass
|
|
class SACParams(Params):
|
|
tau: float = 0.005
|
|
gamma: float = 0.99
|
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
|
estimation_step: int = 1
|
|
exploration_noise: BaseNoise | Literal["default"] | None = None
|
|
deterministic_eval: bool = True
|
|
action_scaling: bool = True
|
|
action_bound_method: Literal["clip"] | None = "clip"
|
|
actor_lr: float = 1e-3
|
|
critic1_lr: float = 1e-3
|
|
critic2_lr: float = 1e-3
|
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|