Support REDQ in high-level API

* Implement example mujoco_redq_hl
* Add abstraction CriticEnsembleFactory with default implementations
  to suit REDQ
* Fix type annotation of linear_layer in Net, MLP, Critic
  (was incompatible with REDQ usage)
This commit is contained in:
Dominik Jain 2023-10-10 15:49:05 +02:00
parent 7af836bd6a
commit 17ef4dd5eb
8 changed files with 328 additions and 23 deletions

View File

@ -38,7 +38,7 @@ def main(
test_num: int = 10,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), now)
sampling_config = SamplingConfig(
num_epochs=epoch,

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
REDQExperimentBuilder,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import REDQParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
ensemble_size: int = 10,
subset_size: int = 2,
actor_lr: float = 1e-3,
critic_lr: float = 1e-3,
gamma: float = 0.99,
tau: float = 0.005,
alpha: float = 0.2,
auto_alpha: bool = False,
alpha_lr: float = 3e-4,
start_timesteps: int = 10000,
epoch: int = 200,
step_per_epoch: int = 5000,
step_per_collect: int = 1,
update_per_step: int = 20,
n_step: int = 1,
batch_size: int = 256,
target_mode: Literal["mean", "min"] = "min",
training_num: int = 1,
test_num: int = 10,
):
log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
start_timesteps=start_timesteps,
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_redq_params(
REDQParams(
actor_lr=actor_lr,
critic_lr=critic_lr,
gamma=gamma,
tau=tau,
alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step,
target_mode=target_mode,
subset_size=subset_size,
ensemble_size=ensemble_size,
),
)
.with_actor_factory_default(hidden_sizes)
.with_critic_ensemble_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -14,7 +14,7 @@ from tianshou.highlevel.module.actor import (
ActorFactory,
)
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt,
)
@ -28,6 +28,7 @@ from tianshou.highlevel.params.policy_params import (
ParamTransformerData,
PGParams,
PPOParams,
REDQParams,
SACParams,
TD3Params,
TRPOParams,
@ -42,6 +43,7 @@ from tianshou.policy import (
NPGPolicy,
PGPolicy,
PPOPolicy,
REDQPolicy,
SACPolicy,
TD3Policy,
TRPOPolicy,
@ -565,6 +567,58 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
)
class REDQAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: REDQParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_ensemble_factory: CriticEnsembleFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.critic_ensemble_factory = critic_ensemble_factory
self.actor_factory = actor_factory
self.params = params
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
envs.get_type().assert_continuous(self)
actor = self.actor_factory.create_module_opt(
envs,
device,
self.optim_factory,
self.params.actor_lr,
)
critic_ensemble = self.critic_ensemble_factory.create_module_opt(
envs,
device,
self.params.ensemble_size,
True,
self.optim_factory,
self.params.critic_lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim_factory=self.optim_factory,
actor=actor,
critic1=critic_ensemble,
),
)
action_space = cast(gymnasium.spaces.Box, envs.get_action_space())
return REDQPolicy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic_ensemble.module,
critic_optim=critic_ensemble.optim,
action_space=action_space,
observation_space=envs.get_observation_space(),
**kwargs,
)
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,

View File

@ -17,6 +17,7 @@ from tianshou.highlevel.agent import (
NPGAgentFactory,
PGAgentFactory,
PPOAgentFactory,
REDQAgentFactory,
SACAgentFactory,
TD3AgentFactory,
TRPOAgentFactory,
@ -29,7 +30,12 @@ from tianshou.highlevel.module.actor import (
ActorFactoryDefault,
ContinuousActorType,
)
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
from tianshou.highlevel.module.critic import (
CriticEnsembleFactory,
CriticEnsembleFactoryDefault,
CriticFactory,
CriticFactoryDefault,
)
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import (
A2CParams,
@ -38,6 +44,7 @@ from tianshou.highlevel.params.policy_params import (
NPGParams,
PGParams,
PPOParams,
REDQParams,
SACParams,
TD3Params,
TRPOParams,
@ -404,6 +411,28 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self
class _BuilderMixinCriticEnsembleFactory:
def __init__(self) -> None:
self.critic_ensemble_factory: CriticEnsembleFactory | None = None
def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self:
self.critic_ensemble_factory = factory
return self
def with_critic_ensemble_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> Self:
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
return self
def _get_critic_ensemble_factory(self):
if self.critic_ensemble_factory is None:
return CriticEnsembleFactoryDefault()
else:
return self.critic_ensemble_factory
class PGExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
@ -621,6 +650,37 @@ class DDPGExperimentBuilder(
)
class REDQExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinCriticEnsembleFactory,
):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinCriticEnsembleFactory.__init__(self)
self._params: REDQParams = REDQParams()
def with_redq_params(self, params: REDQParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return REDQAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_ensemble_factory(),
self._get_optim_factory(),
)
class SACExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,

View File

@ -8,7 +8,7 @@ from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import EnsembleLinear, Net
from tianshou.utils.string import ToStringMixin
@ -39,21 +39,16 @@ class CriticFactoryDefault(CriticFactory):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
factory: CriticFactory
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
envs,
device,
use_action,
)
elif env_type == EnvType.DISCRETE:
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
envs,
device,
use_action,
)
else:
raise ValueError(f"{env_type} not supported")
match env_type:
case EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
case EnvType.DISCRETE:
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(envs, device, use_action)
class CriticFactoryContinuousNet(CriticFactory):
@ -92,3 +87,94 @@ class CriticFactoryDiscreteNet(CriticFactory):
critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
class CriticEnsembleFactory:
@abstractmethod
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
pass
def create_module_opt(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
optim_factory: OptimizerFactory,
lr: float,
) -> ModuleOpt:
module = self.create_module(envs, device, ensemble_size, use_action)
opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt)
class CriticEnsembleFactoryDefault(CriticEnsembleFactory):
"""A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
env_type = envs.get_type()
factory: CriticEnsembleFactory
match env_type:
case EnvType.CONTINUOUS:
factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes)
case EnvType.DISCRETE:
raise NotImplementedError("No default is implemented for the discrete case")
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(
envs,
device,
ensemble_size,
use_action,
)
class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
def linear_layer(x: int, y: int) -> EnsembleLinear:
return EnsembleLinear(ensemble_size, x, y)
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
linear_layer=linear_layer,
)
critic = continuous.Critic(
net_c,
device=device,
linear_layer=linear_layer,
flatten_input=False,
).to(device)
init_linear_orthogonal(critic)
return critic

View File

@ -399,6 +399,22 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
return transformers
@dataclass
class REDQParams(DDPGParams):
ensemble_size: int = 10
subset_size: int = 2
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
actor_delay: int = 20
deterministic_eval: bool = True
target_mode: Literal["mean", "min"] = "min"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerAutoAlpha("alpha"))
return transformers
@dataclass
class TD3Params(Params, ParamsMixinActorAndDualCritics):
tau: float = 0.005

View File

@ -12,6 +12,7 @@ from tianshou.data.types import RecurrentStateBatch
ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
def miniblock(
@ -77,7 +78,7 @@ class MLP(nn.Module):
activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU,
act_args: ArgsType | None = None,
device: str | int | torch.device | None = None,
linear_layer: type[nn.Linear] = nn.Linear,
linear_layer: TLinearLayer = nn.Linear,
flatten_input: bool = True,
) -> None:
super().__init__()
@ -183,7 +184,8 @@ class Net(NetBase):
pass a tuple of two dict (first for Q and second for V) stating
self-defined arguments as stated in
class:`~tianshou.utils.net.common.MLP`. Default to None.
:param linear_layer: use this module as linear layer. Default to nn.Linear.
:param linear_layer: use this module constructor, which takes the input
and output dimension as input, as linear layer. Default to nn.Linear.
.. seealso::
@ -209,7 +211,7 @@ class Net(NetBase):
concat: bool = False,
num_atoms: int = 1,
dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None,
linear_layer: type[nn.Linear] = nn.Linear,
linear_layer: TLinearLayer = nn.Linear,
) -> None:
super().__init__()
self.device = device

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
from torch import nn
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
SIGMA_MIN = -20
SIGMA_MAX = 2
@ -108,7 +108,7 @@ class Critic(nn.Module):
hidden_sizes: Sequence[int] = (),
device: str | int | torch.device = "cpu",
preprocess_net_output_dim: int | None = None,
linear_layer: type[nn.Linear] = nn.Linear,
linear_layer: TLinearLayer = nn.Linear,
flatten_input: bool = True,
) -> None:
super().__init__()