diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index bf012d3..ac0d2e1 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -9,18 +9,13 @@ from jsonargparse import CLI from torch.distributions import Independent, Normal from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig +from tianshou.highlevel.agent import PPOConfig from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.experiment import ( - RLExperiment, + PPOExperimentBuilder, RLExperimentConfig, ) -from tianshou.highlevel.logger import DefaultLoggerFactory -from tianshou.highlevel.module import ( - ContinuousActorProbFactory, - ContinuousNetCriticFactory, -) -from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory +from tianshou.highlevel.optim import LinearLRSchedulerFactory def main( @@ -52,7 +47,6 @@ def main( ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - logger_factory = DefaultLoggerFactory() sampling_config = RLSamplingConfig( num_epochs=epoch, @@ -70,37 +64,33 @@ def main( def dist_fn(*logits): return Independent(Normal(*logits), 1) - ppo_config = PPOConfig( - gamma=gamma, - gae_lambda=gae_lambda, - action_bound_method=bound_action_method, - rew_norm=rew_norm, - ent_coef=ent_coef, - vf_coef=vf_coef, - max_grad_norm=max_grad_norm, - value_clip=value_clip, - norm_adv=norm_adv, - eps_clip=eps_clip, - dual_clip=dual_clip, - recompute_adv=recompute_adv, + experiment = ( + PPOExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_ppo_params( + PPOConfig( + gamma=gamma, + gae_lambda=gae_lambda, + action_bound_method=bound_action_method, + rew_norm=rew_norm, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + value_clip=value_clip, + norm_adv=norm_adv, + eps_clip=eps_clip, + dual_clip=dual_clip, + recompute_adv=recompute_adv, + dist_fn=dist_fn, + lr=lr, + lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config) + if lr_decay + else None, + ), + ) + .with_actor_factory_default(hidden_sizes) + .with_critic_factory_default(hidden_sizes) + .build() ) - actor_factory = ContinuousActorProbFactory(hidden_sizes) - critic_factory = ContinuousNetCriticFactory(hidden_sizes) - optim_factory = AdamOptimizerFactory() - lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None - agent_factory = PPOAgentFactory( - ppo_config, - sampling_config, - actor_factory, - critic_factory, - optim_factory, - dist_fn, - lr, - lr_scheduler_factory, - ) - - experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory) - experiment.run(log_name) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index d76ff21..1d5ceb2 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,18 +7,12 @@ from collections.abc import Sequence from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig +from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.experiment import ( - RLExperiment, RLExperimentConfig, + SACExperimentBuilder, ) -from tianshou.highlevel.logger import DefaultLoggerFactory -from tianshou.highlevel.module import ( - ContinuousActorProbFactory, - ContinuousNetCriticFactory, -) -from tianshou.highlevel.optim import AdamOptimizerFactory def main( @@ -45,7 +39,6 @@ def main( ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "sac", str(experiment_config.seed), now) - logger_factory = DefaultLoggerFactory() sampling_config = RLSamplingConfig( num_epochs=epoch, @@ -62,31 +55,25 @@ def main( env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) - if auto_alpha: - alpha = DefaultAutoAlphaFactory(lr=alpha_lr) - sac_config = SACConfig( - tau=tau, - gamma=gamma, - alpha=alpha, - estimation_step=n_step, - actor_lr=actor_lr, - critic1_lr=critic_lr, - critic2_lr=critic_lr, + experiment = ( + SACExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_sac_params( + SACConfig( + tau=tau, + gamma=gamma, + alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha, + estimation_step=n_step, + actor_lr=actor_lr, + critic1_lr=critic_lr, + critic2_lr=critic_lr, + ), + ) + .with_actor_factory_default( + hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True, + ) + .with_common_critic_factory_default(hidden_sizes) + .build() ) - actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) - critic_factory = ContinuousNetCriticFactory(hidden_sizes) - optim_factory = AdamOptimizerFactory() - agent_factory = SACAgentFactory( - sac_config, - sampling_config, - actor_factory, - critic_factory, - critic_factory, - optim_factory, - ) - - experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory) - experiment.run(log_name) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 4c49b13..91fe3c4 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -161,6 +161,9 @@ class PPOConfig(PGConfig): eps_clip: float = 0.2 dual_clip: float | None = None recompute_adv: bool = True + dist_fn: Callable = None + lr: float = 1e-3 + lr_scheduler_factory: LRSchedulerFactory | None = None class PPOAgentFactory(OnpolicyAgentFactory): @@ -171,26 +174,20 @@ class PPOAgentFactory(OnpolicyAgentFactory): actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory, - dist_fn, - lr: float, - lr_scheduler_factory: LRSchedulerFactory | None = None, ): super().__init__(sampling_config) self.optimizer_factory = optimizer_factory self.critic_factory = critic_factory self.actor_factory = actor_factory self.config = config - self.lr = lr - self.lr_scheduler_factory = lr_scheduler_factory - self.dist_fn = dist_fn def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=False) actor_critic = ActorCritic(actor, critic) - optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr) - if self.lr_scheduler_factory is not None: - lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim) + optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr) + if self.config.lr_scheduler_factory is not None: + lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim) else: lr_scheduler = None return PPOPolicy( @@ -198,7 +195,7 @@ class PPOAgentFactory(OnpolicyAgentFactory): actor, critic, optim, - dist_fn=self.dist_fn, + dist_fn=self.config.dist_fn, lr_scheduler=lr_scheduler, # env-stuff action_space=envs.get_action_space(), diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 90a8044..273bb0a 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from enum import Enum from typing import Any import gymnasium as gym @@ -9,6 +10,17 @@ from tianshou.env import BaseVectorEnv TShape = int | Sequence[int] +class EnvType(Enum): + CONTINUOUS = "continuous" + DISCRETE = "discrete" + + def is_discrete(self): + return self == EnvType.DISCRETE + + def is_continuous(self): + return self == EnvType.CONTINUOUS + + class Environments(ABC): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env @@ -29,6 +41,10 @@ class Environments(ABC): def get_action_space(self) -> gym.Space: return self.env.action_space + @abstractmethod + def get_type(self) -> EnvType: + pass + class ContinuousEnvironments(Environments): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): @@ -62,6 +78,9 @@ class ContinuousEnvironments(Environments): def get_state_shape(self) -> TShape: return self.state_shape + def get_type(self): + return EnvType.CONTINUOUS + class EnvFactory(ABC): @abstractmethod diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 1012028..7b65a1f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,3 +1,5 @@ +from abc import abstractmethod +from collections.abc import Sequence from dataclasses import dataclass from pprint import pprint from typing import Generic, TypeVar @@ -6,9 +8,17 @@ import numpy as np import torch from tianshou.data import Collector -from tianshou.highlevel.agent import AgentFactory +from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig +from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import EnvFactory -from tianshou.highlevel.logger import LoggerFactory +from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory +from tianshou.highlevel.module import ( + ActorFactory, + CriticFactory, + DefaultActorFactory, + DefaultCriticFactory, +) +from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer @@ -38,13 +48,15 @@ class RLExperiment(Generic[TPolicy, TTrainer]): self, config: RLExperimentConfig, env_factory: EnvFactory, - logger_factory: LoggerFactory, agent_factory: AgentFactory, + logger_factory: LoggerFactory | None = None, ): + if logger_factory is None: + logger_factory = DefaultLoggerFactory() self.config = config self.env_factory = env_factory - self.logger_factory = logger_factory self.agent_factory = agent_factory + self.logger_factory = logger_factory def _set_seed(self) -> None: seed = self.config.seed @@ -109,3 +121,214 @@ class RLExperiment(Generic[TPolicy, TTrainer]): test_collector.reset() result = test_collector.collect(n_episode=num_episodes, render=render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder") + + +class RLExperimentBuilder: + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, + ): + self._config = experiment_config + self._env_factory = env_factory + self._sampling_config = sampling_config + self._logger_factory: LoggerFactory | None = None + self._optim_factory: OptimizerFactory | None = None + + def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder: + self._logger_factory = logger_factory + return self + + def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder: + self._optim_factory = optim_factory + return self + + def with_optim_factory_default( + self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, + ) -> TBuilder: + """Configures the use of the default optimizer, Adam, with the given parameters. + + :param betas: coefficients used for computing running averages of gradient and its square + :param eps: term added to the denominator to improve numerical stability + :param weight_decay: weight decay (L2 penalty) + :return: the builder + """ + self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay) + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + pass + + def _get_optim_factory(self) -> OptimizerFactory: + if self._optim_factory is None: + return AdamOptimizerFactory() + else: + return self._optim_factory + + def build(self) -> RLExperiment: + return RLExperiment( + self._config, self._env_factory, self._create_agent_factory(), self._logger_factory, + ) + + +class _BuilderMixinActorFactory: + def __init__(self): + self._actor_factory: ActorFactory | None = None + + def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder: + self: TBuilder | _BuilderMixinActorFactory + self._actor_factory = actor_factory + return self + + def with_actor_factory_default( + self: TBuilder, + hidden_sizes: Sequence[int], + continuous_unbounded=False, + continuous_conditioned_sigma=False, + ) -> TBuilder: + self: TBuilder | _BuilderMixinActorFactory + self._actor_factory = DefaultActorFactory( + hidden_sizes, + continuous_unbounded=continuous_unbounded, + continuous_conditioned_sigma=continuous_conditioned_sigma, + ) + return self + + def _get_actor_factory(self): + if self._actor_factory is None: + return DefaultActorFactory() + else: + return self._actor_factory + + +class _BuilderMixinCriticsFactory: + def __init__(self, num_critics: int): + self._critic_factories: list[CriticFactory | None] = [None] * num_critics + + def _with_critic_factory(self, idx: int, critic_factory: CriticFactory): + self._critic_factories[idx] = critic_factory + return self + + def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]): + self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes) + return self + + def _get_critic_factory(self, idx: int): + factory = self._critic_factories[idx] + if factory is None: + return DefaultCriticFactory() + else: + return factory + + +class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): + def __init__(self): + super().__init__(1) + + def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: + self: TBuilder | "_BuilderMixinSingleCriticFactory" + self._with_critic_factory(0, critic_factory) + return self + + def with_critic_factory_default( + self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + ) -> TBuilder: + self: TBuilder | "_BuilderMixinSingleCriticFactory" + self._with_critic_factory_default(0, hidden_sizes) + return self + + +class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): + def __init__(self): + super().__init__(2) + + def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: + self: TBuilder | "_BuilderMixinDualCriticFactory" + for i in range(len(self._critic_factories)): + self._with_critic_factory(i, critic_factory) + return self + + def with_common_critic_factory_default( + self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + ) -> TBuilder: + self: TBuilder | "_BuilderMixinDualCriticFactory" + for i in range(len(self._critic_factories)): + self._with_critic_factory_default(i, hidden_sizes) + return self + + def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: + self: TBuilder | "_BuilderMixinDualCriticFactory" + self._with_critic_factory(0, critic_factory) + return self + + def with_critic1_factory_default( + self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + ) -> TBuilder: + self: TBuilder | "_BuilderMixinDualCriticFactory" + self._with_critic_factory_default(0, hidden_sizes) + return self + + def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: + self: TBuilder | "_BuilderMixinDualCriticFactory" + self._with_critic_factory(1, critic_factory) + return self + + def with_critic2_factory_default( + self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES, + ) -> TBuilder: + self: TBuilder | "_BuilderMixinDualCriticFactory" + self._with_critic_factory_default(0, hidden_sizes) + return self + + +class PPOExperimentBuilder( + RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory, +): + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, + ): + super().__init__(experiment_config, env_factory, sampling_config) + _BuilderMixinActorFactory.__init__(self) + _BuilderMixinSingleCriticFactory.__init__(self) + self._params: PPOConfig = PPOConfig() + + def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder": + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return PPOAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class SACExperimentBuilder( + RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory, +): + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, + ): + super().__init__(experiment_config, env_factory, sampling_config) + _BuilderMixinActorFactory.__init__(self) + _BuilderMixinDualCriticFactory.__init__(self) + self._params: SACConfig = SACConfig() + + def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder": + self._params = params + return self diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module.py index 8686218..17c5767 100644 --- a/tianshou/highlevel/module.py +++ b/tianshou/highlevel/module.py @@ -5,7 +5,7 @@ import numpy as np import torch from torch import nn -from tianshou.highlevel.env import Environments +from tianshou.highlevel.env import Environments, EnvType from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.continuous import Critic as ContinuousCritic @@ -46,8 +46,41 @@ class ActorFactory(ABC): m.weight.data.copy_(0.01 * m.weight.data) +class DefaultActorFactory(ActorFactory): + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__( + self, + hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, + continuous_unbounded=False, + continuous_conditioned_sigma=False, + ): + self.continuous_unbounded = continuous_unbounded + self.continuous_conditioned_sigma = continuous_conditioned_sigma + self.hidden_sizes = hidden_sizes + + """ + An actor factory which, depending on the type of environment, creates a suitable MLP-based policy + """ + + def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + env_type = envs.get_type() + if env_type == EnvType.CONTINUOUS: + factory = ContinuousActorProbFactory( + self.hidden_sizes, + unbounded=self.continuous_unbounded, + conditioned_sigma=self.continuous_conditioned_sigma, + ) + return factory.create_module(envs, device) + elif env_type == EnvType.DISCRETE: + raise NotImplementedError + else: + raise ValueError(f"{env_type} not supported") + + class ContinuousActorFactory(ActorFactory, ABC): - pass + """Serves as a type bound for actor factories that are suitable for continuous action spaces.""" + class ContinuousActorProbFactory(ContinuousActorFactory): @@ -85,13 +118,31 @@ class CriticFactory(ABC): pass +class DefaultCriticFactory(CriticFactory): + """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" + + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): + self.hidden_sizes = hidden_sizes + + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + env_type = envs.get_type() + if env_type == EnvType.CONTINUOUS: + factory = ContinuousNetCriticFactory(self.hidden_sizes) + return factory.create_module(envs, device, use_action) + elif env_type == EnvType.DISCRETE: + raise NotImplementedError + else: + raise ValueError(f"{env_type} not supported") + + class ContinuousCriticFactory(CriticFactory, ABC): pass class ContinuousNetCriticFactory(ContinuousCriticFactory): - def __init__(self, hidden_sizes: Sequence[int], action_shape=0): - self.action_shape = action_shape + def __init__(self, hidden_sizes: Sequence[int]): self.hidden_sizes = hidden_sizes def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index edbe0a4..ee6677a 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -29,8 +29,19 @@ class TorchOptimizerFactory(OptimizerFactory): class AdamOptimizerFactory(OptimizerFactory): + def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0): + self.weight_decay = weight_decay + self.eps = eps + self.betas = betas + def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: - return Adam(module.parameters(), lr=lr) + return Adam( + module.parameters(), + lr=lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + ) class LRSchedulerFactory(ABC):