Improve high-level policy parametrisation
Policy objects are now parametrised by converting the parameter dataclass instances to kwargs, using some injectable conversions along the way
This commit is contained in:
parent
37dc07e487
commit
367778d37f
@ -9,13 +9,13 @@ from jsonargparse import CLI
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.agent import PPOConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.optim import LinearLRSchedulerFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LinearLRSchedulerFactory
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
|
||||
|
||||
def main(
|
||||
@ -65,22 +65,21 @@ def main(
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
|
||||
.with_ppo_params(
|
||||
PPOConfig(
|
||||
gamma=gamma,
|
||||
PPOParams(
|
||||
discount_factor=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
action_bound_method=bound_action_method,
|
||||
rew_norm=rew_norm,
|
||||
reward_normalization=rew_norm,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
value_clip=value_clip,
|
||||
norm_adv=norm_adv,
|
||||
advantage_normalization=norm_adv,
|
||||
eps_clip=eps_clip,
|
||||
dual_clip=dual_clip,
|
||||
recompute_adv=recompute_adv,
|
||||
dist_fn=dist_fn,
|
||||
recompute_advantage=recompute_adv,
|
||||
lr=lr,
|
||||
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
|
||||
if lr_decay
|
||||
|
@ -7,7 +7,8 @@ from collections.abc import Sequence
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig
|
||||
from tianshou.highlevel.params.policy_params import SACParams
|
||||
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
@ -58,7 +59,7 @@ def main(
|
||||
experiment = (
|
||||
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
.with_sac_params(
|
||||
SACConfig(
|
||||
SACParams(
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,
|
||||
|
@ -1,10 +1,8 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
@ -13,9 +11,14 @@ from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
|
||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
@ -128,178 +131,132 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAgentConfig:
|
||||
"""Config common to most RL algorithms."""
|
||||
class ParamTransformerDrop(ParamTransformer):
|
||||
def __init__(self, *keys: str):
|
||||
self.keys = keys
|
||||
|
||||
gamma: float = 0.99
|
||||
"""Discount factor"""
|
||||
gae_lambda: float = 0.95
|
||||
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
||||
rew_norm: bool = True
|
||||
"""Whether to normalize rewards"""
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
for k in self.keys:
|
||||
del kwargs[k]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PGConfig(RLAgentConfig):
|
||||
"""Config of general policy-gradient algorithms."""
|
||||
class ParamTransformerLRScheduler(ParamTransformer):
|
||||
def __init__(self, optim: torch.optim.Optimizer):
|
||||
self.optim = optim
|
||||
|
||||
ent_coef: float = 0.0
|
||||
vf_coef: float = 0.25
|
||||
max_grad_norm: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOConfig(PGConfig):
|
||||
"""PPO specific config."""
|
||||
|
||||
value_clip: bool = False
|
||||
norm_adv: bool = False
|
||||
"""Whether to normalize advantages"""
|
||||
eps_clip: float = 0.2
|
||||
dual_clip: float | None = None
|
||||
recompute_adv: bool = True
|
||||
dist_fn: Callable = None
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
||||
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
|
||||
|
||||
|
||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
config: PPOConfig,
|
||||
params: PPOParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.optimizer_factory = optimizer_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.config = config
|
||||
self.config = params
|
||||
self.dist_fn = dist_fn
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
|
||||
if self.config.lr_scheduler_factory is not None:
|
||||
lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim)
|
||||
else:
|
||||
lr_scheduler = None
|
||||
kwargs = self.config.create_kwargs(
|
||||
ParamTransformerDrop("lr"),
|
||||
ParamTransformerLRScheduler(optim))
|
||||
return PPOPolicy(
|
||||
# nn-stuff
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist_fn=self.config.dist_fn,
|
||||
lr_scheduler=lr_scheduler,
|
||||
# env-stuff
|
||||
actor=actor,
|
||||
critic=critic,
|
||||
optim=optim,
|
||||
dist_fn=self.dist_fn,
|
||||
action_space=envs.get_action_space(),
|
||||
action_scaling=True,
|
||||
# general_config
|
||||
discount_factor=self.config.gamma,
|
||||
gae_lambda=self.config.gae_lambda,
|
||||
reward_normalization=self.config.rew_norm,
|
||||
action_bound_method=self.config.action_bound_method,
|
||||
# pg_config
|
||||
max_grad_norm=self.config.max_grad_norm,
|
||||
vf_coef=self.config.vf_coef,
|
||||
ent_coef=self.config.ent_coef,
|
||||
# ppo_config
|
||||
eps_clip=self.config.eps_clip,
|
||||
value_clip=self.config.value_clip,
|
||||
dual_clip=self.config.dual_clip,
|
||||
advantage_normalization=self.config.norm_adv,
|
||||
recompute_advantage=self.config.recompute_adv,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class AutoAlphaFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_auto_alpha(
|
||||
self,
|
||||
envs: Environments,
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
):
|
||||
pass
|
||||
class ParamTransformerAlpha(ParamTransformer):
|
||||
def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
|
||||
self.envs = envs
|
||||
self.optim_factory = optim_factory
|
||||
self.device = device
|
||||
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
key = "alpha"
|
||||
alpha = self.get(kwargs, key)
|
||||
if isinstance(alpha, AutoAlphaFactory):
|
||||
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
|
||||
|
||||
|
||||
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
||||
def __init__(self, lr: float = 3e-4):
|
||||
self.lr = lr
|
||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
|
||||
self.optim_key_list = optim_key_list
|
||||
|
||||
def create_auto_alpha(
|
||||
self,
|
||||
envs: Environments,
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||
target_entropy = -np.prod(envs.get_action_shape())
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||
return target_entropy, log_alpha, alpha_optim
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||
reward_normalization: bool = False
|
||||
estimation_step: int = 1
|
||||
deterministic_eval: bool = True
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
lr_schedulers = []
|
||||
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
|
||||
if lr_scheduler_factory is not None:
|
||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||
match len(lr_schedulers):
|
||||
case 0:
|
||||
lr_scheduler = None
|
||||
case 1:
|
||||
lr_scheduler = lr_schedulers[0]
|
||||
case _:
|
||||
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
||||
kwargs["lr_scheduler"] = lr_scheduler
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig,
|
||||
params: SACParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
exploration_noise: BaseNoise | None = None,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.critic2_factory = critic2_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.exploration_noise = exploration_noise
|
||||
self.optim_factory = optim_factory
|
||||
self.config = config
|
||||
self.params = params
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
|
||||
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
|
||||
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
|
||||
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
|
||||
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
|
||||
if isinstance(self.config.alpha, AutoAlphaFactory):
|
||||
alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device)
|
||||
else:
|
||||
alpha = self.config.alpha
|
||||
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
|
||||
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
|
||||
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerMultiLRScheduler([
|
||||
(actor_optim, "actor_lr_scheduler_factory"),
|
||||
(critic1_optim, "critic1_lr_scheduler_factory"),
|
||||
(critic2_optim, "critic2_lr_scheduler_factory")]
|
||||
),
|
||||
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
|
||||
return SACPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
tau=self.config.tau,
|
||||
gamma=self.config.gamma,
|
||||
alpha=alpha,
|
||||
estimation_step=self.config.estimation_step,
|
||||
actor=actor,
|
||||
actor_optim=actor_optim,
|
||||
critic=critic1,
|
||||
critic_optim=critic1_optim,
|
||||
critic2=critic2,
|
||||
critic2_optim=critic2_optim,
|
||||
action_space=envs.get_action_space(),
|
||||
deterministic_eval=self.config.deterministic_eval,
|
||||
exploration_noise=self.exploration_noise,
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs
|
||||
)
|
||||
|
@ -28,19 +28,22 @@ class Environments(ABC):
|
||||
self.test_envs = test_envs
|
||||
|
||||
def info(self) -> dict[str, Any]:
|
||||
return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()}
|
||||
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
|
||||
|
||||
@abstractmethod
|
||||
def get_action_shape(self) -> TShape:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> TShape:
|
||||
def get_observation_shape(self) -> TShape:
|
||||
pass
|
||||
|
||||
def get_action_space(self) -> gym.Space:
|
||||
return self.env.action_space
|
||||
|
||||
def get_observation_space(self) -> gym.Space:
|
||||
return self.env.observation_space
|
||||
|
||||
@abstractmethod
|
||||
def get_type(self) -> EnvType:
|
||||
pass
|
||||
@ -75,7 +78,7 @@ class ContinuousEnvironments(Environments):
|
||||
def get_action_shape(self) -> TShape:
|
||||
return self.action_shape
|
||||
|
||||
def get_state_shape(self) -> TShape:
|
||||
def get_observation_shape(self) -> TShape:
|
||||
return self.state_shape
|
||||
|
||||
def get_type(self):
|
||||
|
@ -2,13 +2,13 @@ from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pprint import pprint
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, TypeVar, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig
|
||||
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
@ -19,7 +19,9 @@ from tianshou.highlevel.module import (
|
||||
DefaultCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.trainer import BaseTrainer
|
||||
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
@ -294,13 +296,15 @@ class PPOExperimentBuilder(
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
self._params: PPOConfig = PPOConfig()
|
||||
self._params: PPOParams = PPOParams()
|
||||
self._dist_fn = dist_fn
|
||||
|
||||
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder":
|
||||
def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder":
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@ -312,6 +316,7 @@ class PPOExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._dist_fn
|
||||
)
|
||||
|
||||
|
||||
@ -327,8 +332,12 @@ class SACExperimentBuilder(
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
self._params: SACConfig = SACConfig()
|
||||
self._params: SACParams = SACParams()
|
||||
|
||||
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder":
|
||||
def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder":
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(),
|
||||
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory())
|
||||
|
@ -91,7 +91,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
net_a = Net(
|
||||
envs.get_state_shape(),
|
||||
envs.get_observation_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
@ -148,7 +148,7 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_state_shape(),
|
||||
envs.get_observation_shape(),
|
||||
action_shape=action_shape,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
concat=use_action,
|
||||
|
@ -2,13 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
|
||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
@ -44,19 +40,3 @@ class AdamOptimizerFactory(OptimizerFactory):
|
||||
)
|
||||
|
||||
|
||||
class LRSchedulerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
pass
|
||||
|
||||
|
||||
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
max_update_num = (
|
||||
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
||||
* self.sampling_config.num_epochs
|
||||
)
|
||||
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
0
tianshou/highlevel/params/__init__.py
Normal file
0
tianshou/highlevel/params/__init__.py
Normal file
35
tianshou/highlevel/params/alpha.py
Normal file
35
tianshou/highlevel/params/alpha.py
Normal file
@ -0,0 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module import TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
|
||||
|
||||
class AutoAlphaFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_auto_alpha(
|
||||
self,
|
||||
envs: Environments,
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
||||
def __init__(self, lr: float = 3e-4):
|
||||
self.lr = lr
|
||||
|
||||
def create_auto_alpha(
|
||||
self,
|
||||
envs: Environments,
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||
target_entropy = -np.prod(envs.get_action_shape())
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||
return target_entropy, log_alpha, alpha_optim
|
25
tianshou/highlevel/params/lr_scheduler.py
Normal file
25
tianshou/highlevel/params/lr_scheduler.py
Normal file
@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
|
||||
|
||||
class LRSchedulerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
pass
|
||||
|
||||
|
||||
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
max_update_num = (
|
||||
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
||||
* self.sampling_config.num_epochs
|
||||
)
|
||||
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
80
tianshou/highlevel/params/policy_params.py
Normal file
80
tianshou/highlevel/params/policy_params.py
Normal file
@ -0,0 +1,80 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
|
||||
|
||||
class ParamTransformer(ABC):
|
||||
@abstractmethod
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
|
||||
value = d[key]
|
||||
if drop:
|
||||
del d[key]
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
for transformer in transformers:
|
||||
transformer.transform(d)
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class PGParams(Params):
|
||||
"""Config of general policy-gradient algorithms."""
|
||||
discount_factor: float = 0.99
|
||||
reward_normalization: bool = False
|
||||
deterministic_eval: bool = False
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2CParams(PGParams):
|
||||
vf_coef: float = 0.5
|
||||
ent_coef: float = 0.01
|
||||
max_grad_norm: float | None = None
|
||||
gae_lambda: float = 0.95
|
||||
max_batchsize: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOParams(A2CParams):
|
||||
"""PPO specific config."""
|
||||
eps_clip: float = 0.2
|
||||
dual_clip: float | None = None
|
||||
value_clip: bool = False
|
||||
advantage_normalization: bool = True
|
||||
recompute_advantage: bool = False
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACParams(Params):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||
estimation_step: int = 1
|
||||
exploration_noise: BaseNoise | Literal["default"] | None = None
|
||||
deterministic_eval: bool = True
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
Loading…
x
Reference in New Issue
Block a user