Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
223 lines
6.8 KiB
Python
223 lines
6.8 KiB
Python
from collections.abc import Callable, Sequence
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from tianshou.utils.net.discrete import NoisyLinear
|
|
|
|
|
|
def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
|
|
torch.nn.init.orthogonal_(layer.weight, std)
|
|
torch.nn.init.constant_(layer.bias, bias_const)
|
|
return layer
|
|
|
|
|
|
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
|
|
class scaled_module(module):
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
if info is None:
|
|
info = {}
|
|
return super().forward(obs / denom, state, info)
|
|
|
|
return scaled_module
|
|
|
|
|
|
class DQN(nn.Module):
|
|
"""Reference: Human-level control through deep reinforcement learning.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
c: int,
|
|
h: int,
|
|
w: int,
|
|
action_shape: Sequence[int],
|
|
device: str | int | torch.device = "cpu",
|
|
features_only: bool = False,
|
|
output_dim: int | None = None,
|
|
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
|
) -> None:
|
|
super().__init__()
|
|
self.device = device
|
|
self.net = nn.Sequential(
|
|
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
|
|
nn.ReLU(inplace=True),
|
|
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
|
|
nn.ReLU(inplace=True),
|
|
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Flatten(),
|
|
)
|
|
with torch.no_grad():
|
|
self.output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]))
|
|
if not features_only:
|
|
self.net = nn.Sequential(
|
|
self.net,
|
|
layer_init(nn.Linear(self.output_dim, 512)),
|
|
nn.ReLU(inplace=True),
|
|
layer_init(nn.Linear(512, int(np.prod(action_shape)))),
|
|
)
|
|
self.output_dim = np.prod(action_shape)
|
|
elif output_dim is not None:
|
|
self.net = nn.Sequential(
|
|
self.net,
|
|
layer_init(nn.Linear(self.output_dim, output_dim)),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
self.output_dim = output_dim
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
r"""Mapping: s -> Q(s, \*)."""
|
|
if info is None:
|
|
info = {}
|
|
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
|
|
return self.net(obs), state
|
|
|
|
|
|
class C51(DQN):
|
|
"""Reference: A distributional perspective on reinforcement learning.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
c: int,
|
|
h: int,
|
|
w: int,
|
|
action_shape: Sequence[int],
|
|
num_atoms: int = 51,
|
|
device: str | int | torch.device = "cpu",
|
|
) -> None:
|
|
self.action_num = np.prod(action_shape)
|
|
super().__init__(c, h, w, [self.action_num * num_atoms], device)
|
|
self.num_atoms = num_atoms
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
r"""Mapping: x -> Z(x, \*)."""
|
|
if info is None:
|
|
info = {}
|
|
obs, state = super().forward(obs)
|
|
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
|
|
obs = obs.view(-1, self.action_num, self.num_atoms)
|
|
return obs, state
|
|
|
|
|
|
class Rainbow(DQN):
|
|
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
c: int,
|
|
h: int,
|
|
w: int,
|
|
action_shape: Sequence[int],
|
|
num_atoms: int = 51,
|
|
noisy_std: float = 0.5,
|
|
device: str | int | torch.device = "cpu",
|
|
is_dueling: bool = True,
|
|
is_noisy: bool = True,
|
|
) -> None:
|
|
super().__init__(c, h, w, action_shape, device, features_only=True)
|
|
self.action_num = np.prod(action_shape)
|
|
self.num_atoms = num_atoms
|
|
|
|
def linear(x, y):
|
|
if is_noisy:
|
|
return NoisyLinear(x, y, noisy_std)
|
|
return nn.Linear(x, y)
|
|
|
|
self.Q = nn.Sequential(
|
|
linear(self.output_dim, 512),
|
|
nn.ReLU(inplace=True),
|
|
linear(512, self.action_num * self.num_atoms),
|
|
)
|
|
self._is_dueling = is_dueling
|
|
if self._is_dueling:
|
|
self.V = nn.Sequential(
|
|
linear(self.output_dim, 512),
|
|
nn.ReLU(inplace=True),
|
|
linear(512, self.num_atoms),
|
|
)
|
|
self.output_dim = self.action_num * self.num_atoms
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
r"""Mapping: x -> Z(x, \*)."""
|
|
if info is None:
|
|
info = {}
|
|
obs, state = super().forward(obs)
|
|
q = self.Q(obs)
|
|
q = q.view(-1, self.action_num, self.num_atoms)
|
|
if self._is_dueling:
|
|
v = self.V(obs)
|
|
v = v.view(-1, 1, self.num_atoms)
|
|
logits = q - q.mean(dim=1, keepdim=True) + v
|
|
else:
|
|
logits = q
|
|
probs = logits.softmax(dim=2)
|
|
return probs, state
|
|
|
|
|
|
class QRDQN(DQN):
|
|
"""Reference: Distributional Reinforcement Learning with Quantile Regression.
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
:ref:`build_the_network`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
c: int,
|
|
h: int,
|
|
w: int,
|
|
action_shape: Sequence[int],
|
|
num_quantiles: int = 200,
|
|
device: str | int | torch.device = "cpu",
|
|
) -> None:
|
|
self.action_num = np.prod(action_shape)
|
|
super().__init__(c, h, w, [self.action_num * num_quantiles], device)
|
|
self.num_quantiles = num_quantiles
|
|
|
|
def forward(
|
|
self,
|
|
obs: np.ndarray | torch.Tensor,
|
|
state: Any | None = None,
|
|
info: dict[str, Any] | None = None,
|
|
) -> tuple[torch.Tensor, Any]:
|
|
r"""Mapping: x -> Z(x, \*)."""
|
|
if info is None:
|
|
info = {}
|
|
obs, state = super().forward(obs)
|
|
obs = obs.view(-1, self.action_num, self.num_quantiles)
|
|
return obs, state
|