diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index ac0d2e1..f4ef4fc 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -9,13 +9,13 @@ from jsonargparse import CLI from torch.distributions import Independent, Normal from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.agent import PPOConfig from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.experiment import ( PPOExperimentBuilder, RLExperimentConfig, ) -from tianshou.highlevel.optim import LinearLRSchedulerFactory +from tianshou.highlevel.params.lr_scheduler import LinearLRSchedulerFactory +from tianshou.highlevel.params.policy_params import PPOParams def main( @@ -65,22 +65,21 @@ def main( return Independent(Normal(*logits), 1) experiment = ( - PPOExperimentBuilder(experiment_config, env_factory, sampling_config) + PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn) .with_ppo_params( - PPOConfig( - gamma=gamma, + PPOParams( + discount_factor=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - rew_norm=rew_norm, + reward_normalization=rew_norm, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, value_clip=value_clip, - norm_adv=norm_adv, + advantage_normalization=norm_adv, eps_clip=eps_clip, dual_clip=dual_clip, - recompute_adv=recompute_adv, - dist_fn=dist_fn, + recompute_advantage=recompute_adv, lr=lr, lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config) if lr_decay diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 1d5ceb2..8f4ddd0 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,7 +7,8 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig +from tianshou.highlevel.params.policy_params import SACParams +from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.experiment import ( RLExperimentConfig, @@ -58,7 +59,7 @@ def main( experiment = ( SACExperimentBuilder(experiment_config, env_factory, sampling_config) .with_sac_params( - SACConfig( + SACParams( tau=tau, gamma=gamma, alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 91fe3c4..98ec59e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,10 +1,8 @@ import os from abc import ABC, abstractmethod from collections.abc import Callable -from dataclasses import dataclass -from typing import Literal +from typing import Dict, Any, List, Tuple -import numpy as np import torch from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer @@ -13,9 +11,14 @@ from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import Logger from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice -from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory +from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy +from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer +from tianshou.utils import MultipleLRSchedulers from tianshou.utils.net.common import ActorCritic CHECKPOINT_DICT_KEY_MODEL = "model" @@ -128,178 +131,132 @@ class OffpolicyAgentFactory(AgentFactory, ABC): ) -@dataclass -class RLAgentConfig: - """Config common to most RL algorithms.""" +class ParamTransformerDrop(ParamTransformer): + def __init__(self, *keys: str): + self.keys = keys - gamma: float = 0.99 - """Discount factor""" - gae_lambda: float = 0.95 - """For Generalized Advantage Estimate (equivalent to TD(lambda))""" - action_bound_method: Literal["clip", "tanh"] | None = "clip" - """How to map original actions in range (-inf, inf) to [-1, 1]""" - rew_norm: bool = True - """Whether to normalize rewards""" + def transform(self, kwargs: Dict[str, Any]) -> None: + for k in self.keys: + del kwargs[k] -@dataclass -class PGConfig(RLAgentConfig): - """Config of general policy-gradient algorithms.""" +class ParamTransformerLRScheduler(ParamTransformer): + def __init__(self, optim: torch.optim.Optimizer): + self.optim = optim - ent_coef: float = 0.0 - vf_coef: float = 0.25 - max_grad_norm: float = 0.5 - - -@dataclass -class PPOConfig(PGConfig): - """PPO specific config.""" - - value_clip: bool = False - norm_adv: bool = False - """Whether to normalize advantages""" - eps_clip: float = 0.2 - dual_clip: float | None = None - recompute_adv: bool = True - dist_fn: Callable = None - lr: float = 1e-3 - lr_scheduler_factory: LRSchedulerFactory | None = None + def transform(self, kwargs: Dict[str, Any]) -> None: + factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True) + kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None class PPOAgentFactory(OnpolicyAgentFactory): def __init__( self, - config: PPOConfig, + params: PPOParams, sampling_config: RLSamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, + dist_fn: Callable[[TDistParams], torch.distributions.Distribution], ): super().__init__(sampling_config) self.optimizer_factory = optimizer_factory self.critic_factory = critic_factory self.actor_factory = actor_factory - self.config = config + self.config = params + self.dist_fn = dist_fn def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=False) actor_critic = ActorCritic(actor, critic) optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr) - if self.config.lr_scheduler_factory is not None: - lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim) - else: - lr_scheduler = None + kwargs = self.config.create_kwargs( + ParamTransformerDrop("lr"), + ParamTransformerLRScheduler(optim)) return PPOPolicy( - # nn-stuff - actor, - critic, - optim, - dist_fn=self.config.dist_fn, - lr_scheduler=lr_scheduler, - # env-stuff + actor=actor, + critic=critic, + optim=optim, + dist_fn=self.dist_fn, action_space=envs.get_action_space(), - action_scaling=True, - # general_config - discount_factor=self.config.gamma, - gae_lambda=self.config.gae_lambda, - reward_normalization=self.config.rew_norm, - action_bound_method=self.config.action_bound_method, - # pg_config - max_grad_norm=self.config.max_grad_norm, - vf_coef=self.config.vf_coef, - ent_coef=self.config.ent_coef, - # ppo_config - eps_clip=self.config.eps_clip, - value_clip=self.config.value_clip, - dual_clip=self.config.dual_clip, - advantage_normalization=self.config.norm_adv, - recompute_advantage=self.config.recompute_adv, + **kwargs ) -class AutoAlphaFactory(ABC): - @abstractmethod - def create_auto_alpha( - self, - envs: Environments, - optim_factory: OptimizerFactory, - device: TDevice, - ): - pass +class ParamTransformerAlpha(ParamTransformer): + def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice): + self.envs = envs + self.optim_factory = optim_factory + self.device = device + + def transform(self, kwargs: Dict[str, Any]) -> None: + key = "alpha" + alpha = self.get(kwargs, key) + if isinstance(alpha, AutoAlphaFactory): + kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device) -class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name? - def __init__(self, lr: float = 3e-4): - self.lr = lr +class ParamTransformerMultiLRScheduler(ParamTransformer): + def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]): + self.optim_key_list = optim_key_list - def create_auto_alpha( - self, - envs: Environments, - optim_factory: OptimizerFactory, - device: TDevice, - ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: - target_entropy = -np.prod(envs.get_action_shape()) - log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) - return target_entropy, log_alpha, alpha_optim - - -@dataclass -class SACConfig: - tau: float = 0.005 - gamma: float = 0.99 - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 - reward_normalization: bool = False - estimation_step: int = 1 - deterministic_eval: bool = True - actor_lr: float = 1e-3 - critic1_lr: float = 1e-3 - critic2_lr: float = 1e-3 + def transform(self, kwargs: Dict[str, Any]) -> None: + lr_schedulers = [] + for optim, lr_scheduler_factory_key in self.optim_key_list: + lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True) + if lr_scheduler_factory is not None: + lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) + match len(lr_schedulers): + case 0: + lr_scheduler = None + case 1: + lr_scheduler = lr_schedulers[0] + case _: + lr_scheduler = MultipleLRSchedulers(*lr_schedulers) + kwargs["lr_scheduler"] = lr_scheduler class SACAgentFactory(OffpolicyAgentFactory): def __init__( self, - config: SACConfig, + params: SACParams, sampling_config: RLSamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory, - exploration_noise: BaseNoise | None = None, ): super().__init__(sampling_config) self.critic2_factory = critic2_factory self.critic1_factory = critic1_factory self.actor_factory = actor_factory - self.exploration_noise = exploration_noise self.optim_factory = optim_factory - self.config = config + self.params = params def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: actor = self.actor_factory.create_module(envs, device) critic1 = self.critic1_factory.create_module(envs, device, use_action=True) critic2 = self.critic2_factory.create_module(envs, device, use_action=True) - actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr) - critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr) - critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr) - if isinstance(self.config.alpha, AutoAlphaFactory): - alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device) - else: - alpha = self.config.alpha + actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr) + critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr) + critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr) + kwargs = self.params.create_kwargs( + ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), + ParamTransformerMultiLRScheduler([ + (actor_optim, "actor_lr_scheduler_factory"), + (critic1_optim, "critic1_lr_scheduler_factory"), + (critic2_optim, "critic2_lr_scheduler_factory")] + ), + ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device)) return SACPolicy( - actor, - actor_optim, - critic1, - critic1_optim, - critic2, - critic2_optim, - tau=self.config.tau, - gamma=self.config.gamma, - alpha=alpha, - estimation_step=self.config.estimation_step, + actor=actor, + actor_optim=actor_optim, + critic=critic1, + critic_optim=critic1_optim, + critic2=critic2, + critic2_optim=critic2_optim, action_space=envs.get_action_space(), - deterministic_eval=self.config.deterministic_eval, - exploration_noise=self.exploration_noise, + observation_space=envs.get_observation_space(), + **kwargs ) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 273bb0a..bf5ddb6 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -28,19 +28,22 @@ class Environments(ABC): self.test_envs = test_envs def info(self) -> dict[str, Any]: - return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()} + return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()} @abstractmethod def get_action_shape(self) -> TShape: pass @abstractmethod - def get_state_shape(self) -> TShape: + def get_observation_shape(self) -> TShape: pass def get_action_space(self) -> gym.Space: return self.env.action_space + def get_observation_space(self) -> gym.Space: + return self.env.observation_space + @abstractmethod def get_type(self) -> EnvType: pass @@ -75,7 +78,7 @@ class ContinuousEnvironments(Environments): def get_action_shape(self) -> TShape: return self.action_shape - def get_state_shape(self) -> TShape: + def get_observation_shape(self) -> TShape: return self.state_shape def get_type(self): diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 7b65a1f..e02d613 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -2,13 +2,13 @@ from abc import abstractmethod from collections.abc import Sequence from dataclasses import dataclass from pprint import pprint -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Callable import numpy as np import torch from tianshou.data import Collector -from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig +from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory @@ -19,7 +19,9 @@ from tianshou.highlevel.module import ( DefaultCriticFactory, ) from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory +from tianshou.highlevel.params.policy_params import PPOParams, SACParams from tianshou.policy import BasePolicy +from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -294,13 +296,15 @@ class PPOExperimentBuilder( experiment_config: RLExperimentConfig, env_factory: EnvFactory, sampling_config: RLSamplingConfig, + dist_fn: Callable[[TDistParams], torch.distributions.Distribution], ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self) - self._params: PPOConfig = PPOConfig() + self._params: PPOParams = PPOParams() + self._dist_fn = dist_fn - def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder": + def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder": self._params = params return self @@ -312,6 +316,7 @@ class PPOExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), + self._dist_fn ) @@ -327,8 +332,12 @@ class SACExperimentBuilder( super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory.__init__(self) _BuilderMixinDualCriticFactory.__init__(self) - self._params: SACConfig = SACConfig() + self._params: SACParams = SACParams() - def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder": + def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder": self._params = params return self + + def _create_agent_factory(self) -> AgentFactory: + return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(), + self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory()) diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module.py index 17c5767..979c88e 100644 --- a/tianshou/highlevel/module.py +++ b/tianshou/highlevel/module.py @@ -91,7 +91,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory): def create_module(self, envs: Environments, device: TDevice) -> nn.Module: net_a = Net( - envs.get_state_shape(), + envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device, @@ -148,7 +148,7 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory): def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_state_shape(), + envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index ee6677a..a2a0be1 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -2,13 +2,9 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Any -import numpy as np import torch from torch import Tensor from torch.optim import Adam -from torch.optim.lr_scheduler import LambdaLR, LRScheduler - -from tianshou.highlevel.config import RLSamplingConfig TParams = Iterable[Tensor] | Iterable[dict[str, Any]] @@ -44,19 +40,3 @@ class AdamOptimizerFactory(OptimizerFactory): ) -class LRSchedulerFactory(ABC): - @abstractmethod - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - pass - - -class LinearLRSchedulerFactory(LRSchedulerFactory): - def __init__(self, sampling_config: RLSamplingConfig): - self.sampling_config = sampling_config - - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - max_update_num = ( - np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) - * self.sampling_config.num_epochs - ) - return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) diff --git a/tianshou/highlevel/params/__init__.py b/tianshou/highlevel/params/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py new file mode 100644 index 0000000..5f5bcaa --- /dev/null +++ b/tianshou/highlevel/params/alpha.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module import TDevice +from tianshou.highlevel.optim import OptimizerFactory + + +class AutoAlphaFactory(ABC): + @abstractmethod + def create_auto_alpha( + self, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + pass + + +class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name? + def __init__(self, lr: float = 3e-4): + self.lr = lr + + def create_auto_alpha( + self, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + target_entropy = -np.prod(envs.get_action_shape()) + log_alpha = torch.zeros(1, requires_grad=True, device=device) + alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) + return target_entropy, log_alpha, alpha_optim diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py new file mode 100644 index 0000000..80699cd --- /dev/null +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +from torch.optim.lr_scheduler import LRScheduler, LambdaLR + +from tianshou.highlevel.config import RLSamplingConfig + + +class LRSchedulerFactory(ABC): + @abstractmethod + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + pass + + +class LinearLRSchedulerFactory(LRSchedulerFactory): + def __init__(self, sampling_config: RLSamplingConfig): + self.sampling_config = sampling_config + + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + max_update_num = ( + np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) + * self.sampling_config.num_epochs + ) + return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py new file mode 100644 index 0000000..2df996b --- /dev/null +++ b/tianshou/highlevel/params/policy_params.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, asdict +from typing import Dict, Any, Literal + +import torch + +from tianshou.exploration import BaseNoise +from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory + + +class ParamTransformer(ABC): + @abstractmethod + def transform(self, kwargs: Dict[str, Any]) -> None: + pass + + @staticmethod + def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any: + value = d[key] + if drop: + del d[key] + return value + + +@dataclass +class Params: + def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]: + d = asdict(self) + for transformer in transformers: + transformer.transform(d) + return d + + +@dataclass +class PGParams(Params): + """Config of general policy-gradient algorithms.""" + discount_factor: float = 0.99 + reward_normalization: bool = False + deterministic_eval: bool = False + action_scaling: bool = True + action_bound_method: Literal["clip", "tanh"] | None = "clip" + + +@dataclass +class A2CParams(PGParams): + vf_coef: float = 0.5 + ent_coef: float = 0.01 + max_grad_norm: float | None = None + gae_lambda: float = 0.95 + max_batchsize: int = 256 + + +@dataclass +class PPOParams(A2CParams): + """PPO specific config.""" + eps_clip: float = 0.2 + dual_clip: float | None = None + value_clip: bool = False + advantage_normalization: bool = True + recompute_advantage: bool = False + lr: float = 1e-3 + lr_scheduler_factory: LRSchedulerFactory | None = None + + +@dataclass +class SACParams(Params): + tau: float = 0.005 + gamma: float = 0.99 + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 + estimation_step: int = 1 + exploration_noise: BaseNoise | Literal["default"] | None = None + deterministic_eval: bool = True + action_scaling: bool = True + action_bound_method: Literal["clip"] | None = "clip" + actor_lr: float = 1e-3 + critic1_lr: float = 1e-3 + critic2_lr: float = 1e-3 + actor_lr_scheduler_factory: LRSchedulerFactory | None = None + critic1_lr_scheduler_factory: LRSchedulerFactory | None = None + critic2_lr_scheduler_factory: LRSchedulerFactory | None = None