Tianshou/tianshou/utils/net/discrete.py
2021-05-29 09:44:23 +08:00

205 lines
7.5 KiB
Python

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.utils.net.common import MLP
class Actor(nn.Module):
"""Simple actor network.
Will create an actor operated in discrete action space with structure of
preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param bool softmax_output: whether to apply a softmax layer over the last
layer's output.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
softmax_output: bool = True,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim,
hidden_sizes, device=self.device)
self.softmax_output = softmax_output
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state)
logits = self.last(logits)
if self.softmax_output:
logits = F.softmax(logits, dim=-1)
return logits, h
class Critic(nn.Module):
"""Simple critic network. Will create an actor operated in discrete \
action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int last_size: the output dimension of Critic network. Default to 1.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
hidden_sizes: Sequence[int] = (),
last_size: int = 1,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, last_size,
hidden_sizes, device=self.device)
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
) -> torch.Tensor:
"""Mapping: s -> V(s)."""
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
return self.last(logits)
class CosineEmbeddingNetwork(nn.Module):
"""Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \
of n-dim vectors.
:param num_cosines: the number of cosines used for the embedding.
:param embedding_dim: the dimension of the embedding/output.
.. note::
From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
/fqf_iqn_qrdqn/network.py .
"""
def __init__(self, num_cosines: int, embedding_dim: int) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU())
self.num_cosines = num_cosines
self.embedding_dim = embedding_dim
def forward(self, taus: torch.Tensor) -> torch.Tensor:
batch_size = taus.shape[0]
N = taus.shape[1]
# Calculate i * \pi (i=1,...,N).
i_pi = np.pi * torch.arange(
start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device
).view(1, 1, self.num_cosines)
# Calculate cos(i * \pi * \tau).
cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view(
batch_size * N, self.num_cosines
)
# Calculate embeddings of taus.
tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim)
return tau_embeddings
class ImplicitQuantileNetwork(Critic):
"""Implicit Quantile Network.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param int action_dim: the dimension of action space.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int num_cosines: the number of cosines to use for cosine embedding.
Default to 64.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
.. note::
Although this class inherits Critic, it is actually a quantile Q-Network
with output shape (batch_size, action_dim, sample_size).
The second item of the first return value is tau vector.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu"
) -> None:
last_size = np.prod(action_shape)
super().__init__(preprocess_net, hidden_sizes, last_size,
preprocess_net_output_dim, device)
self.input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines,
self.input_dim).to(device)
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
# Sample fractions.
batch_size = logits.size(0)
taus = torch.rand(batch_size, sample_size,
dtype=logits.dtype, device=logits.device)
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
batch_size * sample_size, -1
)
out = self.last(embedding).view(batch_size,
sample_size, -1).transpose(1, 2)
return (out, taus), h