Allow to configure activation function in default networks

* Set ReLU as default in all actor and critic factories
* Configure non-default in applicable MuJoCo examples
This commit is contained in:
Dominik Jain 2023-10-18 13:57:36 +02:00
parent ed06ab7ff0
commit 41bd463a7b
9 changed files with 96 additions and 33 deletions

View File

@ -5,6 +5,7 @@ from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from torch import nn
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -74,8 +75,8 @@ def main(
),
)
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, nn.Tanh)
.build()
)
experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -76,8 +77,8 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -86,8 +87,8 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -64,7 +65,7 @@ def main(
else None,
),
)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build()
)
experiment.run(log_name)

View File

@ -3,6 +3,7 @@
import os
from collections.abc import Sequence
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -76,8 +77,8 @@ def main(
critic2_lr=critic_lr,
),
)
.with_actor_factory_default(hidden_sizes)
.with_common_critic_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh)
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -80,8 +81,8 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)

View File

@ -80,6 +80,7 @@ from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy
from tianshou.utils import LazyLogger, logging
from tianshou.utils.logging import datetime_tag
from tianshou.utils.net.common import ModuleType
from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__)
@ -440,6 +441,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
def _with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
@ -452,6 +454,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type,
hidden_sizes,
hidden_activation=hidden_activation,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
)
@ -481,6 +484,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
@ -488,6 +492,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
:param hidden_activation: the activation function to use for hidden layers
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter.
@ -495,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""
return super()._with_actor_factory_default(
hidden_sizes,
hidden_activation=hidden_activation,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
)
@ -506,14 +512,19 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
def __init__(self) -> None:
super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder
"""
return super()._with_actor_factory_default(hidden_sizes)
return super()._with_actor_factory_default(hidden_sizes, hidden_activation)
class _BuilderMixinCriticsFactory:
@ -525,8 +536,16 @@ class _BuilderMixinCriticsFactory:
self._critic_factories[idx] = critic_factory
return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self:
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
def _with_critic_factory_default(
self,
idx: int,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
self._critic_factories[idx] = CriticFactoryDefault(
hidden_sizes,
hidden_activation=hidden_activation,
)
return self
def _with_critic_factory_use_actor(self, idx: int) -> Self:
@ -559,13 +578,15 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def with_critic_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Makes the critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes)
self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
return self
@ -595,14 +616,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_common_critic_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Makes both critics use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder
"""
for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes)
self._with_critic_factory_default(i, hidden_sizes, hidden_activation)
return self
def with_common_critic_factory_use_actor(self) -> Self:
@ -623,13 +646,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_critic1_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes)
self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
return self
def with_critic1_factory_use_actor(self) -> Self:
@ -648,13 +673,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_critic2_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder
"""
self._with_critic_factory_default(0, hidden_sizes)
self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
return self
def with_critic2_factory_use_actor(self) -> Self:
@ -670,7 +697,7 @@ class _BuilderMixinCriticEnsembleFactory:
"""Specifies that the given factory shall be used for the critic ensemble.
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
:param critic_factory: the critic factory
:param factory: the critic ensemble factory
:return: the builder
"""
self.critic_ensemble_factory = factory

View File

@ -20,7 +20,7 @@ from tianshou.highlevel.module.intermediate import (
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, Net
from tianshou.utils.net.common import BaseActor, ModuleType, Net
from tianshou.utils.string import ToStringMixin
@ -92,6 +92,7 @@ class ActorFactoryDefault(ActorFactory):
self,
continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = nn.ReLU,
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
discrete_softmax: bool = True,
@ -100,6 +101,7 @@ class ActorFactoryDefault(ActorFactory):
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
self.hidden_activation = hidden_activation
self.discrete_softmax = discrete_softmax
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
@ -110,11 +112,15 @@ class ActorFactoryDefault(ActorFactory):
case ContinuousActorType.GAUSSIAN:
factory = ActorFactoryContinuousGaussianNet(
self.hidden_sizes,
activation=self.hidden_activation,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
factory = ActorFactoryContinuousDeterministicNet(
self.hidden_sizes,
activation=self.hidden_activation,
)
case ContinuousActorType.UNSUPPORTED:
raise ValueError("Continuous action spaces are not supported by the algorithm")
case _:
@ -135,13 +141,15 @@ class ActorFactoryContinuous(ActorFactory, ABC):
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]):
def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
self.hidden_sizes = hidden_sizes
self.activation = activation
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
activation=self.activation,
device=device,
)
return continuous.Actor(
@ -158,6 +166,7 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
hidden_sizes: Sequence[int],
unbounded: bool = True,
conditioned_sigma: bool = False,
activation: ModuleType = nn.ReLU,
):
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param unbounded: whether to apply tanh activation on final logits
@ -167,12 +176,13 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma
self.activation = activation
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
activation=nn.Tanh,
activation=self.activation,
device=device,
)
actor = continuous.ActorProb(
@ -192,14 +202,21 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
class ActorFactoryDiscreteNet(ActorFactory):
def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True):
def __init__(
self,
hidden_sizes: Sequence[int],
softmax_output: bool = True,
activation: ModuleType = nn.ReLU,
):
self.hidden_sizes = hidden_sizes
self.softmax_output = softmax_output
self.activation = activation
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
activation=self.activation,
device=device,
)
return discrete.Actor(

View File

@ -10,7 +10,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net
from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net
from tianshou.utils.string import ToStringMixin
@ -68,8 +68,13 @@ class CriticFactoryDefault(CriticFactory):
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
def __init__(
self,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = nn.ReLU,
):
self.hidden_sizes = hidden_sizes
self.hidden_activation = hidden_activation
def create_module(
self,
@ -82,9 +87,15 @@ class CriticFactoryDefault(CriticFactory):
env_type = envs.get_type()
match env_type:
case EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
factory = CriticFactoryContinuousNet(
self.hidden_sizes,
activation=self.hidden_activation,
)
case EnvType.DISCRETE:
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
factory = CriticFactoryDiscreteNet(
self.hidden_sizes,
activation=self.hidden_activation,
)
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(
@ -96,8 +107,9 @@ class CriticFactoryDefault(CriticFactory):
class CriticFactoryContinuousNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
self.hidden_sizes = hidden_sizes
self.activation = activation
def create_module(
self,
@ -112,7 +124,7 @@ class CriticFactoryContinuousNet(CriticFactory):
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
activation=self.activation,
device=device,
)
critic = continuous.Critic(net_c, device=device).to(device)
@ -121,8 +133,9 @@ class CriticFactoryContinuousNet(CriticFactory):
class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
self.hidden_sizes = hidden_sizes
self.activation = activation
def create_module(
self,
@ -137,7 +150,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
activation=self.activation,
device=device,
)
last_size = (