Improve high-level policy parametrisation

Policy objects are now parametrised by converting the parameter
dataclass instances to kwargs, using some injectable conversions
along the way
This commit is contained in:
Dominik Jain 2023-09-25 17:56:37 +02:00
parent 37dc07e487
commit 367778d37f
11 changed files with 254 additions and 165 deletions

View File

@ -9,13 +9,13 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import PPOConfig
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
PPOExperimentBuilder, PPOExperimentBuilder,
RLExperimentConfig, RLExperimentConfig,
) )
from tianshou.highlevel.optim import LinearLRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LinearLRSchedulerFactory
from tianshou.highlevel.params.policy_params import PPOParams
def main( def main(
@ -65,22 +65,21 @@ def main(
return Independent(Normal(*logits), 1) return Independent(Normal(*logits), 1)
experiment = ( experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config) PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
.with_ppo_params( .with_ppo_params(
PPOConfig( PPOParams(
gamma=gamma, discount_factor=gamma,
gae_lambda=gae_lambda, gae_lambda=gae_lambda,
action_bound_method=bound_action_method, action_bound_method=bound_action_method,
rew_norm=rew_norm, reward_normalization=rew_norm,
ent_coef=ent_coef, ent_coef=ent_coef,
vf_coef=vf_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
value_clip=value_clip, value_clip=value_clip,
norm_adv=norm_adv, advantage_normalization=norm_adv,
eps_clip=eps_clip, eps_clip=eps_clip,
dual_clip=dual_clip, dual_clip=dual_clip,
recompute_adv=recompute_adv, recompute_advantage=recompute_adv,
dist_fn=dist_fn,
lr=lr, lr=lr,
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config) lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
if lr_decay if lr_decay

View File

@ -7,7 +7,8 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig from tianshou.highlevel.params.policy_params import SACParams
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperimentConfig, RLExperimentConfig,
@ -58,7 +59,7 @@ def main(
experiment = ( experiment = (
SACExperimentBuilder(experiment_config, env_factory, sampling_config) SACExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_sac_params( .with_sac_params(
SACConfig( SACParams(
tau=tau, tau=tau,
gamma=gamma, gamma=gamma,
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha, alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,

View File

@ -1,10 +1,8 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from typing import Dict, Any, List, Tuple
from typing import Literal
import numpy as np
import torch import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
@ -13,9 +11,14 @@ from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
@ -128,178 +131,132 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
) )
@dataclass class ParamTransformerDrop(ParamTransformer):
class RLAgentConfig: def __init__(self, *keys: str):
"""Config common to most RL algorithms.""" self.keys = keys
gamma: float = 0.99 def transform(self, kwargs: Dict[str, Any]) -> None:
"""Discount factor""" for k in self.keys:
gae_lambda: float = 0.95 del kwargs[k]
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@dataclass class ParamTransformerLRScheduler(ParamTransformer):
class PGConfig(RLAgentConfig): def __init__(self, optim: torch.optim.Optimizer):
"""Config of general policy-gradient algorithms.""" self.optim = optim
ent_coef: float = 0.0 def transform(self, kwargs: Dict[str, Any]) -> None:
vf_coef: float = 0.25 factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
max_grad_norm: float = 0.5 kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
@dataclass
class PPOConfig(PGConfig):
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True
dist_fn: Callable = None
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
class PPOAgentFactory(OnpolicyAgentFactory): class PPOAgentFactory(OnpolicyAgentFactory):
def __init__( def __init__(
self, self,
config: PPOConfig, params: PPOParams,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
): ):
super().__init__(sampling_config) super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory self.critic_factory = critic_factory
self.actor_factory = actor_factory self.actor_factory = actor_factory
self.config = config self.config = params
self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=False) critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr) optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
if self.config.lr_scheduler_factory is not None: kwargs = self.config.create_kwargs(
lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim) ParamTransformerDrop("lr"),
else: ParamTransformerLRScheduler(optim))
lr_scheduler = None
return PPOPolicy( return PPOPolicy(
# nn-stuff actor=actor,
actor, critic=critic,
critic, optim=optim,
optim, dist_fn=self.dist_fn,
dist_fn=self.config.dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
action_space=envs.get_action_space(), action_space=envs.get_action_space(),
action_scaling=True, **kwargs
# general_config
discount_factor=self.config.gamma,
gae_lambda=self.config.gae_lambda,
reward_normalization=self.config.rew_norm,
action_bound_method=self.config.action_bound_method,
# pg_config
max_grad_norm=self.config.max_grad_norm,
vf_coef=self.config.vf_coef,
ent_coef=self.config.ent_coef,
# ppo_config
eps_clip=self.config.eps_clip,
value_clip=self.config.value_clip,
dual_clip=self.config.dual_clip,
advantage_normalization=self.config.norm_adv,
recompute_advantage=self.config.recompute_adv,
) )
class AutoAlphaFactory(ABC): class ParamTransformerAlpha(ParamTransformer):
@abstractmethod def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
def create_auto_alpha( self.envs = envs
self, self.optim_factory = optim_factory
envs: Environments, self.device = device
optim_factory: OptimizerFactory,
device: TDevice, def transform(self, kwargs: Dict[str, Any]) -> None:
): key = "alpha"
pass alpha = self.get(kwargs, key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name? class ParamTransformerMultiLRScheduler(ParamTransformer):
def __init__(self, lr: float = 3e-4): def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
self.lr = lr self.optim_key_list = optim_key_list
def create_auto_alpha( def transform(self, kwargs: Dict[str, Any]) -> None:
self, lr_schedulers = []
envs: Environments, for optim, lr_scheduler_factory_key in self.optim_key_list:
optim_factory: OptimizerFactory, lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
device: TDevice, if lr_scheduler_factory is not None:
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
target_entropy = -np.prod(envs.get_action_shape()) match len(lr_schedulers):
log_alpha = torch.zeros(1, requires_grad=True, device=device) case 0:
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) lr_scheduler = None
return target_entropy, log_alpha, alpha_optim case 1:
lr_scheduler = lr_schedulers[0]
case _:
@dataclass lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
class SACConfig: kwargs["lr_scheduler"] = lr_scheduler
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
class SACAgentFactory(OffpolicyAgentFactory): class SACAgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,
config: SACConfig, params: SACParams,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic1_factory: CriticFactory, critic1_factory: CriticFactory,
critic2_factory: CriticFactory, critic2_factory: CriticFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
exploration_noise: BaseNoise | None = None,
): ):
super().__init__(sampling_config) super().__init__(sampling_config)
self.critic2_factory = critic2_factory self.critic2_factory = critic2_factory
self.critic1_factory = critic1_factory self.critic1_factory = critic1_factory
self.actor_factory = actor_factory self.actor_factory = actor_factory
self.exploration_noise = exploration_noise
self.optim_factory = optim_factory self.optim_factory = optim_factory
self.config = config self.params = params
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True) critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True) critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr) actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr) critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr) critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
if isinstance(self.config.alpha, AutoAlphaFactory): kwargs = self.params.create_kwargs(
alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device) ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
else: ParamTransformerMultiLRScheduler([
alpha = self.config.alpha (actor_optim, "actor_lr_scheduler_factory"),
(critic1_optim, "critic1_lr_scheduler_factory"),
(critic2_optim, "critic2_lr_scheduler_factory")]
),
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
return SACPolicy( return SACPolicy(
actor, actor=actor,
actor_optim, actor_optim=actor_optim,
critic1, critic=critic1,
critic1_optim, critic_optim=critic1_optim,
critic2, critic2=critic2,
critic2_optim, critic2_optim=critic2_optim,
tau=self.config.tau,
gamma=self.config.gamma,
alpha=alpha,
estimation_step=self.config.estimation_step,
action_space=envs.get_action_space(), action_space=envs.get_action_space(),
deterministic_eval=self.config.deterministic_eval, observation_space=envs.get_observation_space(),
exploration_noise=self.exploration_noise, **kwargs
) )

View File

@ -28,19 +28,22 @@ class Environments(ABC):
self.test_envs = test_envs self.test_envs = test_envs
def info(self) -> dict[str, Any]: def info(self) -> dict[str, Any]:
return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()} return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
@abstractmethod @abstractmethod
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TShape:
pass pass
@abstractmethod @abstractmethod
def get_state_shape(self) -> TShape: def get_observation_shape(self) -> TShape:
pass pass
def get_action_space(self) -> gym.Space: def get_action_space(self) -> gym.Space:
return self.env.action_space return self.env.action_space
def get_observation_space(self) -> gym.Space:
return self.env.observation_space
@abstractmethod @abstractmethod
def get_type(self) -> EnvType: def get_type(self) -> EnvType:
pass pass
@ -75,7 +78,7 @@ class ContinuousEnvironments(Environments):
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TShape:
return self.action_shape return self.action_shape
def get_state_shape(self) -> TShape: def get_observation_shape(self) -> TShape:
return self.state_shape return self.state_shape
def get_type(self): def get_type(self):

View File

@ -2,13 +2,13 @@ from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pprint from pprint import pprint
from typing import Generic, TypeVar from typing import Generic, TypeVar, Callable
import numpy as np import numpy as np
import torch import torch
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
@ -19,7 +19,9 @@ from tianshou.highlevel.module import (
DefaultCriticFactory, DefaultCriticFactory,
) )
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.highlevel.params.policy_params import PPOParams, SACParams
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -294,13 +296,15 @@ class PPOExperimentBuilder(
experiment_config: RLExperimentConfig, experiment_config: RLExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self) _BuilderMixinActorFactory.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOConfig = PPOConfig() self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder": def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder":
self._params = params self._params = params
return self return self
@ -312,6 +316,7 @@ class PPOExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._dist_fn
) )
@ -327,8 +332,12 @@ class SACExperimentBuilder(
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self) _BuilderMixinActorFactory.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self) _BuilderMixinDualCriticFactory.__init__(self)
self._params: SACConfig = SACConfig() self._params: SACParams = SACParams()
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder": def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder":
self._params = params self._params = params
return self return self
def _create_agent_factory(self) -> AgentFactory:
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory())

View File

@ -91,7 +91,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net( net_a = Net(
envs.get_state_shape(), envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
activation=nn.Tanh, activation=nn.Tanh,
device=device, device=device,
@ -148,7 +148,7 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0 action_shape = envs.get_action_shape() if use_action else 0
net_c = Net( net_c = Net(
envs.get_state_shape(), envs.get_observation_shape(),
action_shape=action_shape, action_shape=action_shape,
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
concat=use_action, concat=use_action,

View File

@ -2,13 +2,9 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any from typing import Any
import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig
TParams = Iterable[Tensor] | Iterable[dict[str, Any]] TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
@ -44,19 +40,3 @@ class AdamOptimizerFactory(OptimizerFactory):
) )
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

View File

View File

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import TDevice
from tianshou.highlevel.optim import OptimizerFactory
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim

View File

@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
from tianshou.highlevel.config import RLSamplingConfig
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

View File

@ -0,0 +1,80 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from typing import Dict, Any, Literal
import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
class ParamTransformer(ABC):
@abstractmethod
def transform(self, kwargs: Dict[str, Any]) -> None:
pass
@staticmethod
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
return value
@dataclass
class Params:
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
d = asdict(self)
for transformer in transformers:
transformer.transform(d)
return d
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool = True
action_bound_method: Literal["clip", "tanh"] | None = "clip"
@dataclass
class A2CParams(PGParams):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
@dataclass
class PPOParams(A2CParams):
"""PPO specific config."""
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False
advantage_normalization: bool = True
recompute_advantage: bool = False
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class SACParams(Params):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
exploration_noise: BaseNoise | Literal["default"] | None = None
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None