From 2671580c6c3af717c61f315ec7f80ce20bfb3d53 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 3 Oct 2023 20:26:39 +0200 Subject: [PATCH] Add DDPG high-level API and MuJoCo example --- examples/atari/atari_network.py | 5 +- examples/mujoco/mujoco_ddpg_hl.py | 78 +++++++++++++++++++++ tianshou/highlevel/agent.py | 61 +++++++++++++++- tianshou/highlevel/config.py | 2 +- tianshou/highlevel/experiment.py | 37 +++++++++- tianshou/highlevel/params/policy_params.py | 56 +++++++++++++++ tianshou/highlevel/params/policy_wrapper.py | 3 +- 7 files changed, 234 insertions(+), 8 deletions(-) create mode 100644 examples/mujoco/mujoco_ddpg_hl.py diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 8396900..7c18b35 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -247,6 +247,9 @@ class ActorFactoryAtariDQN(ActorFactory): class FeatureNetFactoryDQN(ModuleFactory): def create_module(self, envs: Environments, device: TDevice) -> Module: dqn = DQN( - *envs.get_observation_shape(), envs.get_action_shape(), device, features_only=True, + *envs.get_observation_shape(), + envs.get_action_shape(), + device, + features_only=True, ) return Module(dqn.net, dqn.output_dim) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py new file mode 100644 index 0000000..097be5d --- /dev/null +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import datetime +import os +from collections.abc import Sequence + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + DDPGExperimentBuilder, + RLExperimentConfig, +) +from tianshou.highlevel.params.noise import MaxActionScaledGaussian +from tianshou.highlevel.params.policy_params import DDPGParams + + +def main( + experiment_config: RLExperimentConfig, + task: str = "Ant-v3", + buffer_size: int = 1000000, + hidden_sizes: Sequence[int] = (256, 256), + actor_lr: float = 1e-3, + critic_lr: float = 1e-3, + gamma: float = 0.99, + tau: float = 0.005, + exploration_noise: float = 0.1, + start_timesteps: int = 25000, + epoch: int = 200, + step_per_epoch: int = 5000, + step_per_collect: int = 1, + update_per_step: int = 1, + n_step: int = 1, + batch_size: int = 256, + training_num: int = 1, + test_num: int = 10, +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + + sampling_config = RLSamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + repeat_per_collect=None, + start_timesteps=start_timesteps, + start_timesteps_random=True, + ) + + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + + experiment = ( + DDPGExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_ddpg_params( + DDPGParams( + actor_lr=actor_lr, + critic_lr=critic_lr, + gamma=gamma, + tau=tau, + exploration_noise=MaxActionScaledGaussian(exploration_noise), + estimation_step=n_step, + ), + ) + .with_actor_factory_default(hidden_sizes) + .with_critic_factory_default(hidden_sizes) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + CLI(main) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 16d90ef..87163cc 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -22,6 +22,7 @@ from tianshou.highlevel.module.module_opt import ( from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.policy_params import ( A2CParams, + DDPGParams, Params, ParamTransformerData, PPOParams, @@ -29,7 +30,14 @@ from tianshou.highlevel.params.policy_params import ( TD3Params, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory -from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy +from tianshou.policy import ( + A2CPolicy, + BasePolicy, + DDPGPolicy, + PPOPolicy, + SACPolicy, + TD3Policy, +) from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import ActorCritic @@ -71,7 +79,8 @@ class AgentFactory(ABC): return train_collector, test_collector def set_policy_wrapper_factory( - self, policy_wrapper_factory: PolicyWrapperFactory | None, + self, + policy_wrapper_factory: PolicyWrapperFactory | None, ) -> None: self.policy_wrapper_factory = policy_wrapper_factory @@ -83,7 +92,10 @@ class AgentFactory(ABC): policy = self._create_policy(envs, device) if self.policy_wrapper_factory is not None: policy = self.policy_wrapper_factory.create_wrapped_policy( - policy, envs, self.optim_factory, device, + policy, + envs, + self.optim_factory, + device, ) return policy @@ -372,6 +384,49 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): return self.create_actor_critic_module_opt(envs, device, self.params.lr) +class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin): + def __init__( + self, + params: DDPGParams, + sampling_config: RLSamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + _ActorAndCriticMixin.__init__( + self, + actor_factory, + critic_factory, + optim_factory, + critic_use_action=True, + ) + self.params = params + self.optim_factory = optim_factory + + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + actor = self.create_actor_module_opt(envs, device, self.params.actor_lr) + critic = self.create_critic_module_opt(envs, device, self.params.critic_lr) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic, + ), + ) + return DDPGPolicy( + actor=actor.module, + actor_optim=actor.optim, + critic=critic.module, + critic_optim=critic.optim, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + **kwargs, + ) + + class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( self, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index d90dce1..500dfe8 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -13,7 +13,7 @@ class RLSamplingConfig: num_test_envs: int = 10 buffer_size: int = 4096 step_per_collect: int = 2048 - repeat_per_collect: int = 10 + repeat_per_collect: int | None = 10 update_per_step: int = 1 start_timesteps: int = 0 start_timesteps_random: bool = False diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 3abaed4..4113cdd 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -11,6 +11,7 @@ from tianshou.data import Collector from tianshou.highlevel.agent import ( A2CAgentFactory, AgentFactory, + DDPGAgentFactory, PPOAgentFactory, SACAgentFactory, TD3AgentFactory, @@ -27,6 +28,7 @@ from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.params.policy_params import ( A2CParams, + DDPGParams, PPOParams, SACParams, TD3Params, @@ -406,13 +408,11 @@ class PPOExperimentBuilder( experiment_config: RLExperimentConfig, env_factory: EnvFactory, sampling_config: RLSamplingConfig, - env_config: PersistableConfigProtocol | None = None, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self) self._params: PPOParams = PPOParams() - self._env_config = env_config def with_ppo_params(self, params: PPOParams) -> Self: self._params = params @@ -430,6 +430,39 @@ class PPOExperimentBuilder( ) +class DDPGExperimentBuilder( + RLExperimentBuilder, + _BuilderMixinActorFactory_ContinuousDeterministic, + _BuilderMixinSingleCriticFactory, +): + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, + env_config: PersistableConfigProtocol | None = None, + ): + super().__init__(experiment_config, env_factory, sampling_config) + _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) + _BuilderMixinSingleCriticFactory.__init__(self) + self._params: DDPGParams = DDPGParams() + self._env_config = env_config + + def with_ddpg_params(self, params: DDPGParams) -> Self: + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return DDPGAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + class SACExperimentBuilder( RLExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 2d6c6e9..e2e7257 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -128,6 +128,28 @@ class ParamTransformerMultiLRScheduler(ParamTransformer): params[self.key_scheduler] = lr_scheduler +class ParamTransformerActorAndCriticLRScheduler(ParamTransformer): + def __init__( + self, + key_scheduler_factory_actor: str, + key_scheduler_factory_critic: str, + key_scheduler: str, + ): + self.key_factory_actor = key_scheduler_factory_actor + self.key_factory_critic = key_scheduler_factory_critic + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + transformer = ParamTransformerMultiLRScheduler( + [ + (data.actor.optim, self.key_factory_actor), + (data.critic1.optim, self.key_factory_critic), + ], + self.key_scheduler, + ) + transformer.transform(params, data) + + class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): def __init__( self, @@ -232,6 +254,24 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): ] +@dataclass +class ParamsMixinActorAndCritic(GetParamTransformersProtocol): + actor_lr: float = 1e-3 + critic_lr: float = 1e-3 + actor_lr_scheduler_factory: LRSchedulerFactory | None = None + critic_lr_scheduler_factory: LRSchedulerFactory | None = None + + def _get_param_transformers(self): + return [ + ParamTransformerDrop("actor_lr", "critic_lr"), + ParamTransformerActorAndCriticLRScheduler( + "actor_lr_scheduler_factory", + "critic_lr_scheduler_factory", + "lr_scheduler", + ), + ] + + @dataclass class PGParams(Params): """Config of general policy-gradient algorithms.""" @@ -316,6 +356,22 @@ class SACParams(Params, ParamsMixinActorAndDualCritics): return transformers +@dataclass +class DDPGParams(Params, ParamsMixinActorAndCritic): + tau: float = 0.005 + gamma: float = 0.99 + exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default" + estimation_step: int = 1 + action_scaling: bool = True + action_bound_method: Literal["clip"] | None = "clip" + + def _get_param_transformers(self): + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) + transformers.append(ParamTransformerNoiseFactory("exploration_noise")) + return transformers + + @dataclass class TD3Params(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 008bfd8..fb5d3ee 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -25,7 +25,8 @@ class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC): class PolicyWrapperFactoryIntrinsicCuriosity( - Generic[TPolicyIn], PolicyWrapperFactory[TPolicyIn, ICMPolicy], + Generic[TPolicyIn], + PolicyWrapperFactory[TPolicyIn, ICMPolicy], ): def __init__( self,