Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
500 lines
17 KiB
Python
500 lines
17 KiB
Python
import warnings
|
|
from collections.abc import Sequence
|
|
from typing import Any, cast
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from tianshou.utils.net.common import MLP
|
|
|
|
SIGMA_MIN = -20
|
|
SIGMA_MAX = 2
|
|
|
|
|
|
class Actor(nn.Module):
|
|
"""Simple actor network.
|
|
|
|
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param action_shape: a sequence of int for the shape of action.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param float max_action: the scale for the final action logits. Default to
|
|
1.
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
action_shape: Sequence[int],
|
|
hidden_sizes: Sequence[int] = (),
|
|
max_action: float = 1.0,
|
|
device: str | int | torch.device = "cpu",
|
|
preprocess_net_output_dim: int | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.preprocess = preprocess_net
|
|
self.output_dim = int(np.prod(action_shape))
|
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
|
input_dim = cast(int, input_dim)
|
|
self.last = MLP(
|
|
input_dim,
|
|
self.output_dim,
|
|
hidden_sizes,
|
|
device=self.device,
|
|
)
|
|
self.max_action = max_action
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
"""Mapping: obs -> logits -> action."""
|
|
if info is None:
|
|
info = {}
|
|
logits, hidden = self.preprocess(obs, state)
|
|
logits = self.max_action * torch.tanh(self.last(logits))
|
|
return logits, hidden
|
|
|
|
|
|
class Critic(nn.Module):
|
|
"""Simple critic network.
|
|
|
|
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
|
:param bool flatten_input: whether to flatten input data for the last layer.
|
|
Default to True.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
hidden_sizes: Sequence[int] = (),
|
|
device: str | int | torch.device = "cpu",
|
|
preprocess_net_output_dim: int | None = None,
|
|
linear_layer: type[nn.Linear] = nn.Linear,
|
|
flatten_input: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.preprocess = preprocess_net
|
|
self.output_dim = 1
|
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
|
self.last = MLP(
|
|
input_dim, # type: ignore
|
|
1,
|
|
hidden_sizes,
|
|
device=self.device,
|
|
linear_layer=linear_layer,
|
|
flatten_input=flatten_input,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
act: np.ndarray | torch.Tensor | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> torch.Tensor:
|
|
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(
|
|
obs,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
).flatten(1)
|
|
if act is not None:
|
|
act = torch.as_tensor(
|
|
act,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
).flatten(1)
|
|
obs = torch.cat([obs, act], dim=1)
|
|
logits, hidden = self.preprocess(obs)
|
|
return self.last(logits)
|
|
|
|
|
|
class ActorProb(nn.Module):
|
|
"""Simple actor network (output with a Gauss distribution).
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param action_shape: a sequence of int for the shape of action.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param float max_action: the scale for the final action logits. Default to
|
|
1.
|
|
:param bool unbounded: whether to apply tanh activation on final logits.
|
|
Default to False.
|
|
:param bool conditioned_sigma: True when sigma is calculated from the
|
|
input, False when sigma is an independent parameter. Default to False.
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
action_shape: Sequence[int],
|
|
hidden_sizes: Sequence[int] = (),
|
|
max_action: float = 1.0,
|
|
device: str | int | torch.device = "cpu",
|
|
unbounded: bool = False,
|
|
conditioned_sigma: bool = False,
|
|
preprocess_net_output_dim: int | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
if unbounded and not np.isclose(max_action, 1.0):
|
|
warnings.warn("Note that max_action input will be discarded when unbounded is True.")
|
|
max_action = 1.0
|
|
self.preprocess = preprocess_net
|
|
self.device = device
|
|
self.output_dim = int(np.prod(action_shape))
|
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
|
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
|
|
self._c_sigma = conditioned_sigma
|
|
if conditioned_sigma:
|
|
self.sigma = MLP(
|
|
input_dim, # type: ignore
|
|
self.output_dim,
|
|
hidden_sizes,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
|
self.max_action = max_action
|
|
self._unbounded = unbounded
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
|
|
"""Mapping: obs -> logits -> (mu, sigma)."""
|
|
if info is None:
|
|
info = {}
|
|
logits, hidden = self.preprocess(obs, state)
|
|
mu = self.mu(logits)
|
|
if not self._unbounded:
|
|
mu = self.max_action * torch.tanh(mu)
|
|
if self._c_sigma:
|
|
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
|
else:
|
|
shape = [1] * len(mu.shape)
|
|
shape[1] = -1
|
|
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
|
return (mu, sigma), state
|
|
|
|
|
|
class RecurrentActorProb(nn.Module):
|
|
"""Recurrent version of ActorProb.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_num: int,
|
|
state_shape: Sequence[int],
|
|
action_shape: Sequence[int],
|
|
hidden_layer_size: int = 128,
|
|
max_action: float = 1.0,
|
|
device: str | int | torch.device = "cpu",
|
|
unbounded: bool = False,
|
|
conditioned_sigma: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
if unbounded and not np.isclose(max_action, 1.0):
|
|
warnings.warn("Note that max_action input will be discarded when unbounded is True.")
|
|
max_action = 1.0
|
|
self.device = device
|
|
self.nn = nn.LSTM(
|
|
input_size=int(np.prod(state_shape)),
|
|
hidden_size=hidden_layer_size,
|
|
num_layers=layer_num,
|
|
batch_first=True,
|
|
)
|
|
output_dim = int(np.prod(action_shape))
|
|
self.mu = nn.Linear(hidden_layer_size, output_dim)
|
|
self._c_sigma = conditioned_sigma
|
|
if conditioned_sigma:
|
|
self.sigma = nn.Linear(hidden_layer_size, output_dim)
|
|
else:
|
|
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
|
|
self.max_action = max_action
|
|
self._unbounded = unbounded
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: dict[str, torch.Tensor] | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]:
|
|
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(
|
|
obs,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
|
# In short, the tensor's shape in training phase is longer than which
|
|
# in evaluation phase.
|
|
if len(obs.shape) == 2:
|
|
obs = obs.unsqueeze(-2)
|
|
self.nn.flatten_parameters()
|
|
if state is None:
|
|
obs, (hidden, cell) = self.nn(obs)
|
|
else:
|
|
# we store the stack data in [bsz, len, ...] format
|
|
# but pytorch rnn needs [len, bsz, ...]
|
|
obs, (hidden, cell) = self.nn(
|
|
obs,
|
|
(
|
|
state["hidden"].transpose(0, 1).contiguous(),
|
|
state["cell"].transpose(0, 1).contiguous(),
|
|
),
|
|
)
|
|
logits = obs[:, -1]
|
|
mu = self.mu(logits)
|
|
if not self._unbounded:
|
|
mu = self.max_action * torch.tanh(mu)
|
|
if self._c_sigma:
|
|
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
|
else:
|
|
shape = [1] * len(mu.shape)
|
|
shape[1] = -1
|
|
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
|
# please ensure the first dim is batch size: [bsz, len, ...]
|
|
return (mu, sigma), {
|
|
"hidden": hidden.transpose(0, 1).detach(),
|
|
"cell": cell.transpose(0, 1).detach(),
|
|
}
|
|
|
|
|
|
class RecurrentCritic(nn.Module):
|
|
"""Recurrent version of Critic.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_num: int,
|
|
state_shape: Sequence[int],
|
|
action_shape: Sequence[int] = [0],
|
|
device: str | int | torch.device = "cpu",
|
|
hidden_layer_size: int = 128,
|
|
) -> None:
|
|
super().__init__()
|
|
self.state_shape = state_shape
|
|
self.action_shape = action_shape
|
|
self.device = device
|
|
self.nn = nn.LSTM(
|
|
input_size=int(np.prod(state_shape)),
|
|
hidden_size=hidden_layer_size,
|
|
num_layers=layer_num,
|
|
batch_first=True,
|
|
)
|
|
self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
act: np.ndarray | torch.Tensor | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> torch.Tensor:
|
|
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(
|
|
obs,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
|
# In short, the tensor's shape in training phase is longer than which
|
|
# in evaluation phase.
|
|
assert len(obs.shape) == 3
|
|
self.nn.flatten_parameters()
|
|
obs, (hidden, cell) = self.nn(obs)
|
|
obs = obs[:, -1]
|
|
if act is not None:
|
|
act = torch.as_tensor(
|
|
act,
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
obs = torch.cat([obs, act], dim=1)
|
|
return self.fc2(obs)
|
|
|
|
|
|
class Perturbation(nn.Module):
|
|
"""Implementation of perturbation network in BCQ algorithm.
|
|
|
|
Given a state and action, it can generate perturbed action.
|
|
|
|
:param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param float max_action: the maximum value of each dimension of action.
|
|
:param Union[str, int, torch.device] device: which device to create this model on.
|
|
Default to cpu.
|
|
:param float phi: max perturbation parameter for BCQ. Default to 0.05.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
max_action: float,
|
|
device: str | int | torch.device = "cpu",
|
|
phi: float = 0.05,
|
|
):
|
|
# preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
|
|
super().__init__()
|
|
self.preprocess_net = preprocess_net
|
|
self.device = device
|
|
self.max_action = max_action
|
|
self.phi = phi
|
|
|
|
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
|
|
# preprocess_net
|
|
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
|
|
noise = self.phi * self.max_action * torch.tanh(logits)
|
|
# clip to [-max_action, max_action]
|
|
return (noise + action).clamp(-self.max_action, self.max_action)
|
|
|
|
|
|
class VAE(nn.Module):
|
|
"""Implementation of VAE.
|
|
|
|
It models the distribution of action. Given a state, it can generate actions similar to those in batch.
|
|
It is used in BCQ algorithm.
|
|
|
|
:param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be
|
|
state_dim + action_dim, and output_dim must be hidden_dim.
|
|
:param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be
|
|
state_dim + latent_dim, and output_dim must be action_dim.
|
|
:param int hidden_dim: the size of the last linear-layer in encoder.
|
|
:param int latent_dim: the size of latent layer.
|
|
:param float max_action: the maximum value of each dimension of action.
|
|
:param Union[str, torch.device] device: which device to create this model on.
|
|
Default to "cpu".
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
encoder: nn.Module,
|
|
decoder: nn.Module,
|
|
hidden_dim: int,
|
|
latent_dim: int,
|
|
max_action: float,
|
|
device: str | torch.device = "cpu",
|
|
):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
|
|
self.mean = nn.Linear(hidden_dim, latent_dim)
|
|
self.log_std = nn.Linear(hidden_dim, latent_dim)
|
|
|
|
self.decoder = decoder
|
|
|
|
self.max_action = max_action
|
|
self.latent_dim = latent_dim
|
|
self.device = device
|
|
|
|
def forward(
|
|
self,
|
|
state: torch.Tensor,
|
|
action: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# [state, action] -> z , [state, z] -> action
|
|
latent_z = self.encoder(torch.cat([state, action], -1))
|
|
# shape of z: (state.shape[:-1], hidden_dim)
|
|
|
|
mean = self.mean(latent_z)
|
|
# Clamped for numerical stability
|
|
log_std = self.log_std(latent_z).clamp(-4, 15)
|
|
std = torch.exp(log_std)
|
|
# shape of mean, std: (state.shape[:-1], latent_dim)
|
|
|
|
latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
|
|
|
|
reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
|
|
return reconstruction, mean, std
|
|
|
|
def decode(
|
|
self,
|
|
state: torch.Tensor,
|
|
latent_z: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
# decode(state) -> action
|
|
if latent_z is None:
|
|
# state.shape[0] may be batch_size
|
|
# latent vector clipped to [-0.5, 0.5]
|
|
latent_z = (
|
|
torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5)
|
|
)
|
|
|
|
# decode z with state!
|
|
return self.max_action * torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))
|