From 7af836bd6a42ed6ed5152ef07d4a956da7970096 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 10 Oct 2023 14:14:00 +0200 Subject: [PATCH] Support TRPO in high-level API and add example mujoco_trpo_hl --- examples/mujoco/mujoco_trpo_hl.py | 91 ++++++++++++++++++++++ tianshou/highlevel/agent.py | 26 +++++++ tianshou/highlevel/experiment.py | 34 ++++++++ tianshou/highlevel/params/policy_params.py | 7 ++ 4 files changed, 158 insertions(+) create mode 100644 examples/mujoco/mujoco_trpo_hl.py diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py new file mode 100644 index 0000000..72cb2e8 --- /dev/null +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence +from typing import Literal + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + TRPOExperimentBuilder, +) +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactoryIndependentGaussians, +) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.policy_params import TRPOParams +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "Ant-v3", + buffer_size: int = 4096, + hidden_sizes: Sequence[int] = (64, 64), + lr: float = 1e-3, + gamma: float = 0.99, + epoch: int = 100, + step_per_epoch: int = 30000, + step_per_collect: int = 1024, + repeat_per_collect: int = 1, + batch_size: int = 99999, + training_num: int = 16, + test_num: int = 10, + rew_norm: bool = True, + gae_lambda: float = 0.95, + bound_action_method: Literal["clip", "tanh"] = "clip", + lr_decay: bool = True, + norm_adv: bool = True, + optim_critic_iters: int = 20, + max_kl: float = 0.01, + backtrack_coeff: float = 0.8, + max_backtracks: int = 10, +): + log_name = os.path.join(task, "trpo", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=repeat_per_collect, + ) + + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + + experiment = ( + TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_trpo_params( + TRPOParams( + discount_factor=gamma, + gae_lambda=gae_lambda, + action_bound_method=bound_action_method, + reward_normalization=rew_norm, + advantage_normalization=norm_adv, + optim_critic_iters=optim_critic_iters, + max_kl=max_kl, + backtrack_coeff=backtrack_coeff, + max_backtracks=max_backtracks, + lr=lr, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) + if lr_decay + else None, + dist_fn=DistributionFunctionFactoryIndependentGaussians(), + ), + ) + .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_main(lambda: CLI(main)) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index bfdbb20..ea4afc1 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -30,6 +30,7 @@ from tianshou.highlevel.params.policy_params import ( PPOParams, SACParams, TD3Params, + TRPOParams, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext @@ -43,6 +44,7 @@ from tianshou.policy import ( PPOPolicy, SACPolicy, TD3Policy, + TRPOPolicy, ) from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net import continuous, discrete @@ -455,6 +457,30 @@ class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): return self.create_actor_critic_module_opt(envs, device, self.params.lr) +class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): + def __init__( + self, + params: TRPOParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + critic_use_actor_module: bool, + ): + super().__init__( + params, + sampling_config, + actor_factory, + critic_factory, + optimizer_factory, + TRPOPolicy, + critic_use_actor_module, + ) + + def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: + return self.create_actor_critic_module_opt(envs, device, self.params.lr) + + class DQNAgentFactory(OffpolicyAgentFactory): def __init__( self, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 8cdd5f3..554af84 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -19,6 +19,7 @@ from tianshou.highlevel.agent import ( PPOAgentFactory, SACAgentFactory, TD3AgentFactory, + TRPOAgentFactory, ) from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory, Environments @@ -39,6 +40,7 @@ from tianshou.highlevel.params.policy_params import ( PPOParams, SACParams, TD3Params, + TRPOParams, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.persistence import PersistableConfigProtocol @@ -528,6 +530,38 @@ class NPGExperimentBuilder( ) +class TRPOExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + self._params: TRPOParams = TRPOParams() + + def with_trpo_params(self, params: TRPOParams) -> Self: + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return TRPOAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + self._critic_use_actor_module, + ) + + class DQNExperimentBuilder( ExperimentBuilder, _BuilderMixinActorFactory, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 3a2ad58..0a602f5 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -321,6 +321,13 @@ class NPGParams(PGParams): max_batchsize: int = 256 +@dataclass +class TRPOParams(NPGParams): + max_kl: float = 0.01 + backtrack_coeff: float = 0.8 + max_backtracks: int = 10 + + @dataclass class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): actor_lr: float = 1e-3