Add high-level API support for TD3

* Created mixins for agent factories to reduce code duplication
 * Further factorised params & mixins for experiment factories
 * Additional parameter abstractions
 * Implement high-level MuJoCo TD3 example
This commit is contained in:
Dominik Jain 2023-09-26 15:35:18 +02:00
parent 6a739384ee
commit e993425aa1
14 changed files with 626 additions and 116 deletions

View File

@ -66,7 +66,7 @@ def main(
experiment = ( experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn) PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
.with_ppo_params( .with_params(
PPOParams( PPOParams(
discount_factor=gamma, discount_factor=gamma,
gae_lambda=gae_lambda, gae_lambda=gae_lambda,

View File

@ -7,13 +7,13 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperimentConfig, RLExperimentConfig,
SACExperimentBuilder, SACExperimentBuilder,
) )
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
from tianshou.highlevel.params.policy_params import SACParams
def main( def main(
@ -70,7 +70,9 @@ def main(
), ),
) )
.with_actor_factory_default( .with_actor_factory_default(
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True, hidden_sizes,
continuous_unbounded=True,
continuous_conditioned_sigma=True,
) )
.with_common_critic_factory_default(hidden_sizes) .with_common_critic_factory_default(hidden_sizes)
.build() .build()

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperimentConfig,
TD3ExperimentBuilder,
)
from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory
from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory
from tianshou.highlevel.params.policy_params import TD3Params
def main(
experiment_config: RLExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
actor_lr: float = 3e-4,
critic_lr: float = 3e-4,
gamma: float = 0.99,
tau: float = 0.005,
exploration_noise: float = 0.1,
policy_noise: float = 0.2,
noise_clip: float = 0.5,
update_actor_freq: int = 2,
start_timesteps: int = 25000,
epoch: int = 200,
step_per_epoch: int = 5000,
step_per_collect: int = 1,
update_per_step: int = 1,
n_step: int = 1,
batch_size: int = 256,
training_num: int = 1,
test_num: int = 10,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
batch_size=batch_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
start_timesteps=start_timesteps,
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_td3_params(
TD3Params(
tau=tau,
gamma=gamma,
estimation_step=n_step,
update_actor_freq=update_actor_freq,
noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip),
policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise),
exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise),
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
),
)
.with_actor_factory_default(hidden_sizes)
.with_common_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -1,21 +1,35 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Dict, Any, List, Tuple from typing import Any
import torch import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice from tianshou.highlevel.module import (
ActorCriticModuleOpt,
ActorFactory,
ActorModuleOptFactory,
CriticFactory,
CriticModuleOptFactory,
ModuleOpt,
TDevice,
)
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy from tianshou.highlevel.params.policy_params import (
ParamTransformer,
PPOParams,
SACParams,
TD3Params,
)
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.policy.modelfree.pg import TDistParams from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils import MultipleLRSchedulers from tianshou.utils import MultipleLRSchedulers
@ -135,7 +149,7 @@ class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str): def __init__(self, *keys: str):
self.keys = keys self.keys = keys
def transform(self, kwargs: Dict[str, Any]) -> None: def transform(self, kwargs: dict[str, Any]) -> None:
for k in self.keys: for k in self.keys:
del kwargs[k] del kwargs[k]
@ -144,12 +158,94 @@ class ParamTransformerLRScheduler(ParamTransformer):
def __init__(self, optim: torch.optim.Optimizer): def __init__(self, optim: torch.optim.Optimizer):
self.optim = optim self.optim = optim
def transform(self, kwargs: Dict[str, Any]) -> None: def transform(self, kwargs: dict[str, Any]) -> None:
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True) factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None kwargs["lr_scheduler"] = (
factory.create_scheduler(self.optim) if factory is not None else None
)
class PPOAgentFactory(OnpolicyAgentFactory): class _ActorMixin:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorCriticMixin:
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.critic_use_action = critic_use_action
def create_actor_critic_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
class _ActorAndCriticMixin(_ActorMixin):
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
super().__init__(actor_factory, optim_factory)
self.critic_module_opt_factory = CriticModuleOptFactory(
critic_factory,
optim_factory,
critic_use_action,
)
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
self.critic2_module_opt_factory = CriticModuleOptFactory(
critic2_factory,
optim_factory,
critic_use_action,
)
def create_critic2_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ModuleOpt:
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
def __init__( def __init__(
self, self,
params: PPOParams, params: PPOParams,
@ -160,27 +256,29 @@ class PPOAgentFactory(OnpolicyAgentFactory):
dist_fn: Callable[[TDistParams], torch.distributions.Distribution], dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
): ):
super().__init__(sampling_config) super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory _ActorCriticMixin.__init__(
self.critic_factory = critic_factory self,
self.actor_factory = actor_factory actor_factory,
self.config = params critic_factory,
optimizer_factory,
critic_use_action=False,
)
self.params = params
self.dist_fn = dist_fn self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device) actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
critic = self.critic_factory.create_module(envs, device, use_action=False) kwargs = self.params.create_kwargs(
actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
kwargs = self.config.create_kwargs(
ParamTransformerDrop("lr"), ParamTransformerDrop("lr"),
ParamTransformerLRScheduler(optim)) ParamTransformerLRScheduler(actor_critic.optim),
)
return PPOPolicy( return PPOPolicy(
actor=actor, actor=actor_critic.actor,
critic=critic, critic=actor_critic.critic,
optim=optim, optim=actor_critic.optim,
dist_fn=self.dist_fn, dist_fn=self.dist_fn,
action_space=envs.get_action_space(), action_space=envs.get_action_space(),
**kwargs **kwargs,
) )
@ -190,7 +288,7 @@ class ParamTransformerAlpha(ParamTransformer):
self.optim_factory = optim_factory self.optim_factory = optim_factory
self.device = device self.device = device
def transform(self, kwargs: Dict[str, Any]) -> None: def transform(self, kwargs: dict[str, Any]) -> None:
key = "alpha" key = "alpha"
alpha = self.get(kwargs, key) alpha = self.get(kwargs, key)
if isinstance(alpha, AutoAlphaFactory): if isinstance(alpha, AutoAlphaFactory):
@ -198,13 +296,17 @@ class ParamTransformerAlpha(ParamTransformer):
class ParamTransformerMultiLRScheduler(ParamTransformer): class ParamTransformerMultiLRScheduler(ParamTransformer):
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]): def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
self.optim_key_list = optim_key_list self.optim_key_list = optim_key_list
def transform(self, kwargs: Dict[str, Any]) -> None: def transform(self, kwargs: dict[str, Any]) -> None:
lr_schedulers = [] lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list: for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True) lr_scheduler_factory: LRSchedulerFactory | None = self.get(
kwargs,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None: if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
match len(lr_schedulers): match len(lr_schedulers):
@ -217,7 +319,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
kwargs["lr_scheduler"] = lr_scheduler kwargs["lr_scheduler"] = lr_scheduler
class SACAgentFactory(OffpolicyAgentFactory): class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__( def __init__(
self, self,
params: SACParams, params: SACParams,
@ -228,35 +330,114 @@ class SACAgentFactory(OffpolicyAgentFactory):
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
): ):
super().__init__(sampling_config) super().__init__(sampling_config)
self.critic2_factory = critic2_factory _ActorAndDualCriticsMixin.__init__(
self.critic1_factory = critic1_factory self,
self.actor_factory = actor_factory actor_factory,
self.optim_factory = optim_factory critic1_factory,
critic2_factory,
optim_factory,
critic_use_action=True,
)
self.params = params self.params = params
self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device) actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True) critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True) critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler([ ParamTransformerMultiLRScheduler(
(actor_optim, "actor_lr_scheduler_factory"), [
(critic1_optim, "critic1_lr_scheduler_factory"), (actor.optim, "actor_lr_scheduler_factory"),
(critic2_optim, "critic2_lr_scheduler_factory")] (critic1.optim, "critic1_lr_scheduler_factory"),
(critic2.optim, "critic2_lr_scheduler_factory"),
],
), ),
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device)) ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
)
return SACPolicy( return SACPolicy(
actor=actor, actor=actor.module,
actor_optim=actor_optim, actor_optim=actor.optim,
critic=critic1, critic=critic1.module,
critic_optim=critic1_optim, critic_optim=critic1.optim,
critic2=critic2, critic2=critic2.module,
critic2_optim=critic2_optim, critic2_optim=critic2.optim,
action_space=envs.get_action_space(), action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(), observation_space=envs.get_observation_space(),
**kwargs **kwargs,
)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str, envs: Environments):
self.key = key
self.envs = envs
def transform(self, kwargs: dict[str, Any]) -> None:
value = kwargs[self.key]
if isinstance(value, NoiseFactory):
kwargs[self.key] = value.create_noise(self.envs)
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str, envs: Environments):
self.key = key
self.envs = envs
def transform(self, kwargs: dict[str, Any]) -> None:
value = kwargs[self.key]
if isinstance(value, FloatEnvParamFactory):
kwargs[self.key] = value.create_param(self.envs)
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
params: TD3Params,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config)
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
critic1_factory,
critic2_factory,
optim_factory,
critic_use_action=True,
)
self.params = params
self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler(
[
(actor.optim, "actor_lr_scheduler_factory"),
(critic1.optim, "critic1_lr_scheduler_factory"),
(critic2.optim, "critic2_lr_scheduler_factory"),
],
),
ParamTransformerNoiseFactory("exploration_noise", envs),
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
)
return TD3Policy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
) )

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
class RLSamplingConfig: class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching.""" """Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults?
num_epochs: int = 100 num_epochs: int = 100
step_per_epoch: int = 30000 step_per_epoch: int = 30000
batch_size: int = 64 batch_size: int = 64

View File

@ -20,6 +20,14 @@ class EnvType(Enum):
def is_continuous(self): def is_continuous(self):
return self == EnvType.CONTINUOUS return self == EnvType.CONTINUOUS
def assert_continuous(self, requiring_entity: Any):
if not self.is_continuous():
raise AssertionError(f"{requiring_entity} requires continuous environments")
def assert_discrete(self, requiring_entity: Any):
if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments")
class Environments(ABC): class Environments(ABC):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
@ -28,7 +36,10 @@ class Environments(ABC):
self.test_envs = test_envs self.test_envs = test_envs
def info(self) -> dict[str, Any]: def info(self) -> dict[str, Any]:
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()} return {
"action_shape": self.get_action_shape(),
"state_shape": self.get_observation_shape(),
}
@abstractmethod @abstractmethod
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TShape:
@ -81,7 +92,7 @@ class ContinuousEnvironments(Environments):
def get_observation_shape(self) -> TShape: def get_observation_shape(self) -> TShape:
return self.state_shape return self.state_shape
def get_type(self): def get_type(self) -> EnvType:
return EnvType.CONTINUOUS return EnvType.CONTINUOUS

View File

@ -1,25 +1,31 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pprint from pprint import pprint
from typing import Generic, TypeVar, Callable from typing import Generic, Self, TypeVar
import numpy as np import numpy as np
import torch import torch
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory from tianshou.highlevel.agent import (
AgentFactory,
PPOAgentFactory,
SACAgentFactory,
TD3AgentFactory,
)
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import ( from tianshou.highlevel.module import (
ActorFactory, ActorFactory,
ContinuousActorType,
CriticFactory, CriticFactory,
DefaultActorFactory, DefaultActorFactory,
DefaultCriticFactory, DefaultCriticFactory,
) )
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.highlevel.params.policy_params import PPOParams, SACParams from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
@ -150,7 +156,10 @@ class RLExperimentBuilder:
return self return self
def with_optim_factory_default( def with_optim_factory_default(
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, self: TBuilder,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
) -> TBuilder: ) -> TBuilder:
"""Configures the use of the default optimizer, Adam, with the given parameters. """Configures the use of the default optimizer, Adam, with the given parameters.
@ -174,12 +183,16 @@ class RLExperimentBuilder:
def build(self) -> RLExperiment: def build(self) -> RLExperiment:
return RLExperiment( return RLExperiment(
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory, self._config,
self._env_factory,
self._create_agent_factory(),
self._logger_factory,
) )
class _BuilderMixinActorFactory: class _BuilderMixinActorFactory:
def __init__(self): def __init__(self, continuous_actor_type: ContinuousActorType):
self._continuous_actor_type = continuous_actor_type
self._actor_factory: ActorFactory | None = None self._actor_factory: ActorFactory | None = None
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder: def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
@ -187,7 +200,7 @@ class _BuilderMixinActorFactory:
self._actor_factory = actor_factory self._actor_factory = actor_factory
return self return self
def with_actor_factory_default( def _with_actor_factory_default(
self: TBuilder, self: TBuilder,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
continuous_unbounded=False, continuous_unbounded=False,
@ -195,6 +208,7 @@ class _BuilderMixinActorFactory:
) -> TBuilder: ) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = DefaultActorFactory( self._actor_factory = DefaultActorFactory(
self._continuous_actor_type,
hidden_sizes, hidden_sizes,
continuous_unbounded=continuous_unbounded, continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma, continuous_conditioned_sigma=continuous_conditioned_sigma,
@ -203,11 +217,40 @@ class _BuilderMixinActorFactory:
def _get_actor_factory(self): def _get_actor_factory(self):
if self._actor_factory is None: if self._actor_factory is None:
return DefaultActorFactory() return DefaultActorFactory(self._continuous_actor_type)
else: else:
return self._actor_factory return self._actor_factory
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
super().__init__(ContinuousActorType.GAUSSIAN)
def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
continuous_conditioned_sigma=False,
) -> Self:
return super()._with_actor_factory_default(
hidden_sizes,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
)
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
return super()._with_actor_factory_default(hidden_sizes)
class _BuilderMixinCriticsFactory: class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int): def __init__(self, num_critics: int):
self._critic_factories: list[CriticFactory | None] = [None] * num_critics self._critic_factories: list[CriticFactory | None] = [None] * num_critics
@ -238,7 +281,8 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
return self return self
def with_critic_factory_default( def with_critic_factory_default(
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, self: TBuilder,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory" self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
@ -256,7 +300,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self return self
def with_common_critic_factory_default( def with_common_critic_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory" self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
@ -269,7 +314,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self return self
def with_critic1_factory_default( def with_critic1_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory" self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
@ -281,7 +327,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self return self
def with_critic2_factory_default( def with_critic2_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory" self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
@ -289,7 +336,9 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
class PPOExperimentBuilder( class PPOExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory, RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticFactory,
): ):
def __init__( def __init__(
self, self,
@ -299,12 +348,12 @@ class PPOExperimentBuilder(
dist_fn: Callable[[TDistParams], torch.distributions.Distribution], dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOParams = PPOParams() self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn self._dist_fn = dist_fn
def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder": def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params self._params = params
return self return self
@ -316,12 +365,14 @@ class PPOExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._dist_fn self._dist_fn,
) )
class SACExperimentBuilder( class SACExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory, RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinDualCriticFactory,
): ):
def __init__( def __init__(
self, self,
@ -330,14 +381,51 @@ class SACExperimentBuilder(
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self) _BuilderMixinDualCriticFactory.__init__(self)
self._params: SACParams = SACParams() self._params: SACParams = SACParams()
def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder": def with_sac_params(self, params: SACParams) -> Self:
self._params = params self._params = params
return self return self
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(), return SACAgentFactory(
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory()) self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_critic_factory(1),
self._get_optim_factory(),
)
class TD3ExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: TD3Params = TD3Params()
def with_td3_params(self, params: TD3Params) -> Self:
self._params = params
return self
def _create_agent_factory(self) -> AgentFactory:
return TD3AgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_critic_factory(1),
self._get_optim_factory(),
)

View File

@ -1,13 +1,13 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = TensorboardLogger | WandbLogger TLogger: TypeAlias = TensorboardLogger | WandbLogger
@dataclass @dataclass
@ -30,7 +30,7 @@ class DefaultLoggerFactory(LoggerFactory):
wandb_project: str | None = None, wandb_project: str | None = None,
): ):
if logger_type == "wandb" and wandb_project is None: if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wand_project'") raise ValueError("Must provide 'wandb_project'")
self.log_dir = log_dir self.log_dir = log_dir
self.logger_type = logger_type self.logger_type = logger_type
self.wandb_project = wandb_project self.wandb_project = wandb_project

View File

@ -1,16 +1,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.utils.net.common import Net from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net import continuous
from tianshou.utils.net.continuous import Critic as ContinuousCritic from tianshou.utils.net.common import ActorCritic, Net
TDevice = str | int | torch.device TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module): def init_linear_orthogonal(module: torch.nn.Module):
@ -24,6 +26,11 @@ def init_linear_orthogonal(module: torch.nn.Module):
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
class ContinuousActorType:
GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic"
class ActorFactory(ABC): class ActorFactory(ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
@ -47,30 +54,36 @@ class ActorFactory(ABC):
class DefaultActorFactory(ActorFactory): class DefaultActorFactory(ActorFactory):
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
DEFAULT_HIDDEN_SIZES = (64, 64) DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__( def __init__(
self, self,
continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False, continuous_unbounded=False,
continuous_conditioned_sigma=False, continuous_conditioned_sigma=False,
): ):
self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
"""
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
"""
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
env_type = envs.get_type() env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS: if env_type == EnvType.CONTINUOUS:
factory = ContinuousActorProbFactory( match self.continuous_actor_type:
self.hidden_sizes, case ContinuousActorType.GAUSSIAN:
unbounded=self.continuous_unbounded, factory = ContinuousActorFactoryGaussian(
conditioned_sigma=self.continuous_conditioned_sigma, self.hidden_sizes,
) unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
case ContinuousActorType.DETERMINISTIC:
factory = ContinuousActorFactoryDeterministic(self.hidden_sizes)
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device) return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
raise NotImplementedError raise NotImplementedError
@ -82,8 +95,25 @@ class ContinuousActorFactory(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces.""" """Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
class ContinuousActorProbFactory(ContinuousActorFactory): def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
device=device,
)
return continuous.Actor(
net_a,
envs.get_action_shape(),
hidden_sizes=(),
device=device,
).to(device)
class ContinuousActorFactoryGaussian(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded self.unbounded = unbounded
@ -96,7 +126,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
activation=nn.Tanh, activation=nn.Tanh,
device=device, device=device,
) )
actor = ActorProb( actor = continuous.ActorProb(
net_a, net_a,
envs.get_action_shape(), envs.get_action_shape(),
unbounded=self.unbounded, unbounded=self.unbounded,
@ -155,6 +185,54 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
activation=nn.Tanh, activation=nn.Tanh,
device=device, device=device,
) )
critic = ContinuousCritic(net_c, device=device).to(device) critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic) init_linear_orthogonal(critic)
return critic return critic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -1,13 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any from typing import Any
import torch import torch
from torch import Tensor
from torch.optim import Adam from torch.optim import Adam
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
class OptimizerFactory(ABC): class OptimizerFactory(ABC):
@abstractmethod @abstractmethod
@ -38,5 +34,3 @@ class AdamOptimizerFactory(OptimizerFactory):
eps=self.eps, eps=self.eps,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
) )

View File

@ -0,0 +1,24 @@
"""Factories for the generation of environment-dependent parameters."""
from abc import ABC, abstractmethod
from typing import TypeVar
from tianshou.highlevel.env import ContinuousEnvironments, Environments
T = TypeVar("T")
class FloatEnvParamFactory(ABC):
@abstractmethod
def create_param(self, envs: Environments) -> float:
pass
class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory):
def __init__(self, value: float):
""":param value: value with which to scale the max action value"""
self.value = value
def create_param(self, envs: Environments) -> float:
envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments
return envs.max_action * self.value

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
import numpy as np import numpy as np
import torch import torch
from torch.optim.lr_scheduler import LRScheduler, LambdaLR from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig

View File

@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.highlevel.env import ContinuousEnvironments, Environments
class NoiseFactory(ABC):
@abstractmethod
def create_noise(self, envs: Environments) -> BaseNoise:
pass
class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
This factory can only be applied to continuous action spaces.
"""
def __init__(self, std_fraction: float):
self.std_fraction = std_fraction
def create_noise(self, envs: Environments) -> BaseNoise:
envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments
return GaussianNoise(sigma=envs.max_action * self.std_fraction)

View File

@ -1,21 +1,23 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import asdict, dataclass
from typing import Dict, Any, Literal from typing import Any, Literal
import torch import torch
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
class ParamTransformer(ABC): class ParamTransformer(ABC):
@abstractmethod @abstractmethod
def transform(self, kwargs: Dict[str, Any]) -> None: def transform(self, kwargs: dict[str, Any]) -> None:
pass pass
@staticmethod @staticmethod
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any: def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key] value = d[key]
if drop: if drop:
del d[key] del d[key]
@ -24,7 +26,7 @@ class ParamTransformer(ABC):
@dataclass @dataclass
class Params: class Params:
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]: def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
d = asdict(self) d = asdict(self)
for transformer in transformers: for transformer in transformers:
transformer.transform(d) transformer.transform(d)
@ -34,6 +36,7 @@ class Params:
@dataclass @dataclass
class PGParams(Params): class PGParams(Params):
"""Config of general policy-gradient algorithms.""" """Config of general policy-gradient algorithms."""
discount_factor: float = 0.99 discount_factor: float = 0.99
reward_normalization: bool = False reward_normalization: bool = False
deterministic_eval: bool = False deterministic_eval: bool = False
@ -53,6 +56,7 @@ class A2CParams(PGParams):
@dataclass @dataclass
class PPOParams(A2CParams): class PPOParams(A2CParams):
"""PPO specific config.""" """PPO specific config."""
eps_clip: float = 0.2 eps_clip: float = 0.2
dual_clip: float | None = None dual_clip: float | None = None
value_clip: bool = False value_clip: bool = False
@ -63,7 +67,17 @@ class PPOParams(A2CParams):
@dataclass @dataclass
class SACParams(Params): class ActorAndDualCriticsParams(Params):
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class SACParams(ActorAndDualCriticsParams):
tau: float = 0.005 tau: float = 0.005
gamma: float = 0.99 gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
@ -72,9 +86,16 @@ class SACParams(Params):
deterministic_eval: bool = True deterministic_eval: bool = True
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3 @dataclass
actor_lr_scheduler_factory: LRSchedulerFactory | None = None class TD3Params(ActorAndDualCriticsParams):
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None tau: float = 0.005
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
policy_noise: float | FloatEnvParamFactory = 0.2
noise_clip: float | FloatEnvParamFactory = 0.5
update_actor_freq: int = 2
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"