Dominik Jain 17ef4dd5eb Support REDQ in high-level API
* Implement example mujoco_redq_hl
* Add abstraction CriticEnsembleFactory with default implementations
  to suit REDQ
* Fix type annotation of linear_layer in Net, MLP, Critic
  (was incompatible with REDQ usage)
2023-10-18 20:44:17 +02:00

181 lines
5.8 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Sequence
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import EnsembleLinear, Net
from tianshou.utils.string import ToStringMixin
class CriticFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
def create_module_opt(
self,
envs: Environments,
device: TDevice,
use_action: bool,
optim_factory: OptimizerFactory,
lr: float,
) -> ModuleOpt:
module = self.create_module(envs, device, use_action)
opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt)
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
factory: CriticFactory
env_type = envs.get_type()
match env_type:
case EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
case EnvType.DISCRETE:
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(envs, device, use_action)
class CriticFactoryContinuousNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
class CriticEnsembleFactory:
@abstractmethod
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
pass
def create_module_opt(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
optim_factory: OptimizerFactory,
lr: float,
) -> ModuleOpt:
module = self.create_module(envs, device, ensemble_size, use_action)
opt = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, opt)
class CriticEnsembleFactoryDefault(CriticEnsembleFactory):
"""A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
env_type = envs.get_type()
factory: CriticEnsembleFactory
match env_type:
case EnvType.CONTINUOUS:
factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes)
case EnvType.DISCRETE:
raise NotImplementedError("No default is implemented for the discrete case")
case _:
raise ValueError(f"{env_type} not supported")
return factory.create_module(
envs,
device,
ensemble_size,
use_action,
)
class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(
self,
envs: Environments,
device: TDevice,
ensemble_size: int,
use_action: bool,
) -> nn.Module:
def linear_layer(x: int, y: int) -> EnsembleLinear:
return EnsembleLinear(ensemble_size, x, y)
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
linear_layer=linear_layer,
)
critic = continuous.Critic(
net_c,
device=device,
linear_layer=linear_layer,
flatten_input=False,
).to(device)
init_linear_orthogonal(critic)
return critic