2023-09-05 23:34:23 +02:00
|
|
|
from collections.abc import Callable, Sequence
|
|
|
|
from typing import Any
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2021-01-20 16:54:13 +08:00
|
|
|
import numpy as np
|
2021-09-03 05:05:04 +08:00
|
|
|
import torch
|
2021-01-20 16:54:13 +08:00
|
|
|
from torch import nn
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.highlevel.env import Environments
|
|
|
|
from tianshou.highlevel.module.actor import ActorFactory
|
2023-10-11 15:31:38 +02:00
|
|
|
from tianshou.highlevel.module.core import (
|
2023-10-16 18:19:31 +02:00
|
|
|
TDevice,
|
|
|
|
)
|
|
|
|
from tianshou.highlevel.module.intermediate import (
|
2023-10-11 15:31:38 +02:00
|
|
|
IntermediateModule,
|
|
|
|
IntermediateModuleFactory,
|
|
|
|
)
|
2024-04-03 18:07:51 +02:00
|
|
|
from tianshou.utils.net.common import NetBase
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
2021-01-20 16:54:13 +08:00
|
|
|
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
|
2022-12-04 12:23:18 -08:00
|
|
|
torch.nn.init.orthogonal_(layer.weight, std)
|
|
|
|
torch.nn.init.constant_(layer.bias, bias_const)
|
|
|
|
return layer
|
|
|
|
|
|
|
|
|
2024-01-11 12:34:26 +01:00
|
|
|
class ScaledObsInputModule(torch.nn.Module):
|
2024-04-03 18:07:51 +02:00
|
|
|
def __init__(self, module: NetBase, denom: float = 255.0) -> None:
|
2024-01-11 12:34:26 +01:00
|
|
|
super().__init__()
|
|
|
|
self.module = module
|
|
|
|
self.denom = denom
|
2024-01-16 13:25:41 +01:00
|
|
|
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
|
2024-01-11 12:34:26 +01:00
|
|
|
self.output_dim = module.output_dim
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
obs: np.ndarray | torch.Tensor,
|
|
|
|
state: Any | None = None,
|
|
|
|
info: dict[str, Any] | None = None,
|
|
|
|
) -> tuple[torch.Tensor, Any]:
|
|
|
|
if info is None:
|
|
|
|
info = {}
|
|
|
|
return self.module.forward(obs / self.denom, state, info)
|
|
|
|
|
2022-12-04 12:23:18 -08:00
|
|
|
|
2024-04-03 18:07:51 +02:00
|
|
|
def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule:
|
2024-01-11 12:34:26 +01:00
|
|
|
return ScaledObsInputModule(module, denom=denom)
|
2022-12-04 12:23:18 -08:00
|
|
|
|
|
|
|
|
2024-04-03 18:07:51 +02:00
|
|
|
class DQN(NetBase[Any]):
|
2021-01-20 16:54:13 +08:00
|
|
|
"""Reference: Human-level control through deep reinforcement learning.
|
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
|
|
:ref:`build_the_network`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
c: int,
|
|
|
|
h: int,
|
|
|
|
w: int,
|
2024-04-03 18:07:51 +02:00
|
|
|
action_shape: Sequence[int] | int,
|
2023-09-05 23:34:23 +02:00
|
|
|
device: str | int | torch.device = "cpu",
|
2021-01-20 16:54:13 +08:00
|
|
|
features_only: bool = False,
|
2024-04-03 18:07:51 +02:00
|
|
|
output_dim_added_layer: int | None = None,
|
2022-12-04 12:23:18 -08:00
|
|
|
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
2021-01-20 16:54:13 +08:00
|
|
|
) -> None:
|
2024-04-03 18:07:51 +02:00
|
|
|
# TODO: Add docstring
|
|
|
|
if features_only and output_dim_added_layer is not None:
|
|
|
|
raise ValueError(
|
|
|
|
"Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.",
|
|
|
|
)
|
2021-01-20 16:54:13 +08:00
|
|
|
super().__init__()
|
|
|
|
self.device = device
|
|
|
|
self.net = nn.Sequential(
|
2023-08-25 23:40:56 +02:00
|
|
|
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
|
|
|
|
nn.ReLU(inplace=True),
|
2022-12-04 12:23:18 -08:00
|
|
|
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
|
2023-08-25 23:40:56 +02:00
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
nn.Flatten(),
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2021-01-20 16:54:13 +08:00
|
|
|
with torch.no_grad():
|
2024-04-03 18:07:51 +02:00
|
|
|
base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]))
|
2021-01-20 16:54:13 +08:00
|
|
|
if not features_only:
|
2024-04-03 18:07:51 +02:00
|
|
|
action_dim = int(np.prod(action_shape))
|
2021-01-20 16:54:13 +08:00
|
|
|
self.net = nn.Sequential(
|
2023-08-25 23:40:56 +02:00
|
|
|
self.net,
|
2024-04-03 18:07:51 +02:00
|
|
|
layer_init(nn.Linear(base_cnn_output_dim, 512)),
|
2022-12-04 12:23:18 -08:00
|
|
|
nn.ReLU(inplace=True),
|
2024-04-03 18:07:51 +02:00
|
|
|
layer_init(nn.Linear(512, action_dim)),
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2024-04-03 18:07:51 +02:00
|
|
|
self.output_dim = action_dim
|
|
|
|
elif output_dim_added_layer is not None:
|
2022-02-10 14:45:06 -08:00
|
|
|
self.net = nn.Sequential(
|
2023-08-25 23:40:56 +02:00
|
|
|
self.net,
|
2024-04-03 18:07:51 +02:00
|
|
|
layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)),
|
2023-08-25 23:40:56 +02:00
|
|
|
nn.ReLU(inplace=True),
|
2022-02-10 14:45:06 -08:00
|
|
|
)
|
2024-04-03 18:07:51 +02:00
|
|
|
else:
|
|
|
|
self.output_dim = base_cnn_output_dim
|
2021-01-20 16:54:13 +08:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
obs: np.ndarray | torch.Tensor,
|
|
|
|
state: Any | None = None,
|
|
|
|
info: dict[str, Any] | None = None,
|
2024-04-03 18:07:51 +02:00
|
|
|
**kwargs: Any,
|
2023-08-25 23:40:56 +02:00
|
|
|
) -> tuple[torch.Tensor, Any]:
|
2022-01-30 00:53:56 +08:00
|
|
|
r"""Mapping: s -> Q(s, \*)."""
|
|
|
|
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
|
|
|
|
return self.net(obs), state
|
2021-01-20 16:54:13 +08:00
|
|
|
|
|
|
|
|
|
|
|
class C51(DQN):
|
|
|
|
"""Reference: A distributional perspective on reinforcement learning.
|
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
|
|
:ref:`build_the_network`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
c: int,
|
|
|
|
h: int,
|
|
|
|
w: int,
|
|
|
|
action_shape: Sequence[int],
|
|
|
|
num_atoms: int = 51,
|
2023-09-05 23:34:23 +02:00
|
|
|
device: str | int | torch.device = "cpu",
|
2021-01-20 16:54:13 +08:00
|
|
|
) -> None:
|
2024-04-03 18:07:51 +02:00
|
|
|
self.action_num = int(np.prod(action_shape))
|
2021-01-28 09:27:05 +08:00
|
|
|
super().__init__(c, h, w, [self.action_num * num_atoms], device)
|
2021-01-20 16:54:13 +08:00
|
|
|
self.num_atoms = num_atoms
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
obs: np.ndarray | torch.Tensor,
|
|
|
|
state: Any | None = None,
|
|
|
|
info: dict[str, Any] | None = None,
|
2024-04-03 18:07:51 +02:00
|
|
|
**kwargs: Any,
|
2023-08-25 23:40:56 +02:00
|
|
|
) -> tuple[torch.Tensor, Any]:
|
2021-01-20 16:54:13 +08:00
|
|
|
r"""Mapping: x -> Z(x, \*)."""
|
2022-01-30 00:53:56 +08:00
|
|
|
obs, state = super().forward(obs)
|
|
|
|
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
|
|
|
|
obs = obs.view(-1, self.action_num, self.num_atoms)
|
|
|
|
return obs, state
|
2021-01-28 09:27:05 +08:00
|
|
|
|
|
|
|
|
2021-08-29 08:34:59 -07:00
|
|
|
class Rainbow(DQN):
|
|
|
|
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
|
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
|
|
:ref:`build_the_network`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
c: int,
|
|
|
|
h: int,
|
|
|
|
w: int,
|
|
|
|
action_shape: Sequence[int],
|
|
|
|
num_atoms: int = 51,
|
|
|
|
noisy_std: float = 0.5,
|
2023-09-05 23:34:23 +02:00
|
|
|
device: str | int | torch.device = "cpu",
|
2021-08-29 08:34:59 -07:00
|
|
|
is_dueling: bool = True,
|
|
|
|
is_noisy: bool = True,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(c, h, w, action_shape, device, features_only=True)
|
2024-04-03 18:07:51 +02:00
|
|
|
self.action_num = int(np.prod(action_shape))
|
2021-08-29 08:34:59 -07:00
|
|
|
self.num_atoms = num_atoms
|
|
|
|
|
2024-04-03 18:07:51 +02:00
|
|
|
def linear(x: int, y: int) -> NoisyLinear | nn.Linear:
|
2021-08-29 08:34:59 -07:00
|
|
|
if is_noisy:
|
|
|
|
return NoisyLinear(x, y, noisy_std)
|
2023-08-25 23:40:56 +02:00
|
|
|
return nn.Linear(x, y)
|
2021-08-29 08:34:59 -07:00
|
|
|
|
|
|
|
self.Q = nn.Sequential(
|
2023-08-25 23:40:56 +02:00
|
|
|
linear(self.output_dim, 512),
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
linear(512, self.action_num * self.num_atoms),
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2021-08-29 08:34:59 -07:00
|
|
|
self._is_dueling = is_dueling
|
|
|
|
if self._is_dueling:
|
|
|
|
self.V = nn.Sequential(
|
2023-08-25 23:40:56 +02:00
|
|
|
linear(self.output_dim, 512),
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
linear(512, self.num_atoms),
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2021-08-29 08:34:59 -07:00
|
|
|
self.output_dim = self.action_num * self.num_atoms
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
obs: np.ndarray | torch.Tensor,
|
|
|
|
state: Any | None = None,
|
|
|
|
info: dict[str, Any] | None = None,
|
2024-04-03 18:07:51 +02:00
|
|
|
**kwargs: Any,
|
2023-08-25 23:40:56 +02:00
|
|
|
) -> tuple[torch.Tensor, Any]:
|
2021-08-29 08:34:59 -07:00
|
|
|
r"""Mapping: x -> Z(x, \*)."""
|
2022-01-30 00:53:56 +08:00
|
|
|
obs, state = super().forward(obs)
|
|
|
|
q = self.Q(obs)
|
2021-08-29 08:34:59 -07:00
|
|
|
q = q.view(-1, self.action_num, self.num_atoms)
|
|
|
|
if self._is_dueling:
|
2022-01-30 00:53:56 +08:00
|
|
|
v = self.V(obs)
|
2021-08-29 08:34:59 -07:00
|
|
|
v = v.view(-1, 1, self.num_atoms)
|
|
|
|
logits = q - q.mean(dim=1, keepdim=True) + v
|
|
|
|
else:
|
|
|
|
logits = q
|
2022-01-30 00:53:56 +08:00
|
|
|
probs = logits.softmax(dim=2)
|
|
|
|
return probs, state
|
2021-08-29 08:34:59 -07:00
|
|
|
|
|
|
|
|
2021-01-28 09:27:05 +08:00
|
|
|
class QRDQN(DQN):
|
2023-08-25 23:40:56 +02:00
|
|
|
"""Reference: Distributional Reinforcement Learning with Quantile Regression.
|
2021-01-28 09:27:05 +08:00
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
|
|
:ref:`build_the_network`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-04-03 18:07:51 +02:00
|
|
|
*,
|
2021-01-28 09:27:05 +08:00
|
|
|
c: int,
|
|
|
|
h: int,
|
|
|
|
w: int,
|
2024-04-03 18:07:51 +02:00
|
|
|
action_shape: Sequence[int] | int,
|
2021-01-28 09:27:05 +08:00
|
|
|
num_quantiles: int = 200,
|
2023-09-05 23:34:23 +02:00
|
|
|
device: str | int | torch.device = "cpu",
|
2021-01-28 09:27:05 +08:00
|
|
|
) -> None:
|
2024-04-03 18:07:51 +02:00
|
|
|
self.action_num = int(np.prod(action_shape))
|
2021-01-28 09:27:05 +08:00
|
|
|
super().__init__(c, h, w, [self.action_num * num_quantiles], device)
|
|
|
|
self.num_quantiles = num_quantiles
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
obs: np.ndarray | torch.Tensor,
|
|
|
|
state: Any | None = None,
|
|
|
|
info: dict[str, Any] | None = None,
|
2024-04-03 18:07:51 +02:00
|
|
|
**kwargs: Any,
|
2023-08-25 23:40:56 +02:00
|
|
|
) -> tuple[torch.Tensor, Any]:
|
2021-01-28 09:27:05 +08:00
|
|
|
r"""Mapping: x -> Z(x, \*)."""
|
2022-01-30 00:53:56 +08:00
|
|
|
obs, state = super().forward(obs)
|
|
|
|
obs = obs.view(-1, self.action_num, self.num_quantiles)
|
|
|
|
return obs, state
|
2023-09-28 20:07:52 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ActorFactoryAtariDQN(ActorFactory):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
2024-04-03 18:07:51 +02:00
|
|
|
scale_obs: bool = True,
|
|
|
|
features_only: bool = False,
|
|
|
|
output_dim_added_layer: int | None = None,
|
2024-02-06 14:24:30 +01:00
|
|
|
) -> None:
|
2024-04-03 18:07:51 +02:00
|
|
|
self.output_dim_added_layer = output_dim_added_layer
|
2023-09-28 20:07:52 +02:00
|
|
|
self.scale_obs = scale_obs
|
2023-10-05 19:21:08 +02:00
|
|
|
self.features_only = features_only
|
2023-09-28 20:07:52 +02:00
|
|
|
|
2023-10-05 19:21:08 +02:00
|
|
|
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
2024-04-03 18:07:51 +02:00
|
|
|
c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3
|
|
|
|
action_shape = envs.get_action_shape()
|
|
|
|
if isinstance(action_shape, np.int64):
|
|
|
|
action_shape = int(action_shape)
|
|
|
|
net: DQN | ScaledObsInputModule
|
2024-01-11 12:34:26 +01:00
|
|
|
net = DQN(
|
2024-04-03 18:07:51 +02:00
|
|
|
c=c,
|
|
|
|
h=h,
|
|
|
|
w=w,
|
|
|
|
action_shape=action_shape,
|
2023-09-28 20:07:52 +02:00
|
|
|
device=device,
|
2023-10-05 19:21:08 +02:00
|
|
|
features_only=self.features_only,
|
2024-04-03 18:07:51 +02:00
|
|
|
output_dim_added_layer=self.output_dim_added_layer,
|
2023-09-28 20:07:52 +02:00
|
|
|
layer_init=layer_init,
|
|
|
|
)
|
2024-01-11 12:34:26 +01:00
|
|
|
if self.scale_obs:
|
|
|
|
net = scale_obs(net)
|
2023-09-28 20:07:52 +02:00
|
|
|
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
|
|
|
|
|
|
|
|
2023-10-11 15:31:38 +02:00
|
|
|
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(self, features_only: bool = False, net_only: bool = False) -> None:
|
2023-10-16 18:38:32 +02:00
|
|
|
self.features_only = features_only
|
2023-10-11 15:31:38 +02:00
|
|
|
self.net_only = net_only
|
|
|
|
|
|
|
|
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
2024-04-03 18:07:51 +02:00
|
|
|
obs_shape = envs.get_observation_shape()
|
|
|
|
if isinstance(obs_shape, int):
|
|
|
|
obs_shape = [obs_shape]
|
|
|
|
assert len(obs_shape) == 3
|
|
|
|
c, h, w = obs_shape
|
|
|
|
action_shape = envs.get_action_shape()
|
|
|
|
if isinstance(action_shape, np.int64):
|
|
|
|
action_shape = int(action_shape)
|
2023-09-28 20:07:52 +02:00
|
|
|
dqn = DQN(
|
2024-04-03 18:07:51 +02:00
|
|
|
c=c,
|
|
|
|
h=h,
|
|
|
|
w=w,
|
|
|
|
action_shape=action_shape,
|
2023-10-11 15:31:38 +02:00
|
|
|
device=device,
|
2023-10-16 18:38:32 +02:00
|
|
|
features_only=self.features_only,
|
|
|
|
).to(device)
|
|
|
|
module = dqn.net if self.net_only else dqn
|
|
|
|
return IntermediateModule(module, dqn.output_dim)
|
|
|
|
|
|
|
|
|
|
|
|
class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(self) -> None:
|
2023-10-16 18:38:32 +02:00
|
|
|
super().__init__(features_only=True, net_only=True)
|