Allow to configure activation function in default networks
* Set ReLU as default in all actor and critic factories * Configure non-default in applicable MuJoCo examples
This commit is contained in:
parent
ed06ab7ff0
commit
41bd463a7b
@ -5,6 +5,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -74,8 +75,8 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
|
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
|
||||||
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True)
|
||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
@ -76,8 +77,8 @@ def main(
|
|||||||
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
@ -86,8 +87,8 @@ def main(
|
|||||||
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
@ -64,7 +65,7 @@ def main(
|
|||||||
else None,
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
@ -76,8 +77,8 @@ def main(
|
|||||||
critic2_lr=critic_lr,
|
critic2_lr=critic_lr,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.with_common_critic_factory_default(hidden_sizes)
|
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
@ -80,8 +81,8 @@ def main(
|
|||||||
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -80,6 +80,7 @@ from tianshou.highlevel.world import World
|
|||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils import LazyLogger, logging
|
from tianshou.utils import LazyLogger, logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
from tianshou.utils.net.common import ModuleType
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -440,6 +441,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
def _with_actor_factory_default(
|
def _with_actor_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
@ -452,6 +454,7 @@ class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
|||||||
self._actor_factory = ActorFactoryDefault(
|
self._actor_factory = ActorFactoryDefault(
|
||||||
self._continuous_actor_type,
|
self._continuous_actor_type,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
|
hidden_activation=hidden_activation,
|
||||||
continuous_unbounded=continuous_unbounded,
|
continuous_unbounded=continuous_unbounded,
|
||||||
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||||
)
|
)
|
||||||
@ -481,6 +484,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
def with_actor_factory_default(
|
def with_actor_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
@ -488,6 +492,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
The default actor factory uses an MLP-style architecture.
|
The default actor factory uses an MLP-style architecture.
|
||||||
|
|
||||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||||
|
:param hidden_activation: the activation function to use for hidden layers
|
||||||
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
:param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits
|
||||||
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
:param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma)
|
||||||
shall be computed from the input; if False, sigma is an independent parameter.
|
shall be computed from the input; if False, sigma is an independent parameter.
|
||||||
@ -495,6 +500,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
"""
|
"""
|
||||||
return super()._with_actor_factory_default(
|
return super()._with_actor_factory_default(
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
|
hidden_activation=hidden_activation,
|
||||||
continuous_unbounded=continuous_unbounded,
|
continuous_unbounded=continuous_unbounded,
|
||||||
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||||
)
|
)
|
||||||
@ -506,14 +512,19 @@ class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactor
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||||
|
|
||||||
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
def with_actor_factory_default(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
|
) -> Self:
|
||||||
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
"""Defines use of the default actor factory, allowing its parameters it to be customized.
|
||||||
The default actor factory uses an MLP-style architecture.
|
The default actor factory uses an MLP-style architecture.
|
||||||
|
|
||||||
:param hidden_sizes: dimensions of hidden layers used by the network
|
:param hidden_sizes: dimensions of hidden layers used by the network
|
||||||
|
:param hidden_activation: the activation function to use for hidden layers
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
return super()._with_actor_factory_default(hidden_sizes)
|
return super()._with_actor_factory_default(hidden_sizes, hidden_activation)
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinCriticsFactory:
|
class _BuilderMixinCriticsFactory:
|
||||||
@ -525,8 +536,16 @@ class _BuilderMixinCriticsFactory:
|
|||||||
self._critic_factories[idx] = critic_factory
|
self._critic_factories[idx] = critic_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self:
|
def _with_critic_factory_default(
|
||||||
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
self,
|
||||||
|
idx: int,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
|
) -> Self:
|
||||||
|
self._critic_factories[idx] = CriticFactoryDefault(
|
||||||
|
hidden_sizes,
|
||||||
|
hidden_activation=hidden_activation,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _with_critic_factory_use_actor(self, idx: int) -> Self:
|
def _with_critic_factory_use_actor(self, idx: int) -> Self:
|
||||||
@ -559,13 +578,15 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
def with_critic_factory_default(
|
def with_critic_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Makes the critic use the default, MLP-style architecture with the given parameters.
|
"""Makes the critic use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:param hidden_activation: the activation function to use for hidden layers
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -595,14 +616,16 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
def with_common_critic_factory_default(
|
def with_common_critic_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Makes both critics use the default, MLP-style architecture with the given parameters.
|
"""Makes both critics use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:param hidden_activation: the activation function to use for hidden layers
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
self._with_critic_factory_default(i, hidden_sizes)
|
self._with_critic_factory_default(i, hidden_sizes, hidden_activation)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_common_critic_factory_use_actor(self) -> Self:
|
def with_common_critic_factory_use_actor(self) -> Self:
|
||||||
@ -623,13 +646,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
def with_critic1_factory_default(
|
def with_critic1_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
|
"""Makes the first critic use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:param hidden_activation: the activation function to use for hidden layers
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic1_factory_use_actor(self) -> Self:
|
def with_critic1_factory_use_actor(self) -> Self:
|
||||||
@ -648,13 +673,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
def with_critic2_factory_default(
|
def with_critic2_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
|
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
|
"""Makes the second critic use the default, MLP-style architecture with the given parameters.
|
||||||
|
|
||||||
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
:param hidden_sizes: the sequence of dimensions to use in hidden layers of the network
|
||||||
|
:param hidden_activation: the activation function to use for hidden layers
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes, hidden_activation)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic2_factory_use_actor(self) -> Self:
|
def with_critic2_factory_use_actor(self) -> Self:
|
||||||
@ -670,7 +697,7 @@ class _BuilderMixinCriticEnsembleFactory:
|
|||||||
"""Specifies that the given factory shall be used for the critic ensemble.
|
"""Specifies that the given factory shall be used for the critic ensemble.
|
||||||
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used.
|
||||||
|
|
||||||
:param critic_factory: the critic factory
|
:param factory: the critic ensemble factory
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self.critic_ensemble_factory = factory
|
self.critic_ensemble_factory = factory
|
||||||
|
@ -20,7 +20,7 @@ from tianshou.highlevel.module.intermediate import (
|
|||||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import BaseActor, Net
|
from tianshou.utils.net.common import BaseActor, ModuleType, Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@ -92,6 +92,7 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
self,
|
self,
|
||||||
continuous_actor_type: ContinuousActorType,
|
continuous_actor_type: ContinuousActorType,
|
||||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||||
|
hidden_activation: ModuleType = nn.ReLU,
|
||||||
continuous_unbounded: bool = False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma: bool = False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
discrete_softmax: bool = True,
|
discrete_softmax: bool = True,
|
||||||
@ -100,6 +101,7 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
self.continuous_unbounded = continuous_unbounded
|
self.continuous_unbounded = continuous_unbounded
|
||||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
self.discrete_softmax = discrete_softmax
|
self.discrete_softmax = discrete_softmax
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
@ -110,11 +112,15 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
case ContinuousActorType.GAUSSIAN:
|
case ContinuousActorType.GAUSSIAN:
|
||||||
factory = ActorFactoryContinuousGaussianNet(
|
factory = ActorFactoryContinuousGaussianNet(
|
||||||
self.hidden_sizes,
|
self.hidden_sizes,
|
||||||
|
activation=self.hidden_activation,
|
||||||
unbounded=self.continuous_unbounded,
|
unbounded=self.continuous_unbounded,
|
||||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||||
)
|
)
|
||||||
case ContinuousActorType.DETERMINISTIC:
|
case ContinuousActorType.DETERMINISTIC:
|
||||||
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
|
factory = ActorFactoryContinuousDeterministicNet(
|
||||||
|
self.hidden_sizes,
|
||||||
|
activation=self.hidden_activation,
|
||||||
|
)
|
||||||
case ContinuousActorType.UNSUPPORTED:
|
case ContinuousActorType.UNSUPPORTED:
|
||||||
raise ValueError("Continuous action spaces are not supported by the algorithm")
|
raise ValueError("Continuous action spaces are not supported by the algorithm")
|
||||||
case _:
|
case _:
|
||||||
@ -135,13 +141,15 @@ class ActorFactoryContinuous(ActorFactory, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
|
class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
activation=self.activation,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
return continuous.Actor(
|
return continuous.Actor(
|
||||||
@ -158,6 +166,7 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
|||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
unbounded: bool = True,
|
unbounded: bool = True,
|
||||||
conditioned_sigma: bool = False,
|
conditioned_sigma: bool = False,
|
||||||
|
activation: ModuleType = nn.ReLU,
|
||||||
):
|
):
|
||||||
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
""":param hidden_sizes: the sequence of hidden dimensions to use in the network structure
|
||||||
:param unbounded: whether to apply tanh activation on final logits
|
:param unbounded: whether to apply tanh activation on final logits
|
||||||
@ -167,12 +176,13 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
|||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
self.unbounded = unbounded
|
self.unbounded = unbounded
|
||||||
self.conditioned_sigma = conditioned_sigma
|
self.conditioned_sigma = conditioned_sigma
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
activation=nn.Tanh,
|
activation=self.activation,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
actor = continuous.ActorProb(
|
actor = continuous.ActorProb(
|
||||||
@ -192,14 +202,21 @@ class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
|||||||
|
|
||||||
|
|
||||||
class ActorFactoryDiscreteNet(ActorFactory):
|
class ActorFactoryDiscreteNet(ActorFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int], softmax_output: bool = True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
softmax_output: bool = True,
|
||||||
|
activation: ModuleType = nn.ReLU,
|
||||||
|
):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
self.softmax_output = softmax_output
|
self.softmax_output = softmax_output
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
envs.get_observation_shape(),
|
envs.get_observation_shape(),
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
activation=self.activation,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
return discrete.Actor(
|
return discrete.Actor(
|
||||||
|
@ -10,7 +10,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
|||||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import BaseActor, EnsembleLinear, Net
|
from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@ -68,8 +68,13 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
|
|
||||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||||
|
|
||||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||||
|
hidden_activation: ModuleType = nn.ReLU,
|
||||||
|
):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
|
||||||
def create_module(
|
def create_module(
|
||||||
self,
|
self,
|
||||||
@ -82,9 +87,15 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
match env_type:
|
match env_type:
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
factory = CriticFactoryContinuousNet(
|
||||||
|
self.hidden_sizes,
|
||||||
|
activation=self.hidden_activation,
|
||||||
|
)
|
||||||
case EnvType.DISCRETE:
|
case EnvType.DISCRETE:
|
||||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
factory = CriticFactoryDiscreteNet(
|
||||||
|
self.hidden_sizes,
|
||||||
|
activation=self.hidden_activation,
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
return factory.create_module(
|
return factory.create_module(
|
||||||
@ -96,8 +107,9 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuousNet(CriticFactory):
|
class CriticFactoryContinuousNet(CriticFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def create_module(
|
def create_module(
|
||||||
self,
|
self,
|
||||||
@ -112,7 +124,7 @@ class CriticFactoryContinuousNet(CriticFactory):
|
|||||||
action_shape=action_shape,
|
action_shape=action_shape,
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
concat=use_action,
|
concat=use_action,
|
||||||
activation=nn.Tanh,
|
activation=self.activation,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
critic = continuous.Critic(net_c, device=device).to(device)
|
critic = continuous.Critic(net_c, device=device).to(device)
|
||||||
@ -121,8 +133,9 @@ class CriticFactoryContinuousNet(CriticFactory):
|
|||||||
|
|
||||||
|
|
||||||
class CriticFactoryDiscreteNet(CriticFactory):
|
class CriticFactoryDiscreteNet(CriticFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def create_module(
|
def create_module(
|
||||||
self,
|
self,
|
||||||
@ -137,7 +150,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
action_shape=action_shape,
|
action_shape=action_shape,
|
||||||
hidden_sizes=self.hidden_sizes,
|
hidden_sizes=self.hidden_sizes,
|
||||||
concat=use_action,
|
concat=use_action,
|
||||||
activation=nn.Tanh,
|
activation=self.activation,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
last_size = (
|
last_size = (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user