From e993425aa1efe1bc3779eb16a4aafe4e0dd4c650 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 26 Sep 2023 15:35:18 +0200 Subject: [PATCH] Add high-level API support for TD3 * Created mixins for agent factories to reduce code duplication * Further factorised params & mixins for experiment factories * Additional parameter abstractions * Implement high-level MuJoCo TD3 example --- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 8 +- examples/mujoco/mujoco_td3_hl.py | 85 +++++++ tianshou/highlevel/agent.py | 281 +++++++++++++++++---- tianshou/highlevel/config.py | 1 + tianshou/highlevel/env.py | 15 +- tianshou/highlevel/experiment.py | 132 ++++++++-- tianshou/highlevel/logger.py | 6 +- tianshou/highlevel/module.py | 110 ++++++-- tianshou/highlevel/optim.py | 6 - tianshou/highlevel/params/env_param.py | 24 ++ tianshou/highlevel/params/lr_scheduler.py | 2 +- tianshou/highlevel/params/noise.py | 25 ++ tianshou/highlevel/params/policy_params.py | 45 +++- 14 files changed, 626 insertions(+), 116 deletions(-) create mode 100644 examples/mujoco/mujoco_td3_hl.py create mode 100644 tianshou/highlevel/params/env_param.py create mode 100644 tianshou/highlevel/params/noise.py diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index f4ef4fc..125c081 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -66,7 +66,7 @@ def main( experiment = ( PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn) - .with_ppo_params( + .with_params( PPOParams( discount_factor=gamma, gae_lambda=gae_lambda, diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 8f4ddd0..bd4282d 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,13 +7,13 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.params.policy_params import SACParams -from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.experiment import ( RLExperimentConfig, SACExperimentBuilder, ) +from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory +from tianshou.highlevel.params.policy_params import SACParams def main( @@ -70,7 +70,9 @@ def main( ), ) .with_actor_factory_default( - hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True, + hidden_sizes, + continuous_unbounded=True, + continuous_conditioned_sigma=True, ) .with_common_critic_factory_default(hidden_sizes) .build() diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py new file mode 100644 index 0000000..760a980 --- /dev/null +++ b/examples/mujoco/mujoco_td3_hl.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +import datetime +import os +from collections.abc import Sequence + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + RLExperimentConfig, + TD3ExperimentBuilder, +) +from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory +from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory +from tianshou.highlevel.params.policy_params import TD3Params + + +def main( + experiment_config: RLExperimentConfig, + task: str = "Ant-v3", + buffer_size: int = 1000000, + hidden_sizes: Sequence[int] = (256, 256), + actor_lr: float = 3e-4, + critic_lr: float = 3e-4, + gamma: float = 0.99, + tau: float = 0.005, + exploration_noise: float = 0.1, + policy_noise: float = 0.2, + noise_clip: float = 0.5, + update_actor_freq: int = 2, + start_timesteps: int = 25000, + epoch: int = 200, + step_per_epoch: int = 5000, + step_per_collect: int = 1, + update_per_step: int = 1, + n_step: int = 1, + batch_size: int = 256, + training_num: int = 1, + test_num: int = 10, +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(task, "td3", str(experiment_config.seed), now) + + sampling_config = RLSamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + batch_size=batch_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + start_timesteps=start_timesteps, + start_timesteps_random=True, + ) + + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + + experiment = ( + TD3ExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_td3_params( + TD3Params( + tau=tau, + gamma=gamma, + estimation_step=n_step, + update_actor_freq=update_actor_freq, + noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip), + policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise), + exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise), + actor_lr=actor_lr, + critic1_lr=critic_lr, + critic2_lr=critic_lr, + ), + ) + .with_actor_factory_default(hidden_sizes) + .with_common_critic_factory_default(hidden_sizes) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + CLI(main) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 98ec59e..b435f7b 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,21 +1,35 @@ import os from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Dict, Any, List, Tuple +from typing import Any import torch from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.exploration import BaseNoise from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import Logger -from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice +from tianshou.highlevel.module import ( + ActorCriticModuleOpt, + ActorFactory, + ActorModuleOptFactory, + CriticFactory, + CriticModuleOptFactory, + ModuleOpt, + TDevice, +) from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.env_param import FloatEnvParamFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory -from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams -from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy +from tianshou.highlevel.params.noise import NoiseFactory +from tianshou.highlevel.params.policy_params import ( + ParamTransformer, + PPOParams, + SACParams, + TD3Params, +) +from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils import MultipleLRSchedulers @@ -135,7 +149,7 @@ class ParamTransformerDrop(ParamTransformer): def __init__(self, *keys: str): self.keys = keys - def transform(self, kwargs: Dict[str, Any]) -> None: + def transform(self, kwargs: dict[str, Any]) -> None: for k in self.keys: del kwargs[k] @@ -144,12 +158,94 @@ class ParamTransformerLRScheduler(ParamTransformer): def __init__(self, optim: torch.optim.Optimizer): self.optim = optim - def transform(self, kwargs: Dict[str, Any]) -> None: + def transform(self, kwargs: dict[str, Any]) -> None: factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True) - kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None + kwargs["lr_scheduler"] = ( + factory.create_scheduler(self.optim) if factory is not None else None + ) -class PPOAgentFactory(OnpolicyAgentFactory): +class _ActorMixin: + def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): + self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory) + + def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + return self.actor_module_opt_factory.create_module_opt(envs, device, lr) + + +class _ActorCriticMixin: + """Mixin for agents that use an ActorCritic module with a single optimizer.""" + + def __init__( + self, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + critic_use_action: bool, + ): + self.actor_factory = actor_factory + self.critic_factory = critic_factory + self.optim_factory = optim_factory + self.critic_use_action = critic_use_action + + def create_actor_critic_module_opt( + self, + envs: Environments, + device: TDevice, + lr: float, + ) -> ActorCriticModuleOpt: + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) + actor_critic = ActorCritic(actor, critic) + optim = self.optim_factory.create_optimizer(actor_critic, lr) + return ActorCriticModuleOpt(actor_critic, optim) + + +class _ActorAndCriticMixin(_ActorMixin): + def __init__( + self, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + critic_use_action: bool, + ): + super().__init__(actor_factory, optim_factory) + self.critic_module_opt_factory = CriticModuleOptFactory( + critic_factory, + optim_factory, + critic_use_action, + ) + + def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + return self.critic_module_opt_factory.create_module_opt(envs, device, lr) + + +class _ActorAndDualCriticsMixin(_ActorAndCriticMixin): + def __init__( + self, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + critic2_factory: CriticFactory, + optim_factory: OptimizerFactory, + critic_use_action: bool, + ): + super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action) + self.critic2_module_opt_factory = CriticModuleOptFactory( + critic2_factory, + optim_factory, + critic_use_action, + ) + + def create_critic2_module_opt( + self, + envs: Environments, + device: TDevice, + lr: float, + ) -> ModuleOpt: + return self.critic2_module_opt_factory.create_module_opt(envs, device, lr) + + +class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): def __init__( self, params: PPOParams, @@ -160,27 +256,29 @@ class PPOAgentFactory(OnpolicyAgentFactory): dist_fn: Callable[[TDistParams], torch.distributions.Distribution], ): super().__init__(sampling_config) - self.optimizer_factory = optimizer_factory - self.critic_factory = critic_factory - self.actor_factory = actor_factory - self.config = params + _ActorCriticMixin.__init__( + self, + actor_factory, + critic_factory, + optimizer_factory, + critic_use_action=False, + ) + self.params = params self.dist_fn = dist_fn def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: - actor = self.actor_factory.create_module(envs, device) - critic = self.critic_factory.create_module(envs, device, use_action=False) - actor_critic = ActorCritic(actor, critic) - optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr) - kwargs = self.config.create_kwargs( + actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) + kwargs = self.params.create_kwargs( ParamTransformerDrop("lr"), - ParamTransformerLRScheduler(optim)) + ParamTransformerLRScheduler(actor_critic.optim), + ) return PPOPolicy( - actor=actor, - critic=critic, - optim=optim, + actor=actor_critic.actor, + critic=actor_critic.critic, + optim=actor_critic.optim, dist_fn=self.dist_fn, action_space=envs.get_action_space(), - **kwargs + **kwargs, ) @@ -190,7 +288,7 @@ class ParamTransformerAlpha(ParamTransformer): self.optim_factory = optim_factory self.device = device - def transform(self, kwargs: Dict[str, Any]) -> None: + def transform(self, kwargs: dict[str, Any]) -> None: key = "alpha" alpha = self.get(kwargs, key) if isinstance(alpha, AutoAlphaFactory): @@ -198,13 +296,17 @@ class ParamTransformerAlpha(ParamTransformer): class ParamTransformerMultiLRScheduler(ParamTransformer): - def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]): + def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]): self.optim_key_list = optim_key_list - def transform(self, kwargs: Dict[str, Any]) -> None: + def transform(self, kwargs: dict[str, Any]) -> None: lr_schedulers = [] for optim, lr_scheduler_factory_key in self.optim_key_list: - lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True) + lr_scheduler_factory: LRSchedulerFactory | None = self.get( + kwargs, + lr_scheduler_factory_key, + drop=True, + ) if lr_scheduler_factory is not None: lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) match len(lr_schedulers): @@ -217,7 +319,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer): kwargs["lr_scheduler"] = lr_scheduler -class SACAgentFactory(OffpolicyAgentFactory): +class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( self, params: SACParams, @@ -228,35 +330,114 @@ class SACAgentFactory(OffpolicyAgentFactory): optim_factory: OptimizerFactory, ): super().__init__(sampling_config) - self.critic2_factory = critic2_factory - self.critic1_factory = critic1_factory - self.actor_factory = actor_factory - self.optim_factory = optim_factory + _ActorAndDualCriticsMixin.__init__( + self, + actor_factory, + critic1_factory, + critic2_factory, + optim_factory, + critic_use_action=True, + ) self.params = params + self.optim_factory = optim_factory def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - actor = self.actor_factory.create_module(envs, device) - critic1 = self.critic1_factory.create_module(envs, device, use_action=True) - critic2 = self.critic2_factory.create_module(envs, device, use_action=True) - actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr) - critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr) - critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr) + actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) + critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) + critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) kwargs = self.params.create_kwargs( ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), - ParamTransformerMultiLRScheduler([ - (actor_optim, "actor_lr_scheduler_factory"), - (critic1_optim, "critic1_lr_scheduler_factory"), - (critic2_optim, "critic2_lr_scheduler_factory")] + ParamTransformerMultiLRScheduler( + [ + (actor.optim, "actor_lr_scheduler_factory"), + (critic1.optim, "critic1_lr_scheduler_factory"), + (critic2.optim, "critic2_lr_scheduler_factory"), + ], ), - ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device)) + ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device), + ) return SACPolicy( - actor=actor, - actor_optim=actor_optim, - critic=critic1, - critic_optim=critic1_optim, - critic2=critic2, - critic2_optim=critic2_optim, + actor=actor.module, + actor_optim=actor.optim, + critic=critic1.module, + critic_optim=critic1.optim, + critic2=critic2.module, + critic2_optim=critic2.optim, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), - **kwargs + **kwargs, + ) + + +class ParamTransformerNoiseFactory(ParamTransformer): + def __init__(self, key: str, envs: Environments): + self.key = key + self.envs = envs + + def transform(self, kwargs: dict[str, Any]) -> None: + value = kwargs[self.key] + if isinstance(value, NoiseFactory): + kwargs[self.key] = value.create_noise(self.envs) + + +class ParamTransformerFloatEnvParamFactory(ParamTransformer): + def __init__(self, key: str, envs: Environments): + self.key = key + self.envs = envs + + def transform(self, kwargs: dict[str, Any]) -> None: + value = kwargs[self.key] + if isinstance(value, FloatEnvParamFactory): + kwargs[self.key] = value.create_param(self.envs) + + +class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): + def __init__( + self, + params: TD3Params, + sampling_config: RLSamplingConfig, + actor_factory: ActorFactory, + critic1_factory: CriticFactory, + critic2_factory: CriticFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config) + _ActorAndDualCriticsMixin.__init__( + self, + actor_factory, + critic1_factory, + critic2_factory, + optim_factory, + critic_use_action=True, + ) + self.params = params + self.optim_factory = optim_factory + + def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) + critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr) + critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr) + kwargs = self.params.create_kwargs( + ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), + ParamTransformerMultiLRScheduler( + [ + (actor.optim, "actor_lr_scheduler_factory"), + (critic1.optim, "critic1_lr_scheduler_factory"), + (critic2.optim, "critic2_lr_scheduler_factory"), + ], + ), + ParamTransformerNoiseFactory("exploration_noise", envs), + ParamTransformerFloatEnvParamFactory("policy_noise", envs), + ParamTransformerFloatEnvParamFactory("noise_clip", envs), + ) + return TD3Policy( + actor=actor.module, + actor_optim=actor.optim, + critic=critic1.module, + critic_optim=critic1.optim, + critic2=critic2.module, + critic2_optim=critic2.optim, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + **kwargs, ) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index b188f36..de5d247 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -5,6 +5,7 @@ from dataclasses import dataclass class RLSamplingConfig: """Sampling, epochs, parallelization, buffers, collectors, and batching.""" + # TODO: What are reasonable defaults? num_epochs: int = 100 step_per_epoch: int = 30000 batch_size: int = 64 diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index bf5ddb6..8f4bd0c 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -20,6 +20,14 @@ class EnvType(Enum): def is_continuous(self): return self == EnvType.CONTINUOUS + def assert_continuous(self, requiring_entity: Any): + if not self.is_continuous(): + raise AssertionError(f"{requiring_entity} requires continuous environments") + + def assert_discrete(self, requiring_entity: Any): + if not self.is_discrete(): + raise AssertionError(f"{requiring_entity} requires discrete environments") + class Environments(ABC): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): @@ -28,7 +36,10 @@ class Environments(ABC): self.test_envs = test_envs def info(self) -> dict[str, Any]: - return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()} + return { + "action_shape": self.get_action_shape(), + "state_shape": self.get_observation_shape(), + } @abstractmethod def get_action_shape(self) -> TShape: @@ -81,7 +92,7 @@ class ContinuousEnvironments(Environments): def get_observation_shape(self) -> TShape: return self.state_shape - def get_type(self): + def get_type(self) -> EnvType: return EnvType.CONTINUOUS diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index e02d613..c483d76 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,25 +1,31 @@ from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from pprint import pprint -from typing import Generic, TypeVar, Callable +from typing import Generic, Self, TypeVar import numpy as np import torch from tianshou.data import Collector -from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory +from tianshou.highlevel.agent import ( + AgentFactory, + PPOAgentFactory, + SACAgentFactory, + TD3AgentFactory, +) from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.module import ( ActorFactory, + ContinuousActorType, CriticFactory, DefaultActorFactory, DefaultCriticFactory, ) from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory -from tianshou.highlevel.params.policy_params import PPOParams, SACParams +from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params from tianshou.policy import BasePolicy from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer @@ -150,7 +156,10 @@ class RLExperimentBuilder: return self def with_optim_factory_default( - self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, + self: TBuilder, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0, ) -> TBuilder: """Configures the use of the default optimizer, Adam, with the given parameters. @@ -174,12 +183,16 @@ class RLExperimentBuilder: def build(self) -> RLExperiment: return RLExperiment( - self._config, self._env_factory, self._create_agent_factory(), self._logger_factory, + self._config, + self._env_factory, + self._create_agent_factory(), + self._logger_factory, ) class _BuilderMixinActorFactory: - def __init__(self): + def __init__(self, continuous_actor_type: ContinuousActorType): + self._continuous_actor_type = continuous_actor_type self._actor_factory: ActorFactory | None = None def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder: @@ -187,7 +200,7 @@ class _BuilderMixinActorFactory: self._actor_factory = actor_factory return self - def with_actor_factory_default( + def _with_actor_factory_default( self: TBuilder, hidden_sizes: Sequence[int], continuous_unbounded=False, @@ -195,6 +208,7 @@ class _BuilderMixinActorFactory: ) -> TBuilder: self: TBuilder | _BuilderMixinActorFactory self._actor_factory = DefaultActorFactory( + self._continuous_actor_type, hidden_sizes, continuous_unbounded=continuous_unbounded, continuous_conditioned_sigma=continuous_conditioned_sigma, @@ -203,11 +217,40 @@ class _BuilderMixinActorFactory: def _get_actor_factory(self): if self._actor_factory is None: - return DefaultActorFactory() + return DefaultActorFactory(self._continuous_actor_type) else: return self._actor_factory +class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): + """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" + + def __init__(self): + super().__init__(ContinuousActorType.GAUSSIAN) + + def with_actor_factory_default( + self, + hidden_sizes: Sequence[int], + continuous_unbounded=False, + continuous_conditioned_sigma=False, + ) -> Self: + return super()._with_actor_factory_default( + hidden_sizes, + continuous_unbounded=continuous_unbounded, + continuous_conditioned_sigma=continuous_conditioned_sigma, + ) + + +class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory): + """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" + + def __init__(self): + super().__init__(ContinuousActorType.DETERMINISTIC) + + def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: + return super()._with_actor_factory_default(hidden_sizes) + + class _BuilderMixinCriticsFactory: def __init__(self, num_critics: int): self._critic_factories: list[CriticFactory | None] = [None] * num_critics @@ -238,7 +281,8 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): return self def with_critic_factory_default( - self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + self: TBuilder, + hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, ) -> TBuilder: self: TBuilder | "_BuilderMixinSingleCriticFactory" self._with_critic_factory_default(0, hidden_sizes) @@ -256,7 +300,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): return self def with_common_critic_factory_default( - self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + self, + hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, ) -> TBuilder: self: TBuilder | "_BuilderMixinDualCriticFactory" for i in range(len(self._critic_factories)): @@ -269,7 +314,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): return self def with_critic1_factory_default( - self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + self, + hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, ) -> TBuilder: self: TBuilder | "_BuilderMixinDualCriticFactory" self._with_critic_factory_default(0, hidden_sizes) @@ -281,7 +327,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): return self def with_critic2_factory_default( - self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + self, + hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, ) -> TBuilder: self: TBuilder | "_BuilderMixinDualCriticFactory" self._with_critic_factory_default(0, hidden_sizes) @@ -289,7 +336,9 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): class PPOExperimentBuilder( - RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory, + RLExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticFactory, ): def __init__( self, @@ -299,12 +348,12 @@ class PPOExperimentBuilder( dist_fn: Callable[[TDistParams], torch.distributions.Distribution], ): super().__init__(experiment_config, env_factory, sampling_config) - _BuilderMixinActorFactory.__init__(self) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self) self._params: PPOParams = PPOParams() self._dist_fn = dist_fn - def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder": + def with_ppo_params(self, params: PPOParams) -> Self: self._params = params return self @@ -316,12 +365,14 @@ class PPOExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), - self._dist_fn + self._dist_fn, ) class SACExperimentBuilder( - RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory, + RLExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinDualCriticFactory, ): def __init__( self, @@ -330,14 +381,51 @@ class SACExperimentBuilder( sampling_config: RLSamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) - _BuilderMixinActorFactory.__init__(self) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinDualCriticFactory.__init__(self) self._params: SACParams = SACParams() - def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder": + def with_sac_params(self, params: SACParams) -> Self: self._params = params return self def _create_agent_factory(self) -> AgentFactory: - return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(), - self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory()) + return SACAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_critic_factory(1), + self._get_optim_factory(), + ) + + +class TD3ExperimentBuilder( + RLExperimentBuilder, + _BuilderMixinActorFactory_ContinuousDeterministic, + _BuilderMixinDualCriticFactory, +): + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, + ): + super().__init__(experiment_config, env_factory, sampling_config) + _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) + _BuilderMixinDualCriticFactory.__init__(self) + self._params: TD3Params = TD3Params() + + def with_td3_params(self, params: TD3Params) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return TD3AgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_critic_factory(1), + self._get_optim_factory(), + ) diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 69db740..c556f4f 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -1,13 +1,13 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Literal +from typing import Literal, TypeAlias from torch.utils.tensorboard import SummaryWriter from tianshou.utils import TensorboardLogger, WandbLogger -TLogger = TensorboardLogger | WandbLogger +TLogger: TypeAlias = TensorboardLogger | WandbLogger @dataclass @@ -30,7 +30,7 @@ class DefaultLoggerFactory(LoggerFactory): wandb_project: str | None = None, ): if logger_type == "wandb" and wandb_project is None: - raise ValueError("Must provide 'wand_project'") + raise ValueError("Must provide 'wandb_project'") self.log_dir = log_dir self.logger_type = logger_type self.wandb_project = wandb_project diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module.py index 979c88e..e3e7f42 100644 --- a/tianshou/highlevel/module.py +++ b/tianshou/highlevel/module.py @@ -1,16 +1,18 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from dataclasses import dataclass +from typing import TypeAlias import numpy as np import torch from torch import nn from tianshou.highlevel.env import Environments, EnvType -from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb -from tianshou.utils.net.continuous import Critic as ContinuousCritic +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.utils.net import continuous +from tianshou.utils.net.common import ActorCritic, Net -TDevice = str | int | torch.device +TDevice: TypeAlias = str | int | torch.device def init_linear_orthogonal(module: torch.nn.Module): @@ -24,6 +26,11 @@ def init_linear_orthogonal(module: torch.nn.Module): torch.nn.init.zeros_(m.bias) +class ContinuousActorType: + GAUSSIAN = "gaussian" + DETERMINISTIC = "deterministic" + + class ActorFactory(ABC): @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> nn.Module: @@ -47,30 +54,36 @@ class ActorFactory(ABC): class DefaultActorFactory(ActorFactory): + """An actor factory which, depending on the type of environment, creates a suitable MLP-based policy.""" + DEFAULT_HIDDEN_SIZES = (64, 64) def __init__( self, + continuous_actor_type: ContinuousActorType, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, continuous_unbounded=False, continuous_conditioned_sigma=False, ): + self.continuous_actor_type = continuous_actor_type self.continuous_unbounded = continuous_unbounded self.continuous_conditioned_sigma = continuous_conditioned_sigma self.hidden_sizes = hidden_sizes - """ - An actor factory which, depending on the type of environment, creates a suitable MLP-based policy - """ - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: env_type = envs.get_type() if env_type == EnvType.CONTINUOUS: - factory = ContinuousActorProbFactory( - self.hidden_sizes, - unbounded=self.continuous_unbounded, - conditioned_sigma=self.continuous_conditioned_sigma, - ) + match self.continuous_actor_type: + case ContinuousActorType.GAUSSIAN: + factory = ContinuousActorFactoryGaussian( + self.hidden_sizes, + unbounded=self.continuous_unbounded, + conditioned_sigma=self.continuous_conditioned_sigma, + ) + case ContinuousActorType.DETERMINISTIC: + factory = ContinuousActorFactoryDeterministic(self.hidden_sizes) + case _: + raise ValueError(self.continuous_actor_type) return factory.create_module(envs, device) elif env_type == EnvType.DISCRETE: raise NotImplementedError @@ -82,8 +95,25 @@ class ContinuousActorFactory(ActorFactory, ABC): """Serves as a type bound for actor factories that are suitable for continuous action spaces.""" +class ContinuousActorFactoryDeterministic(ContinuousActorFactory): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes -class ContinuousActorProbFactory(ContinuousActorFactory): + def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + net_a = Net( + envs.get_observation_shape(), + hidden_sizes=self.hidden_sizes, + device=device, + ) + return continuous.Actor( + net_a, + envs.get_action_shape(), + hidden_sizes=(), + device=device, + ).to(device) + + +class ContinuousActorFactoryGaussian(ContinuousActorFactory): def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): self.hidden_sizes = hidden_sizes self.unbounded = unbounded @@ -96,7 +126,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory): activation=nn.Tanh, device=device, ) - actor = ActorProb( + actor = continuous.ActorProb( net_a, envs.get_action_shape(), unbounded=self.unbounded, @@ -155,6 +185,54 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory): activation=nn.Tanh, device=device, ) - critic = ContinuousCritic(net_c, device=device).to(device) + critic = continuous.Critic(net_c, device=device).to(device) init_linear_orthogonal(critic) return critic + + +@dataclass +class ModuleOpt: + module: torch.nn.Module + optim: torch.optim.Optimizer + + +@dataclass +class ActorCriticModuleOpt: + actor_critic_module: ActorCritic + optim: torch.optim.Optimizer + + @property + def actor(self): + return self.actor_critic_module.actor + + @property + def critic(self): + return self.actor_critic_module.critic + + +class ActorModuleOptFactory: + def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): + self.actor_factory = actor_factory + self.optim_factory = optim_factory + + def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + actor = self.actor_factory.create_module(envs, device) + opt = self.optim_factory.create_optimizer(actor, lr) + return ModuleOpt(actor, opt) + + +class CriticModuleOptFactory: + def __init__( + self, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + use_action: bool, + ): + self.critic_factory = critic_factory + self.optim_factory = optim_factory + self.use_action = use_action + + def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + critic = self.critic_factory.create_module(envs, device, self.use_action) + opt = self.optim_factory.create_optimizer(critic, lr) + return ModuleOpt(critic, opt) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index a2a0be1..d63ce6b 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,13 +1,9 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable from typing import Any import torch -from torch import Tensor from torch.optim import Adam -TParams = Iterable[Tensor] | Iterable[dict[str, Any]] - class OptimizerFactory(ABC): @abstractmethod @@ -38,5 +34,3 @@ class AdamOptimizerFactory(OptimizerFactory): eps=self.eps, weight_decay=self.weight_decay, ) - - diff --git a/tianshou/highlevel/params/env_param.py b/tianshou/highlevel/params/env_param.py new file mode 100644 index 0000000..8798d3a --- /dev/null +++ b/tianshou/highlevel/params/env_param.py @@ -0,0 +1,24 @@ +"""Factories for the generation of environment-dependent parameters.""" +from abc import ABC, abstractmethod +from typing import TypeVar + +from tianshou.highlevel.env import ContinuousEnvironments, Environments + +T = TypeVar("T") + + +class FloatEnvParamFactory(ABC): + @abstractmethod + def create_param(self, envs: Environments) -> float: + pass + + +class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory): + def __init__(self, value: float): + """:param value: value with which to scale the max action value""" + self.value = value + + def create_param(self, envs: Environments) -> float: + envs.get_type().assert_continuous(self) + envs: ContinuousEnvironments + return envs.max_action * self.value diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 80699cd..5c412d8 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import numpy as np import torch -from torch.optim.lr_scheduler import LRScheduler, LambdaLR +from torch.optim.lr_scheduler import LambdaLR, LRScheduler from tianshou.highlevel.config import RLSamplingConfig diff --git a/tianshou/highlevel/params/noise.py b/tianshou/highlevel/params/noise.py new file mode 100644 index 0000000..8017cd3 --- /dev/null +++ b/tianshou/highlevel/params/noise.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.highlevel.env import ContinuousEnvironments, Environments + + +class NoiseFactory(ABC): + @abstractmethod + def create_noise(self, envs: Environments) -> BaseNoise: + pass + + +class MaxActionScaledGaussianNoiseFactory(NoiseFactory): + """Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value. + + This factory can only be applied to continuous action spaces. + """ + + def __init__(self, std_fraction: float): + self.std_fraction = std_fraction + + def create_noise(self, envs: Environments) -> BaseNoise: + envs.get_type().assert_continuous(self) + envs: ContinuousEnvironments + return GaussianNoise(sigma=envs.max_action * self.std_fraction) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 2df996b..03a0417 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -1,21 +1,23 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, asdict -from typing import Dict, Any, Literal +from dataclasses import asdict, dataclass +from typing import Any, Literal import torch from tianshou.exploration import BaseNoise from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.env_param import FloatEnvParamFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory +from tianshou.highlevel.params.noise import NoiseFactory class ParamTransformer(ABC): @abstractmethod - def transform(self, kwargs: Dict[str, Any]) -> None: + def transform(self, kwargs: dict[str, Any]) -> None: pass @staticmethod - def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any: + def get(d: dict[str, Any], key: str, drop: bool = False) -> Any: value = d[key] if drop: del d[key] @@ -24,7 +26,7 @@ class ParamTransformer(ABC): @dataclass class Params: - def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]: + def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]: d = asdict(self) for transformer in transformers: transformer.transform(d) @@ -34,6 +36,7 @@ class Params: @dataclass class PGParams(Params): """Config of general policy-gradient algorithms.""" + discount_factor: float = 0.99 reward_normalization: bool = False deterministic_eval: bool = False @@ -53,6 +56,7 @@ class A2CParams(PGParams): @dataclass class PPOParams(A2CParams): """PPO specific config.""" + eps_clip: float = 0.2 dual_clip: float | None = None value_clip: bool = False @@ -63,7 +67,17 @@ class PPOParams(A2CParams): @dataclass -class SACParams(Params): +class ActorAndDualCriticsParams(Params): + actor_lr: float = 1e-3 + critic1_lr: float = 1e-3 + critic2_lr: float = 1e-3 + actor_lr_scheduler_factory: LRSchedulerFactory | None = None + critic1_lr_scheduler_factory: LRSchedulerFactory | None = None + critic2_lr_scheduler_factory: LRSchedulerFactory | None = None + + +@dataclass +class SACParams(ActorAndDualCriticsParams): tau: float = 0.005 gamma: float = 0.99 alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 @@ -72,9 +86,16 @@ class SACParams(Params): deterministic_eval: bool = True action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" - actor_lr: float = 1e-3 - critic1_lr: float = 1e-3 - critic2_lr: float = 1e-3 - actor_lr_scheduler_factory: LRSchedulerFactory | None = None - critic1_lr_scheduler_factory: LRSchedulerFactory | None = None - critic2_lr_scheduler_factory: LRSchedulerFactory | None = None + + +@dataclass +class TD3Params(ActorAndDualCriticsParams): + tau: float = 0.005 + gamma: float = 0.99 + exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default" + policy_noise: float | FloatEnvParamFactory = 0.2 + noise_clip: float | FloatEnvParamFactory = 0.5 + update_actor_freq: int = 2 + estimation_step: int = 1 + action_scaling: bool = True + action_bound_method: Literal["clip"] | None = "clip"