Tianshou/tianshou/utils/net/continuous.py
Yi Su dd16818ce4
implement REDQ based on original contribution by @Jimenius (#623)
Co-authored-by: Minhui Li
 <limh@lamda.nju.edu.cn>
2022-05-01 00:06:00 +08:00

481 lines
17 KiB
Python

from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
from torch import nn
from tianshou.utils.net.common import MLP
SIGMA_MIN = -20
SIGMA_MAX = 2
class Actor(nn.Module):
"""Simple actor network. Will create an actor operated in continuous \
action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param float max_action: the scale for the final action logits. Default to
1.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self._max = max_action
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
logits, hidden = self.preprocess(obs, state)
logits = self._max * torch.tanh(self.last(logits))
return logits, hidden
class Critic(nn.Module):
"""Simple critic network. Will create an actor operated in continuous \
action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
:param linear_layer: use this module as linear layer. Default to nn.Linear.
:param bool flatten_input: whether to flatten input data for the last layer.
Default to True.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
hidden_sizes: Sequence[int] = (),
device: Union[str, int, torch.device] = "cpu",
preprocess_net_output_dim: Optional[int] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
flatten_input: bool = True,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
1,
hidden_sizes,
device=self.device,
linear_layer=linear_layer,
flatten_input=flatten_input,
)
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
).flatten(1)
if act is not None:
act = torch.as_tensor(
act,
device=self.device, # type: ignore
dtype=torch.float32,
).flatten(1)
obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs)
logits = self.last(logits)
return logits
class ActorProb(nn.Module):
"""Simple actor network (output with a Gauss distribution).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param float max_action: the scale for the final action logits. Default to
1.
:param bool unbounded: whether to apply tanh activation on final logits.
Default to False.
:param bool conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
conditioned_sigma: bool = False,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.device = device
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.mu = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
self._max = max_action
self._unbounded = unbounded
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
"""Mapping: obs -> logits -> (mu, sigma)."""
logits, hidden = self.preprocess(obs, state)
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
if self._c_sigma:
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
return (mu, sigma), state
class RecurrentActorProb(nn.Module):
"""Recurrent version of ActorProb.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int],
hidden_layer_size: int = 128,
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
conditioned_sigma: bool = False,
) -> None:
super().__init__()
self.device = device
self.nn = nn.LSTM(
input_size=int(np.prod(state_shape)),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
output_dim = int(np.prod(action_shape))
self.mu = nn.Linear(hidden_layer_size, output_dim)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = nn.Linear(hidden_layer_size, output_dim)
else:
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
self._max = max_action
self._unbounded = unbounded
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(obs.shape) == 2:
obs = obs.unsqueeze(-2)
self.nn.flatten_parameters()
if state is None:
obs, (hidden, cell) = self.nn(obs)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
obs, (hidden, cell) = self.nn(
obs, (
state["hidden"].transpose(0, 1).contiguous(),
state["cell"].transpose(0, 1).contiguous()
)
)
logits = obs[:, -1]
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
if self._c_sigma:
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), {
"hidden": hidden.transpose(0, 1).detach(),
"cell": cell.transpose(0, 1).detach()
}
class RecurrentCritic(nn.Module):
"""Recurrent version of Critic.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int] = [0],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(
input_size=int(np.prod(state_shape)),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
assert len(obs.shape) == 3
self.nn.flatten_parameters()
obs, (hidden, cell) = self.nn(obs)
obs = obs[:, -1]
if act is not None:
act = torch.as_tensor(
act,
device=self.device, # type: ignore
dtype=torch.float32,
)
obs = torch.cat([obs, act], dim=1)
obs = self.fc2(obs)
return obs
class Perturbation(nn.Module):
"""Implementation of perturbation network in BCQ algorithm. Given a state and \
action, it can generate perturbed action.
:param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param float max_action: the maximum value of each dimension of action.
:param Union[str, int, torch.device] device: which device to create this model on.
Default to cpu.
:param float phi: max perturbation parameter for BCQ. Default to 0.05.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
"""
def __init__(
self,
preprocess_net: nn.Module,
max_action: float,
device: Union[str, int, torch.device] = "cpu",
phi: float = 0.05
):
# preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
super(Perturbation, self).__init__()
self.preprocess_net = preprocess_net
self.device = device
self.max_action = max_action
self.phi = phi
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
# preprocess_net
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
noise = self.phi * self.max_action * torch.tanh(logits)
# clip to [-max_action, max_action]
return (noise + action).clamp(-self.max_action, self.max_action)
class VAE(nn.Module):
"""Implementation of VAE. It models the distribution of action. Given a \
state, it can generate actions similar to those in batch. It is used \
in BCQ algorithm.
:param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be
state_dim + action_dim, and output_dim must be hidden_dim.
:param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be
state_dim + latent_dim, and output_dim must be action_dim.
:param int hidden_dim: the size of the last linear-layer in encoder.
:param int latent_dim: the size of latent layer.
:param float max_action: the maximum value of each dimension of action.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
hidden_dim: int,
latent_dim: int,
max_action: float,
device: Union[str, torch.device] = "cpu"
):
super(VAE, self).__init__()
self.encoder = encoder
self.mean = nn.Linear(hidden_dim, latent_dim)
self.log_std = nn.Linear(hidden_dim, latent_dim)
self.decoder = decoder
self.max_action = max_action
self.latent_dim = latent_dim
self.device = device
def forward(
self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [state, action] -> z , [state, z] -> action
latent_z = self.encoder(torch.cat([state, action], -1))
# shape of z: (state.shape[:-1], hidden_dim)
mean = self.mean(latent_z)
# Clamped for numerical stability
log_std = self.log_std(latent_z).clamp(-4, 15)
std = torch.exp(log_std)
# shape of mean, std: (state.shape[:-1], latent_dim)
latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
return reconstruction, mean, std
def decode(
self,
state: torch.Tensor,
latent_z: Union[torch.Tensor, None] = None
) -> torch.Tensor:
# decode(state) -> action
if latent_z is None:
# state.shape[0] may be batch_size
# latent vector clipped to [-0.5, 0.5]
latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
.to(self.device).clamp(-0.5, 0.5)
# decode z with state!
return self.max_action * \
torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))