n+e 09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00

258 lines
10 KiB
Python

import torch
import numpy as np
from torch import nn
from typing import Any, Dict, List, Type, Tuple, Union, Optional, Sequence
ModuleType = Type[nn.Module]
def miniblock(
input_size: int,
output_size: int = 0,
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = None,
) -> List[nn.Module]:
"""Construct a miniblock with given input/output-size, norm layer and \
activation."""
layers: List[nn.Module] = [nn.Linear(input_size, output_size)]
if norm_layer is not None:
layers += [norm_layer(output_size)] # type: ignore
if activation is not None:
layers += [activation()]
return layers
class MLP(nn.Module):
"""Simple MLP backbone.
Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ...
* hidden_sizes[-1] * output_dim
:param int input_dim: dimension of the input vector.
:param int output_dim: dimension of the output vector. If set to 0, there
is no final linear layer.
:param hidden_sizes: shape of MLP passed in as a list, not incluing
input_dim and output_dim.
:param norm_layer: use which normalization before activation, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
You can also pass a list of normalization modules with the same length
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
the same actvition for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
"""
def __init__(
self,
input_dim: int,
output_dim: int = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
device: Optional[Union[str, int, torch.device]] = None,
) -> None:
super().__init__()
self.device = device
if norm_layer:
if isinstance(norm_layer, list):
assert len(norm_layer) == len(hidden_sizes)
norm_layer_list = norm_layer
else:
norm_layer_list = [
norm_layer for _ in range(len(hidden_sizes))]
else:
norm_layer_list = [None] * len(hidden_sizes)
if activation:
if isinstance(activation, list):
assert len(activation) == len(hidden_sizes)
activation_list = activation
else:
activation_list = [
activation for _ in range(len(hidden_sizes))]
else:
activation_list = [None] * len(hidden_sizes)
hidden_sizes = [input_dim] + list(hidden_sizes)
model = []
for in_dim, out_dim, norm, activ in zip(
hidden_sizes[:-1], hidden_sizes[1:],
norm_layer_list, activation_list):
model += miniblock(in_dim, out_dim, norm, activ)
if output_dim > 0:
model += [nn.Linear(hidden_sizes[-1], output_dim)]
self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model)
def forward(
self, x: Union[np.ndarray, torch.Tensor]
) -> torch.Tensor:
x = torch.as_tensor(
x, device=self.device, dtype=torch.float32) # type: ignore
return self.model(x.flatten(1))
class Net(nn.Module):
"""Wrapper of MLP to support more specific DRL usage.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
:param state_shape: int or a sequence of int of the shape of state.
:param action_shape: int or a sequence of int of the shape of action.
:param hidden_sizes: shape of MLP passed in as a list.
:param norm_layer: use which normalization before activation, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
You can also pass a list of normalization modules with the same length
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
the same actvition for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
:param device: specify the device when the network actually runs. Default
to "cpu".
:param bool softmax: whether to apply a softmax layer over the last layer's
output.
:param bool concat: whether the input shape is concatenated by state_shape
and action_shape. If it is True, ``action_shape`` is not the output
shape, but affects the input shape only.
:param int num_atoms: in order to expand to the net of distributional RL.
Default to 1 (not use).
:param bool dueling_param: whether to use dueling network to calculate Q
values (for Dueling DQN). If you want to use dueling option, you should
pass a tuple of two dict (first for Q and second for V) stating
self-defined arguments as stated in
class:`~tianshou.utils.net.common.MLP`. Default to None.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.MLP` for more
detailed explanation on the usage of activation, norm_layer, etc.
You can also refer to :class:`~tianshou.utils.net.continuous.Actor`,
:class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's
suggested be used.
"""
def __init__(
self,
state_shape: Union[int, Sequence[int]],
action_shape: Union[int, Sequence[int]] = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = nn.ReLU,
device: Union[str, int, torch.device] = "cpu",
softmax: bool = False,
concat: bool = False,
num_atoms: int = 1,
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
) -> None:
super().__init__()
self.device = device
self.softmax = softmax
self.num_atoms = num_atoms
input_dim = int(np.prod(state_shape))
action_dim = int(np.prod(action_shape)) * num_atoms
if concat:
input_dim += action_dim
self.use_dueling = dueling_param is not None
output_dim = action_dim if not self.use_dueling and not concat else 0
self.model = MLP(input_dim, output_dim, hidden_sizes,
norm_layer, activation, device)
self.output_dim = self.model.output_dim
if self.use_dueling: # dueling DQN
q_kwargs, v_kwargs = dueling_param # type: ignore
q_output_dim, v_output_dim = 0, 0
if not concat:
q_output_dim, v_output_dim = action_dim, num_atoms
q_kwargs: Dict[str, Any] = {
**q_kwargs, "input_dim": self.output_dim,
"output_dim": q_output_dim}
v_kwargs: Dict[str, Any] = {
**v_kwargs, "input_dim": self.output_dim,
"output_dim": v_output_dim}
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
self.output_dim = self.Q.output_dim
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> flatten (inside MLP)-> logits."""
logits = self.model(s)
bsz = logits.shape[0]
if self.use_dueling: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
if self.num_atoms > 1:
q = q.view(bsz, -1, self.num_atoms)
v = v.view(bsz, -1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
elif self.num_atoms > 1:
logits = logits.view(bsz, -1, self.num_atoms)
if self.softmax:
logits = torch.softmax(logits, dim=-1)
return logits, state
class Recurrent(nn.Module):
"""Simple Recurrent network based on LSTM.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
layer_num: int,
state_shape: Union[int, Sequence[int]],
action_shape: Union[int, Sequence[int]],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.device = device
self.nn = nn.LSTM(
input_size=hidden_layer_size,
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size)
self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape)))
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Mapping: s -> flatten -> logits.
In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
training mode, s should be with shape ``[bsz, len, dim]``. See the code
and comment for more detail.
"""
s = torch.as_tensor(
s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(s.shape) == 2:
s = s.unsqueeze(-2)
s = self.fc1(s)
self.nn.flatten_parameters()
if state is None:
s, (h, c) = self.nn(s)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous()))
s = self.fc2(s[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()}