Tianshou/tianshou/utils/net/continuous.py

492 lines
17 KiB
Python
Raw Normal View History

import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
2020-03-21 10:58:01 +08:00
import numpy as np
import torch
2020-03-21 10:58:01 +08:00
from torch import nn
from tianshou.utils.net.common import MLP
2020-06-02 22:29:50 +08:00
2020-11-09 16:43:55 +08:00
SIGMA_MIN = -20
SIGMA_MAX = 2
2020-03-21 10:58:01 +08:00
class Actor(nn.Module):
"""Simple actor network. Will create an actor operated in continuous \
action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param float max_action: the scale for the final action logits. Default to
1.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
preprocess_net_output_dim: Optional[int] = None,
) -> None:
2020-03-21 10:58:01 +08:00
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self.max_action = max_action
2020-03-21 10:58:01 +08:00
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
logits, hidden = self.preprocess(obs, state)
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden
class Critic(nn.Module):
"""Simple critic network. Will create an actor operated in continuous \
action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
:param linear_layer: use this module as linear layer. Default to nn.Linear.
:param bool flatten_input: whether to flatten input data for the last layer.
Default to True.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
hidden_sizes: Sequence[int] = (),
device: Union[str, int, torch.device] = "cpu",
preprocess_net_output_dim: Optional[int] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
flatten_input: bool = True,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
1,
hidden_sizes,
device=self.device,
linear_layer=linear_layer,
flatten_input=flatten_input,
)
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
obs = torch.as_tensor(
obs,
Hindsight Experience Replay as a replay buffer (#753) ## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). ![Screen Shot 2022-10-02 at 19 22 53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-31 08:54:54 +09:00
device=self.device,
dtype=torch.float32,
).flatten(1)
if act is not None:
act = torch.as_tensor(
act,
Hindsight Experience Replay as a replay buffer (#753) ## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). ![Screen Shot 2022-10-02 at 19 22 53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-31 08:54:54 +09:00
device=self.device,
dtype=torch.float32,
).flatten(1)
obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs)
logits = self.last(logits)
return logits
2020-03-21 10:58:01 +08:00
2020-03-21 17:04:42 +08:00
class ActorProb(nn.Module):
"""Simple actor network (output with a Gauss distribution).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param float max_action: the scale for the final action logits. Default to
1.
:param bool unbounded: whether to apply tanh activation on final logits.
Default to False.
:param bool conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
2020-11-09 16:43:55 +08:00
conditioned_sigma: bool = False,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
2020-03-21 17:04:42 +08:00
super().__init__()
if unbounded and not np.isclose(max_action, 1.0):
warnings.warn(
"Note that max_action input will be discarded when unbounded is True."
)
max_action = 1.0
self.preprocess = preprocess_net
2020-03-21 17:04:42 +08:00
self.device = device
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.mu = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
2020-11-09 16:43:55 +08:00
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
2020-11-09 16:43:55 +08:00
else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
self.max_action = max_action
self._unbounded = unbounded
2020-03-21 17:04:42 +08:00
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
"""Mapping: obs -> logits -> (mu, sigma)."""
logits, hidden = self.preprocess(obs, state)
2020-04-19 14:30:42 +08:00
mu = self.mu(logits)
if not self._unbounded:
mu = self.max_action * torch.tanh(mu)
2020-11-09 16:43:55 +08:00
if self._c_sigma:
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
2020-11-09 16:43:55 +08:00
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
return (mu, sigma), state
2020-03-21 17:04:42 +08:00
2020-04-30 16:31:40 +08:00
class RecurrentActorProb(nn.Module):
"""Recurrent version of ActorProb.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int],
hidden_layer_size: int = 128,
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
2020-11-09 16:43:55 +08:00
conditioned_sigma: bool = False,
) -> None:
2020-04-30 16:31:40 +08:00
super().__init__()
if unbounded and not np.isclose(max_action, 1.0):
warnings.warn(
"Note that max_action input will be discarded when unbounded is True."
)
max_action = 1.0
2020-04-30 16:31:40 +08:00
self.device = device
self.nn = nn.LSTM(
input_size=int(np.prod(state_shape)),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
output_dim = int(np.prod(action_shape))
self.mu = nn.Linear(hidden_layer_size, output_dim)
2020-11-09 16:43:55 +08:00
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = nn.Linear(hidden_layer_size, output_dim)
2020-11-09 16:43:55 +08:00
else:
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
self.max_action = max_action
self._unbounded = unbounded
2020-04-30 16:31:40 +08:00
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
obs = torch.as_tensor(
obs,
Hindsight Experience Replay as a replay buffer (#753) ## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). ![Screen Shot 2022-10-02 at 19 22 53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-31 08:54:54 +09:00
device=self.device,
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
2020-04-30 16:31:40 +08:00
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(obs.shape) == 2:
obs = obs.unsqueeze(-2)
self.nn.flatten_parameters()
if state is None:
obs, (hidden, cell) = self.nn(obs)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
obs, (hidden, cell) = self.nn(
obs, (
state["hidden"].transpose(0, 1).contiguous(),
state["cell"].transpose(0, 1).contiguous()
)
)
logits = obs[:, -1]
2020-04-30 16:31:40 +08:00
mu = self.mu(logits)
if not self._unbounded:
mu = self.max_action * torch.tanh(mu)
2020-11-09 16:43:55 +08:00
if self._c_sigma:
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
2020-11-09 16:43:55 +08:00
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), {
"hidden": hidden.transpose(0, 1).detach(),
"cell": cell.transpose(0, 1).detach()
}
2020-04-30 16:31:40 +08:00
class RecurrentCritic(nn.Module):
"""Recurrent version of Critic.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int] = [0],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
2020-04-30 16:31:40 +08:00
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(
input_size=int(np.prod(state_shape)),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
2020-04-30 16:31:40 +08:00
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
act: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
obs = torch.as_tensor(
obs,
Hindsight Experience Replay as a replay buffer (#753) ## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). ![Screen Shot 2022-10-02 at 19 22 53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-31 08:54:54 +09:00
device=self.device,
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
2020-04-30 16:31:40 +08:00
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
assert len(obs.shape) == 3
2020-04-30 16:31:40 +08:00
self.nn.flatten_parameters()
obs, (hidden, cell) = self.nn(obs)
obs = obs[:, -1]
if act is not None:
act = torch.as_tensor(
act,
Hindsight Experience Replay as a replay buffer (#753) ## implementation I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (`self._meta`) during the `sample_indices()` call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy. There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer. In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency. I also add a few more convenient replay buffers (`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env (`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py). ## verification I have added unit tests for almost everything I have implemented. HER replay buffer was also tested using DDPG on [`FetchReach-v3` env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py). ![Screen Shot 2022-10-02 at 19 22 53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-31 08:54:54 +09:00
device=self.device,
dtype=torch.float32,
)
obs = torch.cat([obs, act], dim=1)
obs = self.fc2(obs)
return obs
class Perturbation(nn.Module):
"""Implementation of perturbation network in BCQ algorithm. Given a state and \
action, it can generate perturbed action.
:param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param float max_action: the maximum value of each dimension of action.
:param Union[str, int, torch.device] device: which device to create this model on.
Default to cpu.
:param float phi: max perturbation parameter for BCQ. Default to 0.05.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
"""
def __init__(
self,
preprocess_net: nn.Module,
max_action: float,
device: Union[str, int, torch.device] = "cpu",
phi: float = 0.05
):
# preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
super(Perturbation, self).__init__()
self.preprocess_net = preprocess_net
self.device = device
self.max_action = max_action
self.phi = phi
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
# preprocess_net
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
noise = self.phi * self.max_action * torch.tanh(logits)
# clip to [-max_action, max_action]
return (noise + action).clamp(-self.max_action, self.max_action)
class VAE(nn.Module):
"""Implementation of VAE. It models the distribution of action. Given a \
state, it can generate actions similar to those in batch. It is used \
in BCQ algorithm.
:param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be
state_dim + action_dim, and output_dim must be hidden_dim.
:param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be
state_dim + latent_dim, and output_dim must be action_dim.
:param int hidden_dim: the size of the last linear-layer in encoder.
:param int latent_dim: the size of latent layer.
:param float max_action: the maximum value of each dimension of action.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
hidden_dim: int,
latent_dim: int,
max_action: float,
device: Union[str, torch.device] = "cpu"
):
super(VAE, self).__init__()
self.encoder = encoder
self.mean = nn.Linear(hidden_dim, latent_dim)
self.log_std = nn.Linear(hidden_dim, latent_dim)
self.decoder = decoder
self.max_action = max_action
self.latent_dim = latent_dim
self.device = device
def forward(
self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [state, action] -> z , [state, z] -> action
latent_z = self.encoder(torch.cat([state, action], -1))
# shape of z: (state.shape[:-1], hidden_dim)
mean = self.mean(latent_z)
# Clamped for numerical stability
log_std = self.log_std(latent_z).clamp(-4, 15)
std = torch.exp(log_std)
# shape of mean, std: (state.shape[:-1], latent_dim)
latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
return reconstruction, mean, std
def decode(
self,
state: torch.Tensor,
latent_z: Union[torch.Tensor, None] = None
) -> torch.Tensor:
# decode(state) -> action
if latent_z is None:
# state.shape[0] may be batch_size
# latent vector clipped to [-0.5, 0.5]
latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
.to(self.device).clamp(-0.5, 0.5)
# decode z with state!
return self.max_action * \
torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))