247 lines
8.2 KiB
Python
247 lines
8.2 KiB
Python
"""
|
||
The code is from stable_baseline3.
|
||
"""
|
||
from abc import ABC, abstractmethod
|
||
from gym import spaces
|
||
from typing import Any, Dict, Generator, List, Optional, Union, NamedTuple
|
||
from stable_baselines3.common.vec_env import VecNormalize
|
||
import torch as th
|
||
import numpy as np
|
||
import warnings
|
||
from stable_baselines3.common.type_aliases import (
|
||
ReplayBufferSamples,
|
||
RolloutBufferSamples,
|
||
)
|
||
|
||
try:
|
||
# Check memory used by replay buffer when possible
|
||
import psutil
|
||
except ImportError:
|
||
psutil = None
|
||
|
||
|
||
class BaseBuffer(ABC):
|
||
"""
|
||
Base class that represent a buffer (rollout or replay)
|
||
|
||
:param buffer_size: Max number of element in the buffer
|
||
:param observation_dim: Observation space
|
||
:param action_space: Action space
|
||
:param device: PyTorch device
|
||
to which the values will be converted
|
||
:param n_envs: Number of parallel environments
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
buffer_size: int,
|
||
observation_dim: int,
|
||
device: Union[th.device, str] = "cpu",
|
||
n_envs: int = 1,
|
||
):
|
||
super(BaseBuffer, self).__init__()
|
||
self.buffer_size = buffer_size
|
||
self.observation_dim = observation_dim
|
||
|
||
self.pos = 0
|
||
self.full = False
|
||
self.device = device
|
||
self.n_envs = n_envs
|
||
|
||
@staticmethod
|
||
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
|
||
"""
|
||
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
|
||
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
|
||
to [n_steps * n_envs, ...] (which maintain the order)
|
||
|
||
:param arr:
|
||
:return:
|
||
"""
|
||
shape = arr.shape
|
||
if len(shape) < 3:
|
||
shape = shape + (1,)
|
||
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
|
||
|
||
def size(self) -> int:
|
||
"""
|
||
:return: The current size of the buffer
|
||
"""
|
||
if self.full:
|
||
return self.buffer_size
|
||
return self.pos
|
||
|
||
def add(self, *args, **kwargs) -> None:
|
||
"""
|
||
Add elements to the buffer.
|
||
"""
|
||
raise NotImplementedError()
|
||
|
||
def extend(self, *args, **kwargs) -> None:
|
||
"""
|
||
Add a new batch of transitions to the buffer
|
||
"""
|
||
# Do a for loop along the batch axis
|
||
for data in zip(*args):
|
||
self.add(*data)
|
||
|
||
def reset(self) -> None:
|
||
"""
|
||
Reset the buffer.
|
||
"""
|
||
self.pos = 0
|
||
self.full = False
|
||
|
||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
|
||
"""
|
||
:param batch_size: Number of element to sample
|
||
:param env: associated gym VecEnv
|
||
to normalize the observations/rewards when sampling
|
||
:return:
|
||
"""
|
||
upper_bound = self.buffer_size if self.full else self.pos
|
||
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
|
||
return self._get_samples(batch_inds, env=env)
|
||
|
||
@abstractmethod
|
||
def _get_samples(
|
||
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
|
||
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
|
||
"""
|
||
:param batch_inds:
|
||
:param env:
|
||
:return:
|
||
"""
|
||
raise NotImplementedError()
|
||
|
||
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
|
||
"""
|
||
Convert a numpy array to a PyTorch tensor.
|
||
Note: it copies the data by default
|
||
|
||
:param array:
|
||
:param copy: Whether to copy or not the data
|
||
(may be useful to avoid changing things be reference)
|
||
:return:
|
||
"""
|
||
if copy:
|
||
return th.tensor(array).to(self.device)
|
||
return th.as_tensor(array).to(self.device)
|
||
|
||
|
||
class ReplayBufferSamples(NamedTuple):
|
||
observations: th.Tensor
|
||
goals: th.Tensor
|
||
depths: th.Tensor
|
||
map_id: th.Tensor
|
||
|
||
|
||
class ReplayBuffer(BaseBuffer):
|
||
"""
|
||
self.observations
|
||
self.goals
|
||
self.depths
|
||
self.map_ids
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
buffer_size: int,
|
||
observation_dim: spaces.Space,
|
||
image_WxH: tuple,
|
||
device: Union[th.device, str] = "cpu",
|
||
n_envs: int = 1,
|
||
optimize_memory_usage: bool = False,
|
||
handle_timeout_termination: bool = True,
|
||
):
|
||
super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs)
|
||
|
||
# Adjust buffer size
|
||
self.buffer_size = max(buffer_size // n_envs, 1)
|
||
|
||
# Check that the replay buffer can fit into the memory
|
||
if psutil is not None:
|
||
mem_available = psutil.virtual_memory().available
|
||
|
||
self.optimize_memory_usage = optimize_memory_usage
|
||
|
||
self.observations = np.zeros((self.buffer_size, self.n_envs) + observation_dim, dtype=np.float32)
|
||
self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32)
|
||
self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32)
|
||
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.float32)
|
||
|
||
# Handle timeouts termination properly if needed
|
||
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
||
self.handle_timeout_termination = handle_timeout_termination
|
||
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||
|
||
if psutil is not None:
|
||
total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes
|
||
|
||
if total_memory_usage > mem_available:
|
||
# Convert to GB
|
||
total_memory_usage /= 1e9
|
||
mem_available /= 1e9
|
||
warnings.warn(
|
||
"This system does not have apparently enough memory to store the complete "
|
||
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
||
)
|
||
|
||
def add(
|
||
self,
|
||
obs: np.ndarray,
|
||
goal: np.ndarray,
|
||
depth: np.ndarray,
|
||
map_id: int,
|
||
infos: List[Dict[str, Any]],
|
||
) -> None:
|
||
|
||
# TODO: 删了obs的格式调整,检查下还能不能正常放
|
||
|
||
# Copy to avoid modification by reference
|
||
self.observations[self.pos] = np.array(obs).copy()
|
||
self.goals[self.pos] = np.array(goal).copy()
|
||
self.depths[self.pos] = np.array(depth).copy()
|
||
self.map_ids[self.pos] = np.array(map_id).copy()
|
||
|
||
if self.handle_timeout_termination:
|
||
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
||
|
||
self.pos += 1
|
||
if self.pos == self.buffer_size:
|
||
self.full = True
|
||
self.pos = 0
|
||
|
||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
||
"""
|
||
Sample elements from the replay buffer.
|
||
Custom sampling when using memory efficient variant,
|
||
as we should not sample the element with index `self.pos`
|
||
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
||
|
||
:param batch_size: Number of element to sample
|
||
:param env: associated gym VecEnv
|
||
to normalize the observations/rewards when sampling
|
||
:return:
|
||
"""
|
||
if not self.optimize_memory_usage:
|
||
return super().sample(batch_size=batch_size, env=env)
|
||
# Do not sample the element with index `self.pos` as the transitions is invalid
|
||
# (we use only one array to store `obs` and `next_obs`)
|
||
if self.full:
|
||
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
|
||
else:
|
||
batch_inds = np.random.randint(0, self.pos, size=batch_size)
|
||
return self._get_samples(batch_inds, env=env)
|
||
|
||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
||
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
|
||
|
||
data = (
|
||
self.observations[batch_inds, env_indices, :],
|
||
self.goals[batch_inds, env_indices, :],
|
||
self.depths[batch_inds, env_indices, :],
|
||
self.map_ids[batch_inds, env_indices, :],
|
||
)
|
||
return ReplayBufferSamples(*data)
|