* Implement example mujoco_redq_hl * Add abstraction CriticEnsembleFactory with default implementations to suit REDQ * Fix type annotation of linear_layer in Net, MLP, Critic (was incompatible with REDQ usage)
506 lines
17 KiB
Python
506 lines
17 KiB
Python
import warnings
|
|
from collections.abc import Sequence
|
|
from typing import Any, cast
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
|
|
|
|
SIGMA_MIN = -20
|
|
SIGMA_MAX = 2
|
|
|
|
|
|
class Actor(BaseActor):
|
|
"""Simple actor network.
|
|
|
|
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param action_shape: a sequence of int for the shape of action.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param max_action: the scale for the final action logits. Default to
|
|
1.
|
|
:param preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
action_shape: TActionShape,
|
|
hidden_sizes: Sequence[int] = (),
|
|
max_action: float = 1.0,
|
|
device: str | int | torch.device = "cpu",
|
|
preprocess_net_output_dim: int | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.preprocess = preprocess_net
|
|
self.output_dim = int(np.prod(action_shape))
|
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
|
input_dim = cast(int, input_dim)
|
|
self.last = MLP(
|
|
input_dim,
|
|
self.output_dim,
|
|
hidden_sizes,
|
|
device=self.device,
|
|
)
|
|
self.max_action = max_action
|
|
|
|
def get_preprocess_net(self) -> nn.Module:
|
|
return self.preprocess
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
"""Mapping: obs -> logits -> action."""
|
|
if info is None:
|
|
info = {}
|
|
logits, hidden = self.preprocess(obs, state)
|
|
logits = self.max_action * torch.tanh(self.last(logits))
|
|
return logits, hidden
|
|
|
|
|
|
class Critic(nn.Module):
|
|
"""Simple critic network.
|
|
|
|
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
|
:param flatten_input: whether to flatten input data for the last layer.
|
|
Default to True.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
hidden_sizes: Sequence[int] = (),
|
|
device: str | int | torch.device = "cpu",
|
|
preprocess_net_output_dim: int | None = None,
|
|
linear_layer: TLinearLayer = nn.Linear,
|
|
flatten_input: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.preprocess = preprocess_net
|
|
self.output_dim = 1
|
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
|
self.last = MLP(
|
|
input_dim, # type: ignore
|
|
1,
|
|
hidden_sizes,
|
|
device=self.device,
|
|
linear_layer=linear_layer,
|
|
flatten_input=flatten_input,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
act: np.ndarray | torch.Tensor | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> torch.Tensor:
|
|
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(
|
|
obs,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
).flatten(1)
|
|
if act is not None:
|
|
act = torch.as_tensor(
|
|
act,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
).flatten(1)
|
|
obs = torch.cat([obs, act], dim=1)
|
|
logits, hidden = self.preprocess(obs)
|
|
return self.last(logits)
|
|
|
|
|
|
class ActorProb(BaseActor):
|
|
"""Simple actor network (output with a Gauss distribution).
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param action_shape: a sequence of int for the shape of action.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param max_action: the scale for the final action logits. Default to
|
|
1.
|
|
:param unbounded: whether to apply tanh activation on final logits.
|
|
Default to False.
|
|
:param conditioned_sigma: True when sigma is calculated from the
|
|
input, False when sigma is an independent parameter. Default to False.
|
|
:param preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
action_shape: TActionShape,
|
|
hidden_sizes: Sequence[int] = (),
|
|
max_action: float = 1.0,
|
|
device: str | int | torch.device = "cpu",
|
|
unbounded: bool = False,
|
|
conditioned_sigma: bool = False,
|
|
preprocess_net_output_dim: int | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
if unbounded and not np.isclose(max_action, 1.0):
|
|
warnings.warn("Note that max_action input will be discarded when unbounded is True.")
|
|
max_action = 1.0
|
|
self.preprocess = preprocess_net
|
|
self.device = device
|
|
self.output_dim = int(np.prod(action_shape))
|
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
|
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
|
|
self._c_sigma = conditioned_sigma
|
|
if conditioned_sigma:
|
|
self.sigma = MLP(
|
|
input_dim, # type: ignore
|
|
self.output_dim,
|
|
hidden_sizes,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
|
self.max_action = max_action
|
|
self._unbounded = unbounded
|
|
|
|
def get_preprocess_net(self) -> nn.Module:
|
|
return self.preprocess
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
|
|
"""Mapping: obs -> logits -> (mu, sigma)."""
|
|
if info is None:
|
|
info = {}
|
|
logits, hidden = self.preprocess(obs, state)
|
|
mu = self.mu(logits)
|
|
if not self._unbounded:
|
|
mu = self.max_action * torch.tanh(mu)
|
|
if self._c_sigma:
|
|
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
|
else:
|
|
shape = [1] * len(mu.shape)
|
|
shape[1] = -1
|
|
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
|
return (mu, sigma), state
|
|
|
|
|
|
class RecurrentActorProb(nn.Module):
|
|
"""Recurrent version of ActorProb.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_num: int,
|
|
state_shape: Sequence[int],
|
|
action_shape: Sequence[int],
|
|
hidden_layer_size: int = 128,
|
|
max_action: float = 1.0,
|
|
device: str | int | torch.device = "cpu",
|
|
unbounded: bool = False,
|
|
conditioned_sigma: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
if unbounded and not np.isclose(max_action, 1.0):
|
|
warnings.warn("Note that max_action input will be discarded when unbounded is True.")
|
|
max_action = 1.0
|
|
self.device = device
|
|
self.nn = nn.LSTM(
|
|
input_size=int(np.prod(state_shape)),
|
|
hidden_size=hidden_layer_size,
|
|
num_layers=layer_num,
|
|
batch_first=True,
|
|
)
|
|
output_dim = int(np.prod(action_shape))
|
|
self.mu = nn.Linear(hidden_layer_size, output_dim)
|
|
self._c_sigma = conditioned_sigma
|
|
if conditioned_sigma:
|
|
self.sigma = nn.Linear(hidden_layer_size, output_dim)
|
|
else:
|
|
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
|
|
self.max_action = max_action
|
|
self._unbounded = unbounded
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: dict[str, torch.Tensor] | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]:
|
|
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(
|
|
obs,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
|
# In short, the tensor's shape in training phase is longer than which
|
|
# in evaluation phase.
|
|
if len(obs.shape) == 2:
|
|
obs = obs.unsqueeze(-2)
|
|
self.nn.flatten_parameters()
|
|
if state is None:
|
|
obs, (hidden, cell) = self.nn(obs)
|
|
else:
|
|
# we store the stack data in [bsz, len, ...] format
|
|
# but pytorch rnn needs [len, bsz, ...]
|
|
obs, (hidden, cell) = self.nn(
|
|
obs,
|
|
(
|
|
state["hidden"].transpose(0, 1).contiguous(),
|
|
state["cell"].transpose(0, 1).contiguous(),
|
|
),
|
|
)
|
|
logits = obs[:, -1]
|
|
mu = self.mu(logits)
|
|
if not self._unbounded:
|
|
mu = self.max_action * torch.tanh(mu)
|
|
if self._c_sigma:
|
|
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
|
else:
|
|
shape = [1] * len(mu.shape)
|
|
shape[1] = -1
|
|
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
|
# please ensure the first dim is batch size: [bsz, len, ...]
|
|
return (mu, sigma), {
|
|
"hidden": hidden.transpose(0, 1).detach(),
|
|
"cell": cell.transpose(0, 1).detach(),
|
|
}
|
|
|
|
|
|
class RecurrentCritic(nn.Module):
|
|
"""Recurrent version of Critic.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_num: int,
|
|
state_shape: Sequence[int],
|
|
action_shape: Sequence[int] = [0],
|
|
device: str | int | torch.device = "cpu",
|
|
hidden_layer_size: int = 128,
|
|
) -> None:
|
|
super().__init__()
|
|
self.state_shape = state_shape
|
|
self.action_shape = action_shape
|
|
self.device = device
|
|
self.nn = nn.LSTM(
|
|
input_size=int(np.prod(state_shape)),
|
|
hidden_size=hidden_layer_size,
|
|
num_layers=layer_num,
|
|
batch_first=True,
|
|
)
|
|
self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
act: np.ndarray | torch.Tensor | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> torch.Tensor:
|
|
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(
|
|
obs,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
|
# In short, the tensor's shape in training phase is longer than which
|
|
# in evaluation phase.
|
|
assert len(obs.shape) == 3
|
|
self.nn.flatten_parameters()
|
|
obs, (hidden, cell) = self.nn(obs)
|
|
obs = obs[:, -1]
|
|
if act is not None:
|
|
act = torch.as_tensor(
|
|
act,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
obs = torch.cat([obs, act], dim=1)
|
|
return self.fc2(obs)
|
|
|
|
|
|
class Perturbation(nn.Module):
|
|
"""Implementation of perturbation network in BCQ algorithm.
|
|
|
|
Given a state and action, it can generate perturbed action.
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param max_action: the maximum value of each dimension of action.
|
|
:param device: which device to create this model on.
|
|
Default to cpu.
|
|
:param phi: max perturbation parameter for BCQ. Default to 0.05.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
max_action: float,
|
|
device: str | int | torch.device = "cpu",
|
|
phi: float = 0.05,
|
|
):
|
|
# preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
|
|
super().__init__()
|
|
self.preprocess_net = preprocess_net
|
|
self.device = device
|
|
self.max_action = max_action
|
|
self.phi = phi
|
|
|
|
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
|
|
# preprocess_net
|
|
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
|
|
noise = self.phi * self.max_action * torch.tanh(logits)
|
|
# clip to [-max_action, max_action]
|
|
return (noise + action).clamp(-self.max_action, self.max_action)
|
|
|
|
|
|
class VAE(nn.Module):
|
|
"""Implementation of VAE.
|
|
|
|
It models the distribution of action. Given a state, it can generate actions similar to those in batch.
|
|
It is used in BCQ algorithm.
|
|
|
|
:param encoder: the encoder in VAE. Its input_dim must be
|
|
state_dim + action_dim, and output_dim must be hidden_dim.
|
|
:param decoder: the decoder in VAE. Its input_dim must be
|
|
state_dim + latent_dim, and output_dim must be action_dim.
|
|
:param hidden_dim: the size of the last linear-layer in encoder.
|
|
:param latent_dim: the size of latent layer.
|
|
:param max_action: the maximum value of each dimension of action.
|
|
:param device: which device to create this model on.
|
|
Default to "cpu".
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
encoder: nn.Module,
|
|
decoder: nn.Module,
|
|
hidden_dim: int,
|
|
latent_dim: int,
|
|
max_action: float,
|
|
device: str | torch.device = "cpu",
|
|
):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
|
|
self.mean = nn.Linear(hidden_dim, latent_dim)
|
|
self.log_std = nn.Linear(hidden_dim, latent_dim)
|
|
|
|
self.decoder = decoder
|
|
|
|
self.max_action = max_action
|
|
self.latent_dim = latent_dim
|
|
self.device = device
|
|
|
|
def forward(
|
|
self,
|
|
state: torch.Tensor,
|
|
action: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# [state, action] -> z , [state, z] -> action
|
|
latent_z = self.encoder(torch.cat([state, action], -1))
|
|
# shape of z: (state.shape[:-1], hidden_dim)
|
|
|
|
mean = self.mean(latent_z)
|
|
# Clamped for numerical stability
|
|
log_std = self.log_std(latent_z).clamp(-4, 15)
|
|
std = torch.exp(log_std)
|
|
# shape of mean, std: (state.shape[:-1], latent_dim)
|
|
|
|
latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
|
|
|
|
reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
|
|
return reconstruction, mean, std
|
|
|
|
def decode(
|
|
self,
|
|
state: torch.Tensor,
|
|
latent_z: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# decode(state) -> action
|
|
if latent_z is None:
|
|
# state.shape[0] may be batch_size
|
|
# latent vector clipped to [-0.5, 0.5]
|
|
latent_z = (
|
|
torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5)
|
|
)
|
|
|
|
# decode z with state!
|
|
return self.max_action * torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))
|