From 41bd463a7bac45dc2435fb9cb95b383d909099f9 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 18 Oct 2023 13:57:36 +0200 Subject: [PATCH] Allow to configure activation function in default networks * Set ReLU as default in all actor and critic factories * Configure non-default in applicable MuJoCo examples --- examples/mujoco/mujoco_a2c_hl.py | 5 +-- examples/mujoco/mujoco_npg_hl.py | 5 +-- examples/mujoco/mujoco_ppo_hl.py | 5 +-- examples/mujoco/mujoco_reinforce_hl.py | 3 +- examples/mujoco/mujoco_td3_hl.py | 5 +-- examples/mujoco/mujoco_trpo_hl.py | 5 +-- tianshou/highlevel/experiment.py | 45 ++++++++++++++++++++------ tianshou/highlevel/module/actor.py | 27 +++++++++++++--- tianshou/highlevel/module/critic.py | 29 ++++++++++++----- 9 files changed, 96 insertions(+), 33 deletions(-) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index ae19684..76de96f 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from typing import Literal from jsonargparse import CLI +from torch import nn from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -74,8 +75,8 @@ def main( ), ) .with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99)) - .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) - .with_critic_factory_default(hidden_sizes) + .with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() ) experiment.run(log_name) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 27c916b..bcf2bc6 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -4,6 +4,7 @@ import os from collections.abc import Sequence from typing import Literal +import torch from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -76,8 +77,8 @@ def main( dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) - .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) - .with_critic_factory_default(hidden_sizes) + .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) experiment.run(log_name) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 8ad374d..0d121aa 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -4,6 +4,7 @@ import os from collections.abc import Sequence from typing import Literal +import torch from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -86,8 +87,8 @@ def main( dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) - .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) - .with_critic_factory_default(hidden_sizes) + .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) experiment.run(log_name) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 33b1462..fed5b06 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -4,6 +4,7 @@ import os from collections.abc import Sequence from typing import Literal +import torch from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -64,7 +65,7 @@ def main( else None, ), ) - .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) + .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .build() ) experiment.run(log_name) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index fd68549..1a5a3ad 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -3,6 +3,7 @@ import os from collections.abc import Sequence +import torch from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -76,8 +77,8 @@ def main( critic2_lr=critic_lr, ), ) - .with_actor_factory_default(hidden_sizes) - .with_common_critic_factory_default(hidden_sizes) + .with_actor_factory_default(hidden_sizes, torch.nn.Tanh) + .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) experiment.run(log_name) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 7df9580..2e11f4e 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -4,6 +4,7 @@ import os from collections.abc import Sequence from typing import Literal +import torch from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -80,8 +81,8 @@ def main( dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) - .with_actor_factory_default(hidden_sizes, continuous_unbounded=True) - .with_critic_factory_default(hidden_sizes) + .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) experiment.run(log_name) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 29488c0..4db4655 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -80,6 +80,7 @@ from tianshou.highlevel.world import World from tianshou.policy import BasePolicy from tianshou.utils import LazyLogger, logging from tianshou.utils.logging import datetime_tag +from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin log = logging.getLogger(__name__) @@ -440,6 +441,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def _with_actor_factory_default( self, hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: @@ -452,6 +454,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol): self._actor_factory = ActorFactoryDefault( self._continuous_actor_type, hidden_sizes, + hidden_activation=hidden_activation, continuous_unbounded=continuous_unbounded, continuous_conditioned_sigma=continuous_conditioned_sigma, ) @@ -481,6 +484,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): def with_actor_factory_default( self, hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, ) -> Self: @@ -488,6 +492,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network + :param hidden_activation: the activation function to use for hidden layers :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) shall be computed from the input; if False, sigma is an independent parameter. @@ -495,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): """ return super()._with_actor_factory_default( hidden_sizes, + hidden_activation=hidden_activation, continuous_unbounded=continuous_unbounded, continuous_conditioned_sigma=continuous_conditioned_sigma, ) @@ -506,14 +512,19 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor def __init__(self) -> None: super().__init__(ContinuousActorType.DETERMINISTIC) - def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: + def with_actor_factory_default( + self, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: """Defines use of the default actor factory, allowing its parameters it to be customized. The default actor factory uses an MLP-style architecture. :param hidden_sizes: dimensions of hidden layers used by the network + :param hidden_activation: the activation function to use for hidden layers :return: the builder """ - return super()._with_actor_factory_default(hidden_sizes) + return super()._with_actor_factory_default(hidden_sizes, hidden_activation) class _BuilderMixinCriticsFactory: @@ -525,8 +536,16 @@ class _BuilderMixinCriticsFactory: self._critic_factories[idx] = critic_factory return self - def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self: - self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes) + def _with_critic_factory_default( + self, + idx: int, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + self._critic_factories[idx] = CriticFactoryDefault( + hidden_sizes, + hidden_activation=hidden_activation, + ) return self def _with_critic_factory_use_actor(self, idx: int) -> Self: @@ -559,13 +578,15 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): def with_critic_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes the critic use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers :return: the builder """ - self._with_critic_factory_default(0, hidden_sizes) + self._with_critic_factory_default(0, hidden_sizes, hidden_activation) return self @@ -595,14 +616,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): def with_common_critic_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes both critics use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers :return: the builder """ for i in range(len(self._critic_factories)): - self._with_critic_factory_default(i, hidden_sizes) + self._with_critic_factory_default(i, hidden_sizes, hidden_activation) return self def with_common_critic_factory_use_actor(self) -> Self: @@ -623,13 +646,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): def with_critic1_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes the first critic use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers :return: the builder """ - self._with_critic_factory_default(0, hidden_sizes) + self._with_critic_factory_default(0, hidden_sizes, hidden_activation) return self def with_critic1_factory_use_actor(self) -> Self: @@ -648,13 +673,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): def with_critic2_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, ) -> Self: """Makes the second critic use the default, MLP-style architecture with the given parameters. :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers :return: the builder """ - self._with_critic_factory_default(0, hidden_sizes) + self._with_critic_factory_default(0, hidden_sizes, hidden_activation) return self def with_critic2_factory_use_actor(self) -> Self: @@ -670,7 +697,7 @@ class _BuilderMixinCriticEnsembleFactory: """Specifies that the given factory shall be used for the critic ensemble. If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used. - :param critic_factory: the critic factory + :param factory: the critic ensemble factory :return: the builder """ self.critic_ensemble_factory = factory diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index b1d4792..2c7b4d7 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -20,7 +20,7 @@ from tianshou.highlevel.module.intermediate import ( from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import BaseActor, Net +from tianshou.utils.net.common import BaseActor, ModuleType, Net from tianshou.utils.string import ToStringMixin @@ -92,6 +92,7 @@ class ActorFactoryDefault(ActorFactory): self, continuous_actor_type: ContinuousActorType, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = nn.ReLU, continuous_unbounded: bool = False, continuous_conditioned_sigma: bool = False, discrete_softmax: bool = True, @@ -100,6 +101,7 @@ class ActorFactoryDefault(ActorFactory): self.continuous_unbounded = continuous_unbounded self.continuous_conditioned_sigma = continuous_conditioned_sigma self.hidden_sizes = hidden_sizes + self.hidden_activation = hidden_activation self.discrete_softmax = discrete_softmax def create_module(self, envs: Environments, device: TDevice) -> BaseActor: @@ -110,11 +112,15 @@ class ActorFactoryDefault(ActorFactory): case ContinuousActorType.GAUSSIAN: factory = ActorFactoryContinuousGaussianNet( self.hidden_sizes, + activation=self.hidden_activation, unbounded=self.continuous_unbounded, conditioned_sigma=self.continuous_conditioned_sigma, ) case ContinuousActorType.DETERMINISTIC: - factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes) + factory = ActorFactoryContinuousDeterministicNet( + self.hidden_sizes, + activation=self.hidden_activation, + ) case ContinuousActorType.UNSUPPORTED: raise ValueError("Continuous action spaces are not supported by the algorithm") case _: @@ -135,13 +141,15 @@ class ActorFactoryContinuous(ActorFactory, ABC): class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): - def __init__(self, hidden_sizes: Sequence[int]): + def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): self.hidden_sizes = hidden_sizes + self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, + activation=self.activation, device=device, ) return continuous.Actor( @@ -158,6 +166,7 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): hidden_sizes: Sequence[int], unbounded: bool = True, conditioned_sigma: bool = False, + activation: ModuleType = nn.ReLU, ): """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure :param unbounded: whether to apply tanh activation on final logits @@ -167,12 +176,13 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): self.hidden_sizes = hidden_sizes self.unbounded = unbounded self.conditioned_sigma = conditioned_sigma + self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, - activation=nn.Tanh, + activation=self.activation, device=device, ) actor = continuous.ActorProb( @@ -192,14 +202,21 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): class ActorFactoryDiscreteNet(ActorFactory): - def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True): + def __init__( + self, + hidden_sizes: Sequence[int], + softmax_output: bool = True, + activation: ModuleType = nn.ReLU, + ): self.hidden_sizes = hidden_sizes self.softmax_output = softmax_output + self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, + activation=self.activation, device=device, ) return discrete.Actor( diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 1576bc2..6d3a7b1 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -10,7 +10,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net +from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net from tianshou.utils.string import ToStringMixin @@ -68,8 +68,13 @@ class CriticFactoryDefault(CriticFactory): DEFAULT_HIDDEN_SIZES = (64, 64) - def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): + def __init__( + self, + hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = nn.ReLU, + ): self.hidden_sizes = hidden_sizes + self.hidden_activation = hidden_activation def create_module( self, @@ -82,9 +87,15 @@ class CriticFactoryDefault(CriticFactory): env_type = envs.get_type() match env_type: case EnvType.CONTINUOUS: - factory = CriticFactoryContinuousNet(self.hidden_sizes) + factory = CriticFactoryContinuousNet( + self.hidden_sizes, + activation=self.hidden_activation, + ) case EnvType.DISCRETE: - factory = CriticFactoryDiscreteNet(self.hidden_sizes) + factory = CriticFactoryDiscreteNet( + self.hidden_sizes, + activation=self.hidden_activation, + ) case _: raise ValueError(f"{env_type} not supported") return factory.create_module( @@ -96,8 +107,9 @@ class CriticFactoryDefault(CriticFactory): class CriticFactoryContinuousNet(CriticFactory): - def __init__(self, hidden_sizes: Sequence[int]): + def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): self.hidden_sizes = hidden_sizes + self.activation = activation def create_module( self, @@ -112,7 +124,7 @@ class CriticFactoryContinuousNet(CriticFactory): action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, - activation=nn.Tanh, + activation=self.activation, device=device, ) critic = continuous.Critic(net_c, device=device).to(device) @@ -121,8 +133,9 @@ class CriticFactoryContinuousNet(CriticFactory): class CriticFactoryDiscreteNet(CriticFactory): - def __init__(self, hidden_sizes: Sequence[int]): + def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): self.hidden_sizes = hidden_sizes + self.activation = activation def create_module( self, @@ -137,7 +150,7 @@ class CriticFactoryDiscreteNet(CriticFactory): action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, - activation=nn.Tanh, + activation=self.activation, device=device, ) last_size = (