diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py new file mode 100644 index 0000000..2d6a583 --- /dev/null +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +import datetime +import os +from collections.abc import Sequence +from typing import Literal + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + A2CExperimentBuilder, + RLExperimentConfig, +) +from tianshou.highlevel.optim import OptimizerFactoryRMSprop +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.policy_params import A2CParams + + +def main( + experiment_config: RLExperimentConfig, + task: str = "Ant-v3", + buffer_size: int = 4096, + hidden_sizes: Sequence[int] = (64, 64), + lr: float = 7e-4, + gamma: float = 0.99, + epoch: int = 100, + step_per_epoch: int = 30000, + step_per_collect: int = 80, + repeat_per_collect: int = 1, + batch_size: int = 99999, + training_num: int = 16, + test_num: int = 10, + rew_norm: bool = True, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + gae_lambda: float = 0.95, + bound_action_method: Literal["clip", "tanh"] = "clip", + lr_decay: bool = True, + max_grad_norm: float = 0.5, +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + + sampling_config = RLSamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=repeat_per_collect, + ) + + env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) + + experiment = ( + A2CExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_a2c_params( + A2CParams( + discount_factor=gamma, + gae_lambda=gae_lambda, + action_bound_method=bound_action_method, + reward_normalization=rew_norm, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + lr=lr, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) + if lr_decay + else None, + ), + ) + .with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99)) + .with_actor_factory_default(hidden_sizes) + .with_critic_factory_default(hidden_sizes) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + CLI(main) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 4460af4..814792b 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -65,8 +65,8 @@ def main( return Independent(Normal(*logits), 1) experiment = ( - PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn) - .with_params( + PPOExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_ppo_params( PPOParams( discount_factor=gamma, gae_lambda=gae_lambda, @@ -84,6 +84,7 @@ def main( lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) if lr_decay else None, + dist_fn=dist_fn, ), ) .with_actor_factory_default(hidden_sizes) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f079bdb..bd6cb81 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,6 +1,7 @@ import os from abc import ABC, abstractmethod from collections.abc import Callable +from typing import Generic, TypeVar import torch @@ -19,18 +20,21 @@ from tianshou.highlevel.module import ( ) from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.policy_params import ( + A2CParams, + Params, ParamTransformerData, PPOParams, SACParams, TD3Params, ) -from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy -from tianshou.policy.modelfree.pg import TDistParams +from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net.common import ActorCritic CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" +TParams = TypeVar("TParams", bound=Params) +TPolicy = TypeVar("TPolicy", bound=BasePolicy) class AgentFactory(ABC): @@ -219,15 +223,20 @@ class _ActorAndDualCriticsMixin(_ActorAndCriticMixin): return self.critic2_module_opt_factory.create_module_opt(envs, device, lr) -class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): +class ActorCriticAgentFactory( + Generic[TParams, TPolicy], + OnpolicyAgentFactory, + _ActorCriticMixin, + ABC, +): def __init__( self, - params: PPOParams, + params: TParams, sampling_config: RLSamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - dist_fn: Callable[[TDistParams], torch.distributions.Distribution], + policy_class: type[TPolicy], ): super().__init__(sampling_config) _ActorCriticMixin.__init__( @@ -238,10 +247,14 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): critic_use_action=False, ) self.params = params - self.dist_fn = dist_fn + self.policy_class = policy_class - def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: - actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) + @abstractmethod + def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: + pass + + def _create_kwargs(self, envs: Environments, device: TDevice): + actor_critic = self._create_actor_critic(envs, device) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, @@ -250,15 +263,59 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): optim=actor_critic.optim, ), ) - return PPOPolicy( - actor=actor_critic.actor, - critic=actor_critic.critic, - optim=actor_critic.optim, - dist_fn=self.dist_fn, - action_space=envs.get_action_space(), - **kwargs, + kwargs["actor"] = actor_critic.actor + kwargs["critic"] = actor_critic.critic + kwargs["optim"] = actor_critic.optim + kwargs["action_space"] = envs.get_action_space() + return kwargs + + def create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + return self.policy_class(**self._create_kwargs(envs, device)) + + +class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): + def __init__( + self, + params: A2CParams, + sampling_config: RLSamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + ): + super().__init__( + params, + sampling_config, + actor_factory, + critic_factory, + optimizer_factory, + A2CPolicy, ) + def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: + return self.create_actor_critic_module_opt(envs, device, self.params.lr) + + +class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): + def __init__( + self, + params: PPOParams, + sampling_config: RLSamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + ): + super().__init__( + params, + sampling_config, + actor_factory, + critic_factory, + optimizer_factory, + PPOPolicy, + ) + + def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: + return self.create_actor_critic_module_opt(envs, device, self.params.lr) + class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): def __init__( diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 7893e0d..eebd265 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -9,6 +9,7 @@ import torch from tianshou.data import Collector from tianshou.highlevel.agent import ( + A2CAgentFactory, AgentFactory, PPOAgentFactory, SACAgentFactory, @@ -25,10 +26,14 @@ from tianshou.highlevel.module import ( CriticFactoryDefault, ) from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam -from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params +from tianshou.highlevel.params.policy_params import ( + A2CParams, + PPOParams, + SACParams, + TD3Params, +) from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.policy import BasePolicy -from tianshou.policy.modelfree.pg import TDistParams from tianshou.trainer import BaseTrainer TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -234,7 +239,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" def __init__(self): - super().__init__(ContinuousActorType.DETERMINISTIC) + super().__init__(ContinuousActorType.GAUSSIAN) def with_actor_factory_default( self, @@ -343,6 +348,39 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): return self +class A2CExperimentBuilder( + RLExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticFactory, +): + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, + env_config: PersistableConfigProtocol | None = None, + ): + super().__init__(experiment_config, env_factory, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticFactory.__init__(self) + self._params: A2CParams = A2CParams() + self._env_config = env_config + + def with_a2c_params(self, params: A2CParams) -> Self: + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return A2CAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + class PPOExperimentBuilder( RLExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, @@ -353,14 +391,12 @@ class PPOExperimentBuilder( experiment_config: RLExperimentConfig, env_factory: EnvFactory, sampling_config: RLSamplingConfig, - dist_fn: Callable[[TDistParams], torch.distributions.Distribution], env_config: PersistableConfigProtocol | None = None, ): - super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config) + super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self) self._params: PPOParams = PPOParams() - self._dist_fn = dist_fn self._env_config = env_config def with_ppo_params(self, params: PPOParams) -> Self: @@ -375,7 +411,6 @@ class PPOExperimentBuilder( self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), - self._dist_fn, ) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 685c6d9..ef3071d 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any import torch -from torch.optim import Adam +from torch.optim import Adam, RMSprop class OptimizerFactory(ABC): @@ -43,3 +43,23 @@ class OptimizerFactoryAdam(OptimizerFactory): eps=self.eps, weight_decay=self.weight_decay, ) + + +class OptimizerFactoryRMSprop(OptimizerFactory): + def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False): + self.alpha = alpha + self.momentum = momentum + self.centered = centered + self.weight_decay = weight_decay + self.eps = eps + + def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop: + return RMSprop( + module.parameters(), + lr=lr, + alpha=self.alpha, + eps=self.eps, + weight_decay=self.weight_decay, + momentum=self.momentum, + centered=self.centered, + ) diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py new file mode 100644 index 0000000..a0c1c81 --- /dev/null +++ b/tianshou/highlevel/params/dist_fn.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import TypeAlias + +import torch + +from tianshou.highlevel.env import Environments, EnvType +from tianshou.policy.modelfree.pg import TDistParams + +TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution] + + +class DistributionFunctionFactory(ABC): + @abstractmethod + def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + pass + + +def _dist_fn_categorical(p): + return torch.distributions.Categorical(logits=p) + + +def _dist_fn_gaussian(*p): + return torch.distributions.Independent(torch.distributions.Normal(*p), 1) + + +class DistributionFunctionFactoryDefault(DistributionFunctionFactory): + def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + match envs.get_type(): + case EnvType.DISCRETE: + return _dist_fn_categorical + case EnvType.CONTINUOUS: + return _dist_fn_gaussian + case _: + raise ValueError(envs.get_type()) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 9fcebc6..dca4a60 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -9,6 +9,11 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module import ModuleOpt, TDevice from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactory, + DistributionFunctionFactoryDefault, + TDistributionFunction, +) from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory @@ -34,6 +39,12 @@ class ParamTransformerData: class ParamTransformer(ABC): + """Transforms one or more parameters from the representation used by the high-level API + to the representation required by the (low-level) policy implementation. + It operates directly on a dictionary of keyword arguments, which is initially + generated from the parameter dataclass (subclass of `Params`). + """ + @abstractmethod def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: pass @@ -159,6 +170,18 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer): kwargs[self.key] = value.create_value(data.envs) +class ParamTransformerDistributionFunction(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + value = kwargs[self.key] + if value == "default": + kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) + elif isinstance(value, DistributionFunctionFactory): + kwargs[self.key] = value.create_dist_fn(data.envs) + + class GetParamTransformersProtocol(Protocol): def _get_param_transformers(self) -> list[ParamTransformer]: pass @@ -200,16 +223,23 @@ class PGParams(Params): @dataclass -class A2CParams(PGParams): +class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler): vf_coef: float = 0.5 ent_coef: float = 0.01 max_grad_norm: float | None = None gae_lambda: float = 0.95 max_batchsize: int = 256 + dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) + transformers.append(ParamTransformerDistributionFunction("dist_fn")) + return transformers @dataclass -class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler): +class PPOParams(A2CParams): """PPO specific config.""" eps_clip: float = 0.2