Support REDQ in high-level API

* Implement example mujoco_redq_hl
* Add abstraction CriticEnsembleFactory with default implementations
  to suit REDQ
* Fix type annotation of linear_layer in Net, MLP, Critic
  (was incompatible with REDQ usage)
This commit is contained in:
Dominik Jain 2023-10-10 15:49:05 +02:00
parent 7af836bd6a
commit 17ef4dd5eb
8 changed files with 328 additions and 23 deletions

View File

@ -38,7 +38,7 @@ def main(
test_num: int = 10, test_num: int = 10,
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
sampling_config = SamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
REDQExperimentBuilder,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import REDQParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
ensemble_size: int = 10,
subset_size: int = 2,
actor_lr: float = 1e-3,
critic_lr: float = 1e-3,
gamma: float = 0.99,
tau: float = 0.005,
alpha: float = 0.2,
auto_alpha: bool = False,
alpha_lr: float = 3e-4,
start_timesteps: int = 10000,
epoch: int = 200,
step_per_epoch: int = 5000,
step_per_collect: int = 1,
update_per_step: int = 20,
n_step: int = 1,
batch_size: int = 256,
target_mode: Literal["mean", "min"] = "min",
training_num: int = 1,
test_num: int = 10,
):
log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
start_timesteps=start_timesteps,
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_redq_params(
REDQParams(
actor_lr=actor_lr,
critic_lr=critic_lr,
gamma=gamma,
tau=tau,
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step,
target_mode=target_mode,
subset_size=subset_size,
ensemble_size=ensemble_size,
),
)
.with_actor_factory_default(hidden_sizes)
.with_critic_ensemble_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -14,7 +14,7 @@ from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
) )
from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
from tianshou.highlevel.module.module_opt import ( from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt, ActorCriticModuleOpt,
) )
@ -28,6 +28,7 @@ from tianshou.highlevel.params.policy_params import (
ParamTransformerData, ParamTransformerData,
PGParams, PGParams,
PPOParams, PPOParams,
REDQParams,
SACParams, SACParams,
TD3Params, TD3Params,
TRPOParams, TRPOParams,
@ -42,6 +43,7 @@ from tianshou.policy import (
NPGPolicy, NPGPolicy,
PGPolicy, PGPolicy,
PPOPolicy, PPOPolicy,
REDQPolicy,
SACPolicy, SACPolicy,
TD3Policy, TD3Policy,
TRPOPolicy, TRPOPolicy,
@ -565,6 +567,58 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
) )
class REDQAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: REDQParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_ensemble_factory: CriticEnsembleFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.critic_ensemble_factory = critic_ensemble_factory
self.actor_factory = actor_factory
self.params = params
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
envs.get_type().assert_continuous(self)
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic_ensemble = self.critic_ensemble_factory.create_module_opt(
envs,
device,
self.params.ensemble_size,
True,
self.optim_factory,
self.params.critic_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic_ensemble,
),
)
action_space = cast(gymnasium.spaces.Box, envs.get_action_space())
return REDQPolicy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic_ensemble.module,
critic_optim=critic_ensemble.optim,
action_space=action_space,
observation_space=envs.get_observation_space(),
**kwargs,
)
class SACAgentFactory(OffpolicyAgentFactory): class SACAgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,

View File

@ -17,6 +17,7 @@ from tianshou.highlevel.agent import (
NPGAgentFactory, NPGAgentFactory,
PGAgentFactory, PGAgentFactory,
PPOAgentFactory, PPOAgentFactory,
REDQAgentFactory,
SACAgentFactory, SACAgentFactory,
TD3AgentFactory, TD3AgentFactory,
TRPOAgentFactory, TRPOAgentFactory,
@ -29,7 +30,12 @@ from tianshou.highlevel.module.actor import (
ActorFactoryDefault, ActorFactoryDefault,
ContinuousActorType, ContinuousActorType,
) )
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault from tianshou.highlevel.module.critic import (
CriticEnsembleFactory,
CriticEnsembleFactoryDefault,
CriticFactory,
CriticFactoryDefault,
)
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams, A2CParams,
@ -38,6 +44,7 @@ from tianshou.highlevel.params.policy_params import (
NPGParams, NPGParams,
PGParams, PGParams,
PPOParams, PPOParams,
REDQParams,
SACParams, SACParams,
TD3Params, TD3Params,
TRPOParams, TRPOParams,
@ -404,6 +411,28 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self return self
class _BuilderMixinCriticEnsembleFactory:
def __init__(self) -> None:
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
self.critic_ensemble_factory = factory
return self
def with_critic_ensemble_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
return self
def _get_critic_ensemble_factory(self):
if self.critic_ensemble_factory is None:
return CriticEnsembleFactoryDefault()
else:
return self.critic_ensemble_factory
class PGExperimentBuilder( class PGExperimentBuilder(
ExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
@ -621,6 +650,37 @@ class DDPGExperimentBuilder(
) )
class REDQExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinCriticEnsembleFactory,
):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinCriticEnsembleFactory.__init__(self)
self._params: REDQParams = REDQParams()
def with_redq_params(self, params: REDQParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return REDQAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_ensemble_factory(),
self._get_optim_factory(),
)
class SACExperimentBuilder( class SACExperimentBuilder(
ExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,

View File

@ -8,7 +8,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import Net from tianshou.utils.net.common import EnsembleLinear, Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -39,21 +39,16 @@ class CriticFactoryDefault(CriticFactory):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
factory: CriticFactory
env_type = envs.get_type() env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS: match env_type:
return CriticFactoryContinuousNet(self.hidden_sizes).create_module( case EnvType.CONTINUOUS:
envs, factory = CriticFactoryContinuousNet(self.hidden_sizes)
device, case EnvType.DISCRETE:
use_action, factory = CriticFactoryDiscreteNet(self.hidden_sizes)
) case _:
elif env_type == EnvType.DISCRETE:
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
envs,
device,
use_action,
)
else:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
return factory.create_module(envs, device, use_action)
class CriticFactoryContinuousNet(CriticFactory): class CriticFactoryContinuousNet(CriticFactory):
@ -92,3 +87,94 @@ class CriticFactoryDiscreteNet(CriticFactory):
critic = discrete.Critic(net_c, device=device).to(device) critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic) init_linear_orthogonal(critic)
return critic return critic
class CriticEnsembleFactory:
@abstractmethod
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
pass
def create_module_opt(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
optim_factory: OptimizerFactory,
lr: float,
) -> ModuleOpt:
module = self.create_module(envs, device, ensemble_size, use_action)
opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt)
class CriticEnsembleFactoryDefault(CriticEnsembleFactory):
"""A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
env_type = envs.get_type()
factory: CriticEnsembleFactory
match env_type:
case EnvType.CONTINUOUS:
factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes)
case EnvType.DISCRETE:
raise NotImplementedError("No default is implemented for the discrete case")
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(
envs,
device,
ensemble_size,
use_action,
)
class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
def linear_layer(x: int, y: int) -> EnsembleLinear:
return EnsembleLinear(ensemble_size, x, y)
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
linear_layer=linear_layer,
)
critic = continuous.Critic(
net_c,
device=device,
linear_layer=linear_layer,
flatten_input=False,
).to(device)
init_linear_orthogonal(critic)
return critic

View File

@ -399,6 +399,22 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
return transformers return transformers
@dataclass
class REDQParams(DDPGParams):
ensemble_size: int = 10
subset_size: int = 2
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
actor_delay: int = 20
deterministic_eval: bool = True
target_mode: Literal["mean", "min"] = "min"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass @dataclass
class TD3Params(Params, ParamsMixinActorAndDualCritics): class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005 tau: float = 0.005

View File

@ -12,6 +12,7 @@ from tianshou.data.types import RecurrentStateBatch
ModuleType = type[nn.Module] ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int TActionShape: TypeAlias = Sequence[int] | int
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
def miniblock( def miniblock(
@ -77,7 +78,7 @@ class MLP(nn.Module):
activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU,
act_args: ArgsType | None = None, act_args: ArgsType | None = None,
device: str | int | torch.device | None = None, device: str | int | torch.device | None = None,
linear_layer: type[nn.Linear] = nn.Linear, linear_layer: TLinearLayer = nn.Linear,
flatten_input: bool = True, flatten_input: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -183,7 +184,8 @@ class Net(NetBase):
pass a tuple of two dict (first for Q and second for V) stating pass a tuple of two dict (first for Q and second for V) stating
self-defined arguments as stated in self-defined arguments as stated in
class:`~tianshou.utils.net.common.MLP`. Default to None. class:`~tianshou.utils.net.common.MLP`. Default to None.
:param linear_layer: use this module as linear layer. Default to nn.Linear. :param linear_layer: use this module constructor, which takes the input
and output dimension as input, as linear layer. Default to nn.Linear.
.. seealso:: .. seealso::
@ -209,7 +211,7 @@ class Net(NetBase):
concat: bool = False, concat: bool = False,
num_atoms: int = 1, num_atoms: int = 1,
dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None, dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None,
linear_layer: type[nn.Linear] = nn.Linear, linear_layer: TLinearLayer = nn.Linear,
) -> None: ) -> None:
super().__init__() super().__init__()
self.device = device self.device = device

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.utils.net.common import MLP, BaseActor, TActionShape from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
SIGMA_MIN = -20 SIGMA_MIN = -20
SIGMA_MAX = 2 SIGMA_MAX = 2
@ -108,7 +108,7 @@ class Critic(nn.Module):
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,
linear_layer: type[nn.Linear] = nn.Linear, linear_layer: TLinearLayer = nn.Linear,
flatten_input: bool = True, flatten_input: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()