diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py new file mode 100644 index 0000000..ce04658 --- /dev/null +++ b/examples/mujoco/mujoco_npg_hl.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence +from typing import Literal + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + NPGExperimentBuilder, +) +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactoryIndependentGaussians, +) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.policy_params import NPGParams +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "Ant-v3", + buffer_size: int = 4096, + hidden_sizes: Sequence[int] = (64, 64), + lr: float = 1e-3, + gamma: float = 0.99, + epoch: int = 100, + step_per_epoch: int = 30000, + step_per_collect: int = 1024, + repeat_per_collect: int = 1, + batch_size: int = 99999, + training_num: int = 16, + test_num: int = 10, + rew_norm: bool = True, + gae_lambda: float = 0.95, + bound_action_method: Literal["clip", "tanh"] = "clip", + lr_decay: bool = True, + norm_adv: bool = True, + optim_critic_iters: int = 20, + actor_step_size: float = 0.1, +): + log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=repeat_per_collect, + ) + + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + + experiment = ( + NPGExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_npg_params( + NPGParams( + discount_factor=gamma, + gae_lambda=gae_lambda, + action_bound_method=bound_action_method, + reward_normalization=rew_norm, + advantage_normalization=norm_adv, + optim_critic_iters=optim_critic_iters, + actor_step_size=actor_step_size, + lr=lr, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) + if lr_decay + else None, + dist_fn=DistributionFunctionFactoryIndependentGaussians(), + ), + ) + .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_main(lambda: CLI(main)) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 7851efa..bfdbb20 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -23,6 +23,7 @@ from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, DQNParams, + NPGParams, Params, ParamTransformerData, PGParams, @@ -37,6 +38,7 @@ from tianshou.policy import ( BasePolicy, DDPGPolicy, DQNPolicy, + NPGPolicy, PGPolicy, PPOPolicy, SACPolicy, @@ -429,6 +431,30 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): return self.create_actor_critic_module_opt(envs, device, self.params.lr) +class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): + def __init__( + self, + params: NPGParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + critic_use_actor_module: bool, + ): + super().__init__( + params, + sampling_config, + actor_factory, + critic_factory, + optimizer_factory, + NPGPolicy, + critic_use_actor_module, + ) + + def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: + return self.create_actor_critic_module_opt(envs, device, self.params.lr) + + class DQNAgentFactory(OffpolicyAgentFactory): def __init__( self, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 77e94ec..8cdd5f3 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -14,6 +14,7 @@ from tianshou.highlevel.agent import ( AgentFactory, DDPGAgentFactory, DQNAgentFactory, + NPGAgentFactory, PGAgentFactory, PPOAgentFactory, SACAgentFactory, @@ -33,6 +34,7 @@ from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, DQNParams, + NPGParams, PGParams, PPOParams, SACParams, @@ -494,6 +496,38 @@ class PPOExperimentBuilder( ) +class NPGExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) + self._params: NPGParams = NPGParams() + + def with_npg_params(self, params: NPGParams) -> Self: + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return NPGAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + self._critic_use_actor_module, + ) + + class DQNExperimentBuilder( ExperimentBuilder, _BuilderMixinActorFactory, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index f052aec..3a2ad58 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -312,6 +312,15 @@ class PPOParams(A2CParams): recompute_advantage: bool = False +@dataclass +class NPGParams(PGParams): + optim_critic_iters: int = 5 + actor_step_size: float = 0.5 + advantage_normalization: bool = True + gae_lambda: float = 0.95 + max_batchsize: int = 256 + + @dataclass class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): actor_lr: float = 1e-3