Support REDQ in high-level API
* Implement example mujoco_redq_hl * Add abstraction CriticEnsembleFactory with default implementations to suit REDQ * Fix type annotation of linear_layer in Net, MLP, Critic (was incompatible with REDQ usage)
This commit is contained in:
parent
7af836bd6a
commit
17ef4dd5eb
@ -38,7 +38,7 @@ def main(
|
||||
test_num: int = 10,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
|
87
examples/mujoco/mujoco_redq_hl.py
Normal file
87
examples/mujoco/mujoco_redq_hl.py
Normal file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
REDQExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||
from tianshou.highlevel.params.policy_params import REDQParams
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
buffer_size: int = 1000000,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
ensemble_size: int = 10,
|
||||
subset_size: int = 2,
|
||||
actor_lr: float = 1e-3,
|
||||
critic_lr: float = 1e-3,
|
||||
gamma: float = 0.99,
|
||||
tau: float = 0.005,
|
||||
alpha: float = 0.2,
|
||||
auto_alpha: bool = False,
|
||||
alpha_lr: float = 3e-4,
|
||||
start_timesteps: int = 10000,
|
||||
epoch: int = 200,
|
||||
step_per_epoch: int = 5000,
|
||||
step_per_collect: int = 1,
|
||||
update_per_step: int = 20,
|
||||
n_step: int = 1,
|
||||
batch_size: int = 256,
|
||||
target_mode: Literal["mean", "min"] = "min",
|
||||
training_num: int = 1,
|
||||
test_num: int = 10,
|
||||
):
|
||||
log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
update_per_step=update_per_step,
|
||||
repeat_per_collect=None,
|
||||
start_timesteps=start_timesteps,
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
experiment = (
|
||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_redq_params(
|
||||
REDQParams(
|
||||
actor_lr=actor_lr,
|
||||
critic_lr=critic_lr,
|
||||
gamma=gamma,
|
||||
tau=tau,
|
||||
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
|
||||
estimation_step=n_step,
|
||||
target_mode=target_mode,
|
||||
subset_size=subset_size,
|
||||
ensemble_size=ensemble_size,
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(hidden_sizes)
|
||||
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.run_main(lambda: CLI(main))
|
@ -14,7 +14,7 @@ from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
|
||||
from tianshou.highlevel.module.module_opt import (
|
||||
ActorCriticModuleOpt,
|
||||
)
|
||||
@ -28,6 +28,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
ParamTransformerData,
|
||||
PGParams,
|
||||
PPOParams,
|
||||
REDQParams,
|
||||
SACParams,
|
||||
TD3Params,
|
||||
TRPOParams,
|
||||
@ -42,6 +43,7 @@ from tianshou.policy import (
|
||||
NPGPolicy,
|
||||
PGPolicy,
|
||||
PPOPolicy,
|
||||
REDQPolicy,
|
||||
SACPolicy,
|
||||
TD3Policy,
|
||||
TRPOPolicy,
|
||||
@ -565,6 +567,58 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
|
||||
|
||||
class REDQAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: REDQParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_ensemble_factory: CriticEnsembleFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.critic_ensemble_factory = critic_ensemble_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.params = params
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
envs.get_type().assert_continuous(self)
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.optim_factory,
|
||||
self.params.actor_lr,
|
||||
)
|
||||
critic_ensemble = self.critic_ensemble_factory.create_module_opt(
|
||||
envs,
|
||||
device,
|
||||
self.params.ensemble_size,
|
||||
True,
|
||||
self.optim_factory,
|
||||
self.params.critic_lr,
|
||||
)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
device=device,
|
||||
optim_factory=self.optim_factory,
|
||||
actor=actor,
|
||||
critic1=critic_ensemble,
|
||||
),
|
||||
)
|
||||
action_space = cast(gymnasium.spaces.Box, envs.get_action_space())
|
||||
return REDQPolicy(
|
||||
actor=actor.module,
|
||||
actor_optim=actor.optim,
|
||||
critic=critic_ensemble.module,
|
||||
critic_optim=critic_ensemble.optim,
|
||||
action_space=action_space,
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -17,6 +17,7 @@ from tianshou.highlevel.agent import (
|
||||
NPGAgentFactory,
|
||||
PGAgentFactory,
|
||||
PPOAgentFactory,
|
||||
REDQAgentFactory,
|
||||
SACAgentFactory,
|
||||
TD3AgentFactory,
|
||||
TRPOAgentFactory,
|
||||
@ -29,7 +30,12 @@ from tianshou.highlevel.module.actor import (
|
||||
ActorFactoryDefault,
|
||||
ContinuousActorType,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
|
||||
from tianshou.highlevel.module.critic import (
|
||||
CriticEnsembleFactory,
|
||||
CriticEnsembleFactoryDefault,
|
||||
CriticFactory,
|
||||
CriticFactoryDefault,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
A2CParams,
|
||||
@ -38,6 +44,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
NPGParams,
|
||||
PGParams,
|
||||
PPOParams,
|
||||
REDQParams,
|
||||
SACParams,
|
||||
TD3Params,
|
||||
TRPOParams,
|
||||
@ -404,6 +411,28 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
return self
|
||||
|
||||
|
||||
class _BuilderMixinCriticEnsembleFactory:
|
||||
def __init__(self) -> None:
|
||||
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
|
||||
|
||||
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
|
||||
self.critic_ensemble_factory = factory
|
||||
return self
|
||||
|
||||
def with_critic_ensemble_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> Self:
|
||||
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
||||
return self
|
||||
|
||||
def _get_critic_ensemble_factory(self):
|
||||
if self.critic_ensemble_factory is None:
|
||||
return CriticEnsembleFactoryDefault()
|
||||
else:
|
||||
return self.critic_ensemble_factory
|
||||
|
||||
|
||||
class PGExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
@ -621,6 +650,37 @@ class DDPGExperimentBuilder(
|
||||
)
|
||||
|
||||
|
||||
class REDQExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinCriticEnsembleFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig | None = None,
|
||||
sampling_config: SamplingConfig | None = None,
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinCriticEnsembleFactory.__init__(self)
|
||||
self._params: REDQParams = REDQParams()
|
||||
|
||||
def with_redq_params(self, params: REDQParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return REDQAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_ensemble_factory(),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
||||
class SACExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
|
@ -8,7 +8,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import EnsembleLinear, Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
@ -39,21 +39,16 @@ class CriticFactoryDefault(CriticFactory):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
factory: CriticFactory
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action,
|
||||
)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||
case EnvType.DISCRETE:
|
||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||
case _:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
return factory.create_module(envs, device, use_action)
|
||||
|
||||
|
||||
class CriticFactoryContinuousNet(CriticFactory):
|
||||
@ -92,3 +87,94 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
||||
critic = discrete.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
||||
|
||||
class CriticEnsembleFactory:
|
||||
@abstractmethod
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
ensemble_size: int,
|
||||
use_action: bool,
|
||||
) -> nn.Module:
|
||||
pass
|
||||
|
||||
def create_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
ensemble_size: int,
|
||||
use_action: bool,
|
||||
optim_factory: OptimizerFactory,
|
||||
lr: float,
|
||||
) -> ModuleOpt:
|
||||
module = self.create_module(envs, device, ensemble_size, use_action)
|
||||
opt = optim_factory.create_optimizer(module, lr)
|
||||
return ModuleOpt(module, opt)
|
||||
|
||||
|
||||
class CriticEnsembleFactoryDefault(CriticEnsembleFactory):
|
||||
"""A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
|
||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
ensemble_size: int,
|
||||
use_action: bool,
|
||||
) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
factory: CriticEnsembleFactory
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes)
|
||||
case EnvType.DISCRETE:
|
||||
raise NotImplementedError("No default is implemented for the discrete case")
|
||||
case _:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
return factory.create_module(
|
||||
envs,
|
||||
device,
|
||||
ensemble_size,
|
||||
use_action,
|
||||
)
|
||||
|
||||
|
||||
class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
ensemble_size: int,
|
||||
use_action: bool,
|
||||
) -> nn.Module:
|
||||
def linear_layer(x: int, y: int) -> EnsembleLinear:
|
||||
return EnsembleLinear(ensemble_size, x, y)
|
||||
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_observation_shape(),
|
||||
action_shape=action_shape,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
concat=use_action,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
linear_layer=linear_layer,
|
||||
)
|
||||
critic = continuous.Critic(
|
||||
net_c,
|
||||
device=device,
|
||||
linear_layer=linear_layer,
|
||||
flatten_input=False,
|
||||
).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
@ -399,6 +399,22 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class REDQParams(DDPGParams):
|
||||
ensemble_size: int = 10
|
||||
subset_size: int = 2
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||
estimation_step: int = 1
|
||||
actor_delay: int = 20
|
||||
deterministic_eval: bool = True
|
||||
target_mode: Literal["mean", "min"] = "min"
|
||||
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||
tau: float = 0.005
|
||||
|
@ -12,6 +12,7 @@ from tianshou.data.types import RecurrentStateBatch
|
||||
ModuleType = type[nn.Module]
|
||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||
TActionShape: TypeAlias = Sequence[int] | int
|
||||
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
|
||||
|
||||
|
||||
def miniblock(
|
||||
@ -77,7 +78,7 @@ class MLP(nn.Module):
|
||||
activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU,
|
||||
act_args: ArgsType | None = None,
|
||||
device: str | int | torch.device | None = None,
|
||||
linear_layer: type[nn.Linear] = nn.Linear,
|
||||
linear_layer: TLinearLayer = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -183,7 +184,8 @@ class Net(NetBase):
|
||||
pass a tuple of two dict (first for Q and second for V) stating
|
||||
self-defined arguments as stated in
|
||||
class:`~tianshou.utils.net.common.MLP`. Default to None.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
:param linear_layer: use this module constructor, which takes the input
|
||||
and output dimension as input, as linear layer. Default to nn.Linear.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -209,7 +211,7 @@ class Net(NetBase):
|
||||
concat: bool = False,
|
||||
num_atoms: int = 1,
|
||||
dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None,
|
||||
linear_layer: type[nn.Linear] = nn.Linear,
|
||||
linear_layer: TLinearLayer = nn.Linear,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
@ -108,7 +108,7 @@ class Critic(nn.Module):
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
device: str | int | torch.device = "cpu",
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
linear_layer: type[nn.Linear] = nn.Linear,
|
||||
linear_layer: TLinearLayer = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
Loading…
x
Reference in New Issue
Block a user