Support REDQ in high-level API
* Implement example mujoco_redq_hl * Add abstraction CriticEnsembleFactory with default implementations to suit REDQ * Fix type annotation of linear_layer in Net, MLP, Critic (was incompatible with REDQ usage)
This commit is contained in:
parent
7af836bd6a
commit
17ef4dd5eb
@ -38,7 +38,7 @@ def main(
|
|||||||
test_num: int = 10,
|
test_num: int = 10,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
|
87
examples/mujoco/mujoco_redq_hl.py
Normal file
87
examples/mujoco/mujoco_redq_hl.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
ExperimentConfig,
|
||||||
|
REDQExperimentBuilder,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||||
|
from tianshou.highlevel.params.policy_params import REDQParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
|
task: str = "Ant-v3",
|
||||||
|
buffer_size: int = 1000000,
|
||||||
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
|
ensemble_size: int = 10,
|
||||||
|
subset_size: int = 2,
|
||||||
|
actor_lr: float = 1e-3,
|
||||||
|
critic_lr: float = 1e-3,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
tau: float = 0.005,
|
||||||
|
alpha: float = 0.2,
|
||||||
|
auto_alpha: bool = False,
|
||||||
|
alpha_lr: float = 3e-4,
|
||||||
|
start_timesteps: int = 10000,
|
||||||
|
epoch: int = 200,
|
||||||
|
step_per_epoch: int = 5000,
|
||||||
|
step_per_collect: int = 1,
|
||||||
|
update_per_step: int = 20,
|
||||||
|
n_step: int = 1,
|
||||||
|
batch_size: int = 256,
|
||||||
|
target_mode: Literal["mean", "min"] = "min",
|
||||||
|
training_num: int = 1,
|
||||||
|
test_num: int = 10,
|
||||||
|
):
|
||||||
|
log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag())
|
||||||
|
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
repeat_per_collect=None,
|
||||||
|
start_timesteps=start_timesteps,
|
||||||
|
start_timesteps_random=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
.with_redq_params(
|
||||||
|
REDQParams(
|
||||||
|
actor_lr=actor_lr,
|
||||||
|
critic_lr=critic_lr,
|
||||||
|
gamma=gamma,
|
||||||
|
tau=tau,
|
||||||
|
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
|
||||||
|
estimation_step=n_step,
|
||||||
|
target_mode=target_mode,
|
||||||
|
subset_size=subset_size,
|
||||||
|
ensemble_size=ensemble_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory_default(hidden_sizes)
|
||||||
|
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.run_main(lambda: CLI(main))
|
@ -14,7 +14,7 @@ from tianshou.highlevel.module.actor import (
|
|||||||
ActorFactory,
|
ActorFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.core import TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
from tianshou.highlevel.module.critic import CriticFactory
|
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
|
||||||
from tianshou.highlevel.module.module_opt import (
|
from tianshou.highlevel.module.module_opt import (
|
||||||
ActorCriticModuleOpt,
|
ActorCriticModuleOpt,
|
||||||
)
|
)
|
||||||
@ -28,6 +28,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PGParams,
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
|
REDQParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
TRPOParams,
|
TRPOParams,
|
||||||
@ -42,6 +43,7 @@ from tianshou.policy import (
|
|||||||
NPGPolicy,
|
NPGPolicy,
|
||||||
PGPolicy,
|
PGPolicy,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
|
REDQPolicy,
|
||||||
SACPolicy,
|
SACPolicy,
|
||||||
TD3Policy,
|
TD3Policy,
|
||||||
TRPOPolicy,
|
TRPOPolicy,
|
||||||
@ -565,6 +567,58 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class REDQAgentFactory(OffpolicyAgentFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: REDQParams,
|
||||||
|
sampling_config: SamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_ensemble_factory: CriticEnsembleFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config, optim_factory)
|
||||||
|
self.critic_ensemble_factory = critic_ensemble_factory
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.params = params
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
envs.get_type().assert_continuous(self)
|
||||||
|
actor = self.actor_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.actor_lr,
|
||||||
|
)
|
||||||
|
critic_ensemble = self.critic_ensemble_factory.create_module_opt(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
self.params.ensemble_size,
|
||||||
|
True,
|
||||||
|
self.optim_factory,
|
||||||
|
self.params.critic_lr,
|
||||||
|
)
|
||||||
|
kwargs = self.params.create_kwargs(
|
||||||
|
ParamTransformerData(
|
||||||
|
envs=envs,
|
||||||
|
device=device,
|
||||||
|
optim_factory=self.optim_factory,
|
||||||
|
actor=actor,
|
||||||
|
critic1=critic_ensemble,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
action_space = cast(gymnasium.spaces.Box, envs.get_action_space())
|
||||||
|
return REDQPolicy(
|
||||||
|
actor=actor.module,
|
||||||
|
actor_optim=actor.optim,
|
||||||
|
critic=critic_ensemble.module,
|
||||||
|
critic_optim=critic_ensemble.optim,
|
||||||
|
action_space=action_space,
|
||||||
|
observation_space=envs.get_observation_space(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory):
|
class SACAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -17,6 +17,7 @@ from tianshou.highlevel.agent import (
|
|||||||
NPGAgentFactory,
|
NPGAgentFactory,
|
||||||
PGAgentFactory,
|
PGAgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
|
REDQAgentFactory,
|
||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
TD3AgentFactory,
|
TD3AgentFactory,
|
||||||
TRPOAgentFactory,
|
TRPOAgentFactory,
|
||||||
@ -29,7 +30,12 @@ from tianshou.highlevel.module.actor import (
|
|||||||
ActorFactoryDefault,
|
ActorFactoryDefault,
|
||||||
ContinuousActorType,
|
ContinuousActorType,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
|
from tianshou.highlevel.module.critic import (
|
||||||
|
CriticEnsembleFactory,
|
||||||
|
CriticEnsembleFactoryDefault,
|
||||||
|
CriticFactory,
|
||||||
|
CriticFactoryDefault,
|
||||||
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
A2CParams,
|
A2CParams,
|
||||||
@ -38,6 +44,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
NPGParams,
|
NPGParams,
|
||||||
PGParams,
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
|
REDQParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
TRPOParams,
|
TRPOParams,
|
||||||
@ -404,6 +411,28 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinCriticEnsembleFactory:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
|
||||||
|
|
||||||
|
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||||
|
self.critic_ensemble_factory = factory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_critic_ensemble_factory_default(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
|
) -> Self:
|
||||||
|
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_critic_ensemble_factory(self):
|
||||||
|
if self.critic_ensemble_factory is None:
|
||||||
|
return CriticEnsembleFactoryDefault()
|
||||||
|
else:
|
||||||
|
return self.critic_ensemble_factory
|
||||||
|
|
||||||
|
|
||||||
class PGExperimentBuilder(
|
class PGExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
@ -621,6 +650,37 @@ class DDPGExperimentBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class REDQExperimentBuilder(
|
||||||
|
ExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
_BuilderMixinCriticEnsembleFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig | None = None,
|
||||||
|
sampling_config: SamplingConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
|
_BuilderMixinCriticEnsembleFactory.__init__(self)
|
||||||
|
self._params: REDQParams = REDQParams()
|
||||||
|
|
||||||
|
def with_redq_params(self, params: REDQParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return REDQAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_ensemble_factory(),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACExperimentBuilder(
|
class SACExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
@ -8,7 +8,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
|||||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@ -39,21 +39,16 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
factory: CriticFactory
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
if env_type == EnvType.CONTINUOUS:
|
match env_type:
|
||||||
return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
|
case EnvType.CONTINUOUS:
|
||||||
envs,
|
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||||
device,
|
case EnvType.DISCRETE:
|
||||||
use_action,
|
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||||
)
|
case _:
|
||||||
elif env_type == EnvType.DISCRETE:
|
|
||||||
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
|
|
||||||
envs,
|
|
||||||
device,
|
|
||||||
use_action,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
return factory.create_module(envs, device, use_action)
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuousNet(CriticFactory):
|
class CriticFactoryContinuousNet(CriticFactory):
|
||||||
@ -92,3 +87,94 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
critic = discrete.Critic(net_c, device=device).to(device)
|
critic = discrete.Critic(net_c, device=device).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
|
|
||||||
|
|
||||||
|
class CriticEnsembleFactory:
|
||||||
|
@abstractmethod
|
||||||
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
ensemble_size: int,
|
||||||
|
use_action: bool,
|
||||||
|
) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_module_opt(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
ensemble_size: int,
|
||||||
|
use_action: bool,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
lr: float,
|
||||||
|
) -> ModuleOpt:
|
||||||
|
module = self.create_module(envs, device, ensemble_size, use_action)
|
||||||
|
opt = optim_factory.create_optimizer(module, lr)
|
||||||
|
return ModuleOpt(module, opt)
|
||||||
|
|
||||||
|
|
||||||
|
class CriticEnsembleFactoryDefault(CriticEnsembleFactory):
|
||||||
|
"""A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||||
|
|
||||||
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||||
|
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
ensemble_size: int,
|
||||||
|
use_action: bool,
|
||||||
|
) -> nn.Module:
|
||||||
|
env_type = envs.get_type()
|
||||||
|
factory: CriticEnsembleFactory
|
||||||
|
match env_type:
|
||||||
|
case EnvType.CONTINUOUS:
|
||||||
|
factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes)
|
||||||
|
case EnvType.DISCRETE:
|
||||||
|
raise NotImplementedError("No default is implemented for the discrete case")
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
return factory.create_module(
|
||||||
|
envs,
|
||||||
|
device,
|
||||||
|
ensemble_size,
|
||||||
|
use_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory):
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
ensemble_size: int,
|
||||||
|
use_action: bool,
|
||||||
|
) -> nn.Module:
|
||||||
|
def linear_layer(x: int, y: int) -> EnsembleLinear:
|
||||||
|
return EnsembleLinear(ensemble_size, x, y)
|
||||||
|
|
||||||
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
|
net_c = Net(
|
||||||
|
envs.get_observation_shape(),
|
||||||
|
action_shape=action_shape,
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
concat=use_action,
|
||||||
|
activation=nn.Tanh,
|
||||||
|
device=device,
|
||||||
|
linear_layer=linear_layer,
|
||||||
|
)
|
||||||
|
critic = continuous.Critic(
|
||||||
|
net_c,
|
||||||
|
device=device,
|
||||||
|
linear_layer=linear_layer,
|
||||||
|
flatten_input=False,
|
||||||
|
).to(device)
|
||||||
|
init_linear_orthogonal(critic)
|
||||||
|
return critic
|
||||||
|
@ -399,6 +399,22 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
|
|||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class REDQParams(DDPGParams):
|
||||||
|
ensemble_size: int = 10
|
||||||
|
subset_size: int = 2
|
||||||
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||||
|
estimation_step: int = 1
|
||||||
|
actor_delay: int = 20
|
||||||
|
deterministic_eval: bool = True
|
||||||
|
target_mode: Literal["mean", "min"] = "min"
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
@ -12,6 +12,7 @@ from tianshou.data.types import RecurrentStateBatch
|
|||||||
ModuleType = type[nn.Module]
|
ModuleType = type[nn.Module]
|
||||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||||
TActionShape: TypeAlias = Sequence[int] | int
|
TActionShape: TypeAlias = Sequence[int] | int
|
||||||
|
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
|
||||||
|
|
||||||
|
|
||||||
def miniblock(
|
def miniblock(
|
||||||
@ -77,7 +78,7 @@ class MLP(nn.Module):
|
|||||||
activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU,
|
activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU,
|
||||||
act_args: ArgsType | None = None,
|
act_args: ArgsType | None = None,
|
||||||
device: str | int | torch.device | None = None,
|
device: str | int | torch.device | None = None,
|
||||||
linear_layer: type[nn.Linear] = nn.Linear,
|
linear_layer: TLinearLayer = nn.Linear,
|
||||||
flatten_input: bool = True,
|
flatten_input: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -183,7 +184,8 @@ class Net(NetBase):
|
|||||||
pass a tuple of two dict (first for Q and second for V) stating
|
pass a tuple of two dict (first for Q and second for V) stating
|
||||||
self-defined arguments as stated in
|
self-defined arguments as stated in
|
||||||
class:`~tianshou.utils.net.common.MLP`. Default to None.
|
class:`~tianshou.utils.net.common.MLP`. Default to None.
|
||||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
:param linear_layer: use this module constructor, which takes the input
|
||||||
|
and output dimension as input, as linear layer. Default to nn.Linear.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -209,7 +211,7 @@ class Net(NetBase):
|
|||||||
concat: bool = False,
|
concat: bool = False,
|
||||||
num_atoms: int = 1,
|
num_atoms: int = 1,
|
||||||
dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None,
|
dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None,
|
||||||
linear_layer: type[nn.Linear] = nn.Linear,
|
linear_layer: TLinearLayer = nn.Linear,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
|
||||||
|
|
||||||
SIGMA_MIN = -20
|
SIGMA_MIN = -20
|
||||||
SIGMA_MAX = 2
|
SIGMA_MAX = 2
|
||||||
@ -108,7 +108,7 @@ class Critic(nn.Module):
|
|||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
device: str | int | torch.device = "cpu",
|
device: str | int | torch.device = "cpu",
|
||||||
preprocess_net_output_dim: int | None = None,
|
preprocess_net_output_dim: int | None = None,
|
||||||
linear_layer: type[nn.Linear] = nn.Linear,
|
linear_layer: TLinearLayer = nn.Linear,
|
||||||
flatten_input: bool = True,
|
flatten_input: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user