205 lines
7.5 KiB
Python
205 lines
7.5 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
|
|
|
from tianshou.utils.net.common import MLP
|
|
|
|
|
|
class Actor(nn.Module):
|
|
"""Simple actor network.
|
|
|
|
Will create an actor operated in discrete action space with structure of
|
|
preprocess_net ---> action_shape.
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param action_shape: a sequence of int for the shape of action.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param bool softmax_output: whether to apply a softmax layer over the last
|
|
layer's output.
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
action_shape: Sequence[int],
|
|
hidden_sizes: Sequence[int] = (),
|
|
softmax_output: bool = True,
|
|
preprocess_net_output_dim: Optional[int] = None,
|
|
device: Union[str, int, torch.device] = "cpu",
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.preprocess = preprocess_net
|
|
self.output_dim = int(np.prod(action_shape))
|
|
input_dim = getattr(preprocess_net, "output_dim",
|
|
preprocess_net_output_dim)
|
|
self.last = MLP(input_dim, self.output_dim,
|
|
hidden_sizes, device=self.device)
|
|
self.softmax_output = softmax_output
|
|
|
|
def forward(
|
|
self,
|
|
s: Union[np.ndarray, torch.Tensor],
|
|
state: Any = None,
|
|
info: Dict[str, Any] = {},
|
|
) -> Tuple[torch.Tensor, Any]:
|
|
r"""Mapping: s -> Q(s, \*)."""
|
|
logits, h = self.preprocess(s, state)
|
|
logits = self.last(logits)
|
|
if self.softmax_output:
|
|
logits = F.softmax(logits, dim=-1)
|
|
return logits, h
|
|
|
|
|
|
class Critic(nn.Module):
|
|
"""Simple critic network. Will create an actor operated in discrete \
|
|
action space with structure of preprocess_net ---> 1(q value).
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param int last_size: the output dimension of Critic network. Default to 1.
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
of how preprocess_net is suggested to be defined.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
hidden_sizes: Sequence[int] = (),
|
|
last_size: int = 1,
|
|
preprocess_net_output_dim: Optional[int] = None,
|
|
device: Union[str, int, torch.device] = "cpu",
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.preprocess = preprocess_net
|
|
self.output_dim = last_size
|
|
input_dim = getattr(preprocess_net, "output_dim",
|
|
preprocess_net_output_dim)
|
|
self.last = MLP(input_dim, last_size,
|
|
hidden_sizes, device=self.device)
|
|
|
|
def forward(
|
|
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
|
) -> torch.Tensor:
|
|
"""Mapping: s -> V(s)."""
|
|
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
|
|
return self.last(logits)
|
|
|
|
|
|
class CosineEmbeddingNetwork(nn.Module):
|
|
"""Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \
|
|
of n-dim vectors.
|
|
|
|
:param num_cosines: the number of cosines used for the embedding.
|
|
:param embedding_dim: the dimension of the embedding/output.
|
|
|
|
.. note::
|
|
|
|
From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
|
|
/fqf_iqn_qrdqn/network.py .
|
|
"""
|
|
|
|
def __init__(self, num_cosines: int, embedding_dim: int) -> None:
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU())
|
|
self.num_cosines = num_cosines
|
|
self.embedding_dim = embedding_dim
|
|
|
|
def forward(self, taus: torch.Tensor) -> torch.Tensor:
|
|
batch_size = taus.shape[0]
|
|
N = taus.shape[1]
|
|
# Calculate i * \pi (i=1,...,N).
|
|
i_pi = np.pi * torch.arange(
|
|
start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device
|
|
).view(1, 1, self.num_cosines)
|
|
# Calculate cos(i * \pi * \tau).
|
|
cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view(
|
|
batch_size * N, self.num_cosines
|
|
)
|
|
# Calculate embeddings of taus.
|
|
tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim)
|
|
return tau_embeddings
|
|
|
|
|
|
class ImplicitQuantileNetwork(Critic):
|
|
"""Implicit Quantile Network.
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
flattened hidden state.
|
|
:param int action_dim: the dimension of action space.
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
only a single linear layer).
|
|
:param int num_cosines: the number of cosines to use for cosine embedding.
|
|
Default to 64.
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
preprocess_net.
|
|
|
|
.. note::
|
|
|
|
Although this class inherits Critic, it is actually a quantile Q-Network
|
|
with output shape (batch_size, action_dim, sample_size).
|
|
|
|
The second item of the first return value is tau vector.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
preprocess_net: nn.Module,
|
|
action_shape: Sequence[int],
|
|
hidden_sizes: Sequence[int] = (),
|
|
num_cosines: int = 64,
|
|
preprocess_net_output_dim: Optional[int] = None,
|
|
device: Union[str, int, torch.device] = "cpu"
|
|
) -> None:
|
|
last_size = np.prod(action_shape)
|
|
super().__init__(preprocess_net, hidden_sizes, last_size,
|
|
preprocess_net_output_dim, device)
|
|
self.input_dim = getattr(preprocess_net, "output_dim",
|
|
preprocess_net_output_dim)
|
|
self.embed_model = CosineEmbeddingNetwork(num_cosines,
|
|
self.input_dim).to(device)
|
|
|
|
def forward( # type: ignore
|
|
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
|
|
) -> Tuple[Any, torch.Tensor]:
|
|
r"""Mapping: s -> Q(s, \*)."""
|
|
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
|
# Sample fractions.
|
|
batch_size = logits.size(0)
|
|
taus = torch.rand(batch_size, sample_size,
|
|
dtype=logits.dtype, device=logits.device)
|
|
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
|
|
batch_size * sample_size, -1
|
|
)
|
|
out = self.last(embedding).view(batch_size,
|
|
sample_size, -1).transpose(1, 2)
|
|
return (out, taus), h
|