From 17ef4dd5ebede9afc800b9c9fe92d7e5fe05805f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 10 Oct 2023 15:49:05 +0200 Subject: [PATCH] Support REDQ in high-level API * Implement example mujoco_redq_hl * Add abstraction CriticEnsembleFactory with default implementations to suit REDQ * Fix type annotation of linear_layer in Net, MLP, Critic (was incompatible with REDQ usage) --- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 87 ++++++++++++++++ tianshou/highlevel/agent.py | 56 +++++++++- tianshou/highlevel/experiment.py | 62 ++++++++++- tianshou/highlevel/module/critic.py | 116 ++++++++++++++++++--- tianshou/highlevel/params/policy_params.py | 16 +++ tianshou/utils/net/common.py | 8 +- tianshou/utils/net/continuous.py | 4 +- 8 files changed, 328 insertions(+), 23 deletions(-) create mode 100644 examples/mujoco/mujoco_redq_hl.py diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 5763d9d..d4e676a 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -38,7 +38,7 @@ def main( test_num: int = 10, ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now) sampling_config = SamplingConfig( num_epochs=epoch, diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py new file mode 100644 index 0000000..8a924ac --- /dev/null +++ b/examples/mujoco/mujoco_redq_hl.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence +from typing import Literal + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + REDQExperimentBuilder, +) +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault +from tianshou.highlevel.params.policy_params import REDQParams +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "Ant-v3", + buffer_size: int = 1000000, + hidden_sizes: Sequence[int] = (256, 256), + ensemble_size: int = 10, + subset_size: int = 2, + actor_lr: float = 1e-3, + critic_lr: float = 1e-3, + gamma: float = 0.99, + tau: float = 0.005, + alpha: float = 0.2, + auto_alpha: bool = False, + alpha_lr: float = 3e-4, + start_timesteps: int = 10000, + epoch: int = 200, + step_per_epoch: int = 5000, + step_per_collect: int = 1, + update_per_step: int = 20, + n_step: int = 1, + batch_size: int = 256, + target_mode: Literal["mean", "min"] = "min", + training_num: int = 1, + test_num: int = 10, +): + log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + repeat_per_collect=None, + start_timesteps=start_timesteps, + start_timesteps_random=True, + ) + + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + + experiment = ( + REDQExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_redq_params( + REDQParams( + actor_lr=actor_lr, + critic_lr=critic_lr, + gamma=gamma, + tau=tau, + alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, + estimation_step=n_step, + target_mode=target_mode, + subset_size=subset_size, + ensemble_size=ensemble_size, + ), + ) + .with_actor_factory_default(hidden_sizes) + .with_critic_ensemble_factory_default(hidden_sizes) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_main(lambda: CLI(main)) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index ea4afc1..f92dc66 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -14,7 +14,7 @@ from tianshou.highlevel.module.actor import ( ActorFactory, ) from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.module.critic import CriticFactory +from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory from tianshou.highlevel.module.module_opt import ( ActorCriticModuleOpt, ) @@ -28,6 +28,7 @@ from tianshou.highlevel.params.policy_params import ( ParamTransformerData, PGParams, PPOParams, + REDQParams, SACParams, TD3Params, TRPOParams, @@ -42,6 +43,7 @@ from tianshou.policy import ( NPGPolicy, PGPolicy, PPOPolicy, + REDQPolicy, SACPolicy, TD3Policy, TRPOPolicy, @@ -565,6 +567,58 @@ class DDPGAgentFactory(OffpolicyAgentFactory): ) +class REDQAgentFactory(OffpolicyAgentFactory): + def __init__( + self, + params: REDQParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic_ensemble_factory: CriticEnsembleFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.critic_ensemble_factory = critic_ensemble_factory + self.actor_factory = actor_factory + self.params = params + self.optim_factory = optim_factory + + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + envs.get_type().assert_continuous(self) + actor = self.actor_factory.create_module_opt( + envs, + device, + self.optim_factory, + self.params.actor_lr, + ) + critic_ensemble = self.critic_ensemble_factory.create_module_opt( + envs, + device, + self.params.ensemble_size, + True, + self.optim_factory, + self.params.critic_lr, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic_ensemble, + ), + ) + action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) + return REDQPolicy( + actor=actor.module, + actor_optim=actor.optim, + critic=critic_ensemble.module, + critic_optim=critic_ensemble.optim, + action_space=action_space, + observation_space=envs.get_observation_space(), + **kwargs, + ) + + class SACAgentFactory(OffpolicyAgentFactory): def __init__( self, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 554af84..eb8eeae 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -17,6 +17,7 @@ from tianshou.highlevel.agent import ( NPGAgentFactory, PGAgentFactory, PPOAgentFactory, + REDQAgentFactory, SACAgentFactory, TD3AgentFactory, TRPOAgentFactory, @@ -29,7 +30,12 @@ from tianshou.highlevel.module.actor import ( ActorFactoryDefault, ContinuousActorType, ) -from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault +from tianshou.highlevel.module.critic import ( + CriticEnsembleFactory, + CriticEnsembleFactoryDefault, + CriticFactory, + CriticFactoryDefault, +) from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.params.policy_params import ( A2CParams, @@ -38,6 +44,7 @@ from tianshou.highlevel.params.policy_params import ( NPGParams, PGParams, PPOParams, + REDQParams, SACParams, TD3Params, TRPOParams, @@ -404,6 +411,28 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): return self +class _BuilderMixinCriticEnsembleFactory: + def __init__(self) -> None: + self.critic_ensemble_factory: CriticEnsembleFactory | None = None + + def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self: + self.critic_ensemble_factory = factory + return self + + def with_critic_ensemble_factory_default( + self, + hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + ) -> Self: + self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) + return self + + def _get_critic_ensemble_factory(self): + if self.critic_ensemble_factory is None: + return CriticEnsembleFactoryDefault() + else: + return self.critic_ensemble_factory + + class PGExperimentBuilder( ExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, @@ -621,6 +650,37 @@ class DDPGExperimentBuilder( ) +class REDQExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinCriticEnsembleFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinCriticEnsembleFactory.__init__(self) + self._params: REDQParams = REDQParams() + + def with_redq_params(self, params: REDQParams) -> Self: + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return REDQAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_ensemble_factory(), + self._get_optim_factory(), + ) + + class SACExperimentBuilder( ExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 914d35f..e6edfc7 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -8,7 +8,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.string import ToStringMixin @@ -39,21 +39,16 @@ class CriticFactoryDefault(CriticFactory): self.hidden_sizes = hidden_sizes def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + factory: CriticFactory env_type = envs.get_type() - if env_type == EnvType.CONTINUOUS: - return CriticFactoryContinuousNet(self.hidden_sizes).create_module( - envs, - device, - use_action, - ) - elif env_type == EnvType.DISCRETE: - return CriticFactoryDiscreteNet(self.hidden_sizes).create_module( - envs, - device, - use_action, - ) - else: - raise ValueError(f"{env_type} not supported") + match env_type: + case EnvType.CONTINUOUS: + factory = CriticFactoryContinuousNet(self.hidden_sizes) + case EnvType.DISCRETE: + factory = CriticFactoryDiscreteNet(self.hidden_sizes) + case _: + raise ValueError(f"{env_type} not supported") + return factory.create_module(envs, device, use_action) class CriticFactoryContinuousNet(CriticFactory): @@ -92,3 +87,94 @@ class CriticFactoryDiscreteNet(CriticFactory): critic = discrete.Critic(net_c, device=device).to(device) init_linear_orthogonal(critic) return critic + + +class CriticEnsembleFactory: + @abstractmethod + def create_module( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + ) -> nn.Module: + pass + + def create_module_opt( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + optim_factory: OptimizerFactory, + lr: float, + ) -> ModuleOpt: + module = self.create_module(envs, device, ensemble_size, use_action) + opt = optim_factory.create_optimizer(module, lr) + return ModuleOpt(module, opt) + + +class CriticEnsembleFactoryDefault(CriticEnsembleFactory): + """A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic.""" + + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): + self.hidden_sizes = hidden_sizes + + def create_module( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + ) -> nn.Module: + env_type = envs.get_type() + factory: CriticEnsembleFactory + match env_type: + case EnvType.CONTINUOUS: + factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes) + case EnvType.DISCRETE: + raise NotImplementedError("No default is implemented for the discrete case") + case _: + raise ValueError(f"{env_type} not supported") + return factory.create_module( + envs, + device, + ensemble_size, + use_action, + ) + + +class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + ) -> nn.Module: + def linear_layer(x: int, y: int) -> EnsembleLinear: + return EnsembleLinear(ensemble_size, x, y) + + action_shape = envs.get_action_shape() if use_action else 0 + net_c = Net( + envs.get_observation_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=nn.Tanh, + device=device, + linear_layer=linear_layer, + ) + critic = continuous.Critic( + net_c, + device=device, + linear_layer=linear_layer, + flatten_input=False, + ).to(device) + init_linear_orthogonal(critic) + return critic diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 0a602f5..7bf5ef4 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -399,6 +399,22 @@ class DDPGParams(Params, ParamsMixinActorAndCritic): return transformers +@dataclass +class REDQParams(DDPGParams): + ensemble_size: int = 10 + subset_size: int = 2 + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2 + estimation_step: int = 1 + actor_delay: int = 20 + deterministic_eval: bool = True + target_mode: Literal["mean", "min"] = "min" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.append(ParamTransformerAutoAlpha("alpha")) + return transformers + + @dataclass class TD3Params(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index fa6ec00..7603454 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -12,6 +12,7 @@ from tianshou.data.types import RecurrentStateBatch ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] TActionShape: TypeAlias = Sequence[int] | int +TLinearLayer: TypeAlias = Callable[[int, int], nn.Module] def miniblock( @@ -77,7 +78,7 @@ class MLP(nn.Module): activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, device: str | int | torch.device | None = None, - linear_layer: type[nn.Linear] = nn.Linear, + linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: super().__init__() @@ -183,7 +184,8 @@ class Net(NetBase): pass a tuple of two dict (first for Q and second for V) stating self-defined arguments as stated in class:`~tianshou.utils.net.common.MLP`. Default to None. - :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param linear_layer: use this module constructor, which takes the input + and output dimension as input, as linear layer. Default to nn.Linear. .. seealso:: @@ -209,7 +211,7 @@ class Net(NetBase): concat: bool = False, num_atoms: int = 1, dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None, - linear_layer: type[nn.Linear] = nn.Linear, + linear_layer: TLinearLayer = nn.Linear, ) -> None: super().__init__() self.device = device diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index d70ff02..c69c249 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch import nn -from tianshou.utils.net.common import MLP, BaseActor, TActionShape +from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -108,7 +108,7 @@ class Critic(nn.Module): hidden_sizes: Sequence[int] = (), device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, - linear_layer: type[nn.Linear] = nn.Linear, + linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: super().__init__()