Allow to configure activation function in default networks

* Set ReLU as default in all actor and critic factories
* Configure non-default in applicable MuJoCo examples
This commit is contained in:
Dominik Jain 2023-10-18 13:57:36 +02:00
parent ed06ab7ff0
commit 41bd463a7b
9 changed files with 96 additions and 33 deletions

View File

@ -5,6 +5,7 @@ from collections.abc import Sequence
from typing import Literal from typing import Literal
from jsonargparse import CLI from jsonargparse import CLI
from torch import nn
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -74,8 +75,8 @@ def main(
), ),
) )
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99)) .with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes, nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -76,8 +77,8 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(), dist_fn=DistributionFunctionFactoryIndependentGaussians(),
), ),
) )
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -86,8 +87,8 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(), dist_fn=DistributionFunctionFactoryIndependentGaussians(),
), ),
) )
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -64,7 +65,7 @@ def main(
else None, else None,
), ),
) )
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -3,6 +3,7 @@
import os import os
from collections.abc import Sequence from collections.abc import Sequence
import torch
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -76,8 +77,8 @@ def main(
critic2_lr=critic_lr, critic2_lr=critic_lr,
), ),
) )
.with_actor_factory_default(hidden_sizes) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh)
.with_common_critic_factory_default(hidden_sizes) .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -4,6 +4,7 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -80,8 +81,8 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(), dist_fn=DistributionFunctionFactoryIndependentGaussians(),
), ),
) )
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -80,6 +80,7 @@ from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import LazyLogger, logging from tianshou.utils import LazyLogger, logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
from tianshou.utils.net.common import ModuleType
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -440,6 +441,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
def _with_actor_factory_default( def _with_actor_factory_default(
self, self,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
continuous_unbounded: bool = False, continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False, continuous_conditioned_sigma: bool = False,
) -> Self: ) -> Self:
@ -452,6 +454,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
self._actor_factory = ActorFactoryDefault( self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type, self._continuous_actor_type,
hidden_sizes, hidden_sizes,
hidden_activation=hidden_activation,
continuous_unbounded=continuous_unbounded, continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma, continuous_conditioned_sigma=continuous_conditioned_sigma,
) )
@ -481,6 +484,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
def with_actor_factory_default( def with_actor_factory_default(
self, self,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
continuous_unbounded: bool = False, continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False, continuous_conditioned_sigma: bool = False,
) -> Self: ) -> Self:
@ -488,6 +492,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
The default actor factory uses an MLP-style architecture. The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network :param hidden_sizes: dimensions of hidden layers used by the network
:param hidden_activation: the activation function to use for hidden layers
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
shall be computed from the input; if False, sigma is an independent parameter. shall be computed from the input; if False, sigma is an independent parameter.
@ -495,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
""" """
return super()._with_actor_factory_default( return super()._with_actor_factory_default(
hidden_sizes, hidden_sizes,
hidden_activation=hidden_activation,
continuous_unbounded=continuous_unbounded, continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma, continuous_conditioned_sigma=continuous_conditioned_sigma,
) )
@ -506,14 +512,19 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(ContinuousActorType.DETERMINISTIC) super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized. """Defines use of the default actor factory, allowing its parameters it to be customized.
The default actor factory uses an MLP-style architecture. The default actor factory uses an MLP-style architecture.
:param hidden_sizes: dimensions of hidden layers used by the network :param hidden_sizes: dimensions of hidden layers used by the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder :return: the builder
""" """
return super()._with_actor_factory_default(hidden_sizes) return super()._with_actor_factory_default(hidden_sizes, hidden_activation)
class _BuilderMixinCriticsFactory: class _BuilderMixinCriticsFactory:
@ -525,8 +536,16 @@ class _BuilderMixinCriticsFactory:
self._critic_factories[idx] = critic_factory self._critic_factories[idx] = critic_factory
return self return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self: def _with_critic_factory_default(
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes) self,
idx: int,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
self._critic_factories[idx] = CriticFactoryDefault(
hidden_sizes,
hidden_activation=hidden_activation,
)
return self return self
def _with_critic_factory_use_actor(self, idx: int) -> Self: def _with_critic_factory_use_actor(self, idx: int) -> Self:
@ -559,13 +578,15 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def with_critic_factory_default( def with_critic_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self: ) -> Self:
"""Makes the critic use the default, MLP-style architecture with the given parameters. """Makes the critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder :return: the builder
""" """
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
return self return self
@ -595,14 +616,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_common_critic_factory_default( def with_common_critic_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self: ) -> Self:
"""Makes both critics use the default, MLP-style architecture with the given parameters. """Makes both critics use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder :return: the builder
""" """
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes) self._with_critic_factory_default(i, hidden_sizes, hidden_activation)
return self return self
def with_common_critic_factory_use_actor(self) -> Self: def with_common_critic_factory_use_actor(self) -> Self:
@ -623,13 +646,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_critic1_factory_default( def with_critic1_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self: ) -> Self:
"""Makes the first critic use the default, MLP-style architecture with the given parameters. """Makes the first critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder :return: the builder
""" """
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
return self return self
def with_critic1_factory_use_actor(self) -> Self: def with_critic1_factory_use_actor(self) -> Self:
@ -648,13 +673,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_critic2_factory_default( def with_critic2_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self: ) -> Self:
"""Makes the second critic use the default, MLP-style architecture with the given parameters. """Makes the second critic use the default, MLP-style architecture with the given parameters.
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder :return: the builder
""" """
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
return self return self
def with_critic2_factory_use_actor(self) -> Self: def with_critic2_factory_use_actor(self) -> Self:
@ -670,7 +697,7 @@ class _BuilderMixinCriticEnsembleFactory:
"""Specifies that the given factory shall be used for the critic ensemble. """Specifies that the given factory shall be used for the critic ensemble.
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used. If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
:param critic_factory: the critic factory :param factory: the critic ensemble factory
:return: the builder :return: the builder
""" """
self.critic_ensemble_factory = factory self.critic_ensemble_factory = factory

View File

@ -20,7 +20,7 @@ from tianshou.highlevel.module.intermediate import (
from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, Net from tianshou.utils.net.common import BaseActor, ModuleType, Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -92,6 +92,7 @@ class ActorFactoryDefault(ActorFactory):
self, self,
continuous_actor_type: ContinuousActorType, continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = nn.ReLU,
continuous_unbounded: bool = False, continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False, continuous_conditioned_sigma: bool = False,
discrete_softmax: bool = True, discrete_softmax: bool = True,
@ -100,6 +101,7 @@ class ActorFactoryDefault(ActorFactory):
self.continuous_unbounded = continuous_unbounded self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.hidden_activation = hidden_activation
self.discrete_softmax = discrete_softmax self.discrete_softmax = discrete_softmax
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
@ -110,11 +112,15 @@ class ActorFactoryDefault(ActorFactory):
case ContinuousActorType.GAUSSIAN: case ContinuousActorType.GAUSSIAN:
factory = ActorFactoryContinuousGaussianNet( factory = ActorFactoryContinuousGaussianNet(
self.hidden_sizes, self.hidden_sizes,
activation=self.hidden_activation,
unbounded=self.continuous_unbounded, unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma, conditioned_sigma=self.continuous_conditioned_sigma,
) )
case ContinuousActorType.DETERMINISTIC: case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes) factory = ActorFactoryContinuousDeterministicNet(
self.hidden_sizes,
activation=self.hidden_activation,
)
case ContinuousActorType.UNSUPPORTED: case ContinuousActorType.UNSUPPORTED:
raise ValueError("Continuous action spaces are not supported by the algorithm") raise ValueError("Continuous action spaces are not supported by the algorithm")
case _: case _:
@ -135,13 +141,15 @@ class ActorFactoryContinuous(ActorFactory, ABC):
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.activation = activation
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net( net_a = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
activation=self.activation,
device=device, device=device,
) )
return continuous.Actor( return continuous.Actor(
@ -158,6 +166,7 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
unbounded: bool = True, unbounded: bool = True,
conditioned_sigma: bool = False, conditioned_sigma: bool = False,
activation: ModuleType = nn.ReLU,
): ):
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
:param unbounded: whether to apply tanh activation on final logits :param unbounded: whether to apply tanh activation on final logits
@ -167,12 +176,13 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma self.conditioned_sigma = conditioned_sigma
self.activation = activation
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net( net_a = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
activation=nn.Tanh, activation=self.activation,
device=device, device=device,
) )
actor = continuous.ActorProb( actor = continuous.ActorProb(
@ -192,14 +202,21 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
class ActorFactoryDiscreteNet(ActorFactory): class ActorFactoryDiscreteNet(ActorFactory):
def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True): def __init__(
self,
hidden_sizes: Sequence[int],
softmax_output: bool = True,
activation: ModuleType = nn.ReLU,
):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.softmax_output = softmax_output self.softmax_output = softmax_output
self.activation = activation
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
net_a = Net( net_a = Net(
envs.get_observation_shape(), envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
activation=self.activation,
device=device, device=device,
) )
return discrete.Actor( return discrete.Actor(

View File

@ -10,7 +10,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -68,8 +68,13 @@ class CriticFactoryDefault(CriticFactory):
DEFAULT_HIDDEN_SIZES = (64, 64) DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): def __init__(
self,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
hidden_activation: ModuleType = nn.ReLU,
):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.hidden_activation = hidden_activation
def create_module( def create_module(
self, self,
@ -82,9 +87,15 @@ class CriticFactoryDefault(CriticFactory):
env_type = envs.get_type() env_type = envs.get_type()
match env_type: match env_type:
case EnvType.CONTINUOUS: case EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes) factory = CriticFactoryContinuousNet(
self.hidden_sizes,
activation=self.hidden_activation,
)
case EnvType.DISCRETE: case EnvType.DISCRETE:
factory = CriticFactoryDiscreteNet(self.hidden_sizes) factory = CriticFactoryDiscreteNet(
self.hidden_sizes,
activation=self.hidden_activation,
)
case _: case _:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
return factory.create_module( return factory.create_module(
@ -96,8 +107,9 @@ class CriticFactoryDefault(CriticFactory):
class CriticFactoryContinuousNet(CriticFactory): class CriticFactoryContinuousNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.activation = activation
def create_module( def create_module(
self, self,
@ -112,7 +124,7 @@ class CriticFactoryContinuousNet(CriticFactory):
action_shape=action_shape, action_shape=action_shape,
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
concat=use_action, concat=use_action,
activation=nn.Tanh, activation=self.activation,
device=device, device=device,
) )
critic = continuous.Critic(net_c, device=device).to(device) critic = continuous.Critic(net_c, device=device).to(device)
@ -121,8 +133,9 @@ class CriticFactoryContinuousNet(CriticFactory):
class CriticFactoryDiscreteNet(CriticFactory): class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.activation = activation
def create_module( def create_module(
self, self,
@ -137,7 +150,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
action_shape=action_shape, action_shape=action_shape,
hidden_sizes=self.hidden_sizes, hidden_sizes=self.hidden_sizes,
concat=use_action, concat=use_action,
activation=nn.Tanh, activation=self.activation,
device=device, device=device,
) )
last_size = ( last_size = (