2020-03-28 07:27:18 +08:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
from torch import nn
|
|
|
|
import torch.nn.functional as F
|
2020-09-12 15:39:01 +08:00
|
|
|
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
2020-03-28 07:27:18 +08:00
|
|
|
|
2021-01-20 16:54:13 +08:00
|
|
|
from tianshou.utils.net.common import MLP
|
2020-09-13 19:31:50 +08:00
|
|
|
|
2020-03-28 07:27:18 +08:00
|
|
|
|
|
|
|
class Actor(nn.Module):
|
2021-01-20 16:54:13 +08:00
|
|
|
"""Simple actor network.
|
|
|
|
|
|
|
|
Will create an actor operated in discrete action space with structure of
|
|
|
|
preprocess_net ---> action_shape.
|
|
|
|
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
|
|
flattened hidden state.
|
|
|
|
:param action_shape: a sequence of int for the shape of action.
|
|
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
|
|
only a single linear layer).
|
|
|
|
:param bool softmax_output: whether to apply a softmax layer over the last
|
|
|
|
layer's output.
|
|
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
|
|
preprocess_net.
|
2020-09-11 07:55:37 +08:00
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
2020-07-09 22:57:01 +08:00
|
|
|
:ref:`build_the_network`.
|
2021-01-20 16:54:13 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
|
|
of how preprocess_net is suggested to be defined.
|
2020-07-09 22:57:01 +08:00
|
|
|
"""
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
preprocess_net: nn.Module,
|
|
|
|
action_shape: Sequence[int],
|
2021-01-20 16:54:13 +08:00
|
|
|
hidden_sizes: Sequence[int] = (),
|
2020-09-14 14:59:23 +08:00
|
|
|
softmax_output: bool = True,
|
2021-01-20 16:54:13 +08:00
|
|
|
preprocess_net_output_dim: Optional[int] = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
) -> None:
|
2020-03-28 07:27:18 +08:00
|
|
|
super().__init__()
|
|
|
|
self.preprocess = preprocess_net
|
2021-01-20 16:54:13 +08:00
|
|
|
self.output_dim = np.prod(action_shape)
|
|
|
|
input_dim = getattr(preprocess_net, "output_dim",
|
|
|
|
preprocess_net_output_dim)
|
|
|
|
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
2020-09-14 14:59:23 +08:00
|
|
|
self.softmax_output = softmax_output
|
2020-03-28 07:27:18 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
s: Union[np.ndarray, torch.Tensor],
|
|
|
|
state: Optional[Any] = None,
|
|
|
|
info: Dict[str, Any] = {},
|
|
|
|
) -> Tuple[torch.Tensor, Any]:
|
2020-09-11 07:55:37 +08:00
|
|
|
r"""Mapping: s -> Q(s, \*)."""
|
2020-03-28 07:27:18 +08:00
|
|
|
logits, h = self.preprocess(s, state)
|
2020-09-14 14:59:23 +08:00
|
|
|
logits = self.last(logits)
|
|
|
|
if self.softmax_output:
|
|
|
|
logits = F.softmax(logits, dim=-1)
|
2020-03-28 07:27:18 +08:00
|
|
|
return logits, h
|
|
|
|
|
|
|
|
|
|
|
|
class Critic(nn.Module):
|
2021-01-20 16:54:13 +08:00
|
|
|
"""Simple critic network. Will create an actor operated in discrete \
|
|
|
|
action space with structure of preprocess_net ---> 1(q value).
|
|
|
|
|
|
|
|
:param preprocess_net: a self-defined preprocess_net which output a
|
|
|
|
flattened hidden state.
|
|
|
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
|
|
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
|
|
|
only a single linear layer).
|
|
|
|
:param int last_size: the output dimension of Critic network. Default to 1.
|
|
|
|
:param int preprocess_net_output_dim: the output dimension of
|
|
|
|
preprocess_net.
|
2020-09-11 07:55:37 +08:00
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
2020-07-09 22:57:01 +08:00
|
|
|
:ref:`build_the_network`.
|
2021-01-20 16:54:13 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
|
|
|
of how preprocess_net is suggested to be defined.
|
2020-07-09 22:57:01 +08:00
|
|
|
"""
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
2020-09-14 14:59:23 +08:00
|
|
|
self,
|
|
|
|
preprocess_net: nn.Module,
|
2021-01-20 16:54:13 +08:00
|
|
|
hidden_sizes: Sequence[int] = (),
|
|
|
|
last_size: int = 1,
|
|
|
|
preprocess_net_output_dim: Optional[int] = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
) -> None:
|
2020-03-28 07:27:18 +08:00
|
|
|
super().__init__()
|
|
|
|
self.preprocess = preprocess_net
|
2021-01-20 16:54:13 +08:00
|
|
|
self.output_dim = last_size
|
|
|
|
input_dim = getattr(preprocess_net, "output_dim",
|
|
|
|
preprocess_net_output_dim)
|
|
|
|
self.last = MLP(input_dim, last_size, hidden_sizes)
|
2020-03-28 07:27:18 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def forward(
|
|
|
|
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
|
|
|
) -> torch.Tensor:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Mapping: s -> V(s)."""
|
2021-01-20 16:54:13 +08:00
|
|
|
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
|
|
|
|
return self.last(logits)
|