YOPO/flightpolicy/yopo/buffers.py

247 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
The code is from stable_baseline3.
"""
from abc import ABC, abstractmethod
from gym import spaces
from typing import Any, Dict, Generator, List, Optional, Union, NamedTuple
from stable_baselines3.common.vec_env import VecNormalize
import torch as th
import numpy as np
import warnings
from stable_baselines3.common.type_aliases import (
ReplayBufferSamples,
RolloutBufferSamples,
)
try:
# Check memory used by replay buffer when possible
import psutil
except ImportError:
psutil = None
class BaseBuffer(ABC):
"""
Base class that represent a buffer (rollout or replay)
:param buffer_size: Max number of element in the buffer
:param observation_dim: Observation space
:param action_space: Action space
:param device: PyTorch device
to which the values will be converted
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_dim: int,
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
):
super(BaseBuffer, self).__init__()
self.buffer_size = buffer_size
self.observation_dim = observation_dim
self.pos = 0
self.full = False
self.device = device
self.n_envs = n_envs
@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to [n_steps * n_envs, ...] (which maintain the order)
:param arr:
:return:
"""
shape = arr.shape
if len(shape) < 3:
shape = shape + (1,)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(self, *args, **kwargs) -> None:
"""
Add elements to the buffer.
"""
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.
"""
self.pos = 0
self.full = False
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
"""
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)
@abstractmethod
def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
"""
:param batch_inds:
:param env:
:return:
"""
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data
(may be useful to avoid changing things be reference)
:return:
"""
if copy:
return th.tensor(array).to(self.device)
return th.as_tensor(array).to(self.device)
class ReplayBufferSamples(NamedTuple):
observations: th.Tensor
goals: th.Tensor
depths: th.Tensor
map_id: th.Tensor
class ReplayBuffer(BaseBuffer):
"""
self.observations
self.goals
self.depths
self.map_ids
"""
def __init__(
self,
buffer_size: int,
observation_dim: spaces.Space,
image_WxH: tuple,
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_dim, device, n_envs=n_envs)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + observation_dim, dtype=np.float32)
self.goals = np.zeros((self.buffer_size, self.n_envs, 3), dtype=np.float32)
self.depths = np.zeros((self.buffer_size, self.n_envs, 1, image_WxH[1], image_WxH[0]), dtype=np.float32)
self.map_ids = np.zeros((self.buffer_size, self.n_envs, 1), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
total_memory_usage = self.observations.nbytes + self.goals.nbytes + self.depths.nbytes + self.map_ids.nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
obs: np.ndarray,
goal: np.ndarray,
depth: np.ndarray,
map_id: int,
infos: List[Dict[str, Any]],
) -> None:
# TODO: 删了obs的格式调整检查下还能不能正常放
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
self.goals[self.pos] = np.array(goal).copy()
self.depths[self.pos] = np.array(depth).copy()
self.map_ids[self.pos] = np.array(map_id).copy()
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
as we should not sample the element with index `self.pos`
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
# Do not sample the element with index `self.pos` as the transitions is invalid
# (we use only one array to store `obs` and `next_obs`)
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds, env=env)
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
data = (
self.observations[batch_inds, env_indices, :],
self.goals[batch_inds, env_indices, :],
self.depths[batch_inds, env_indices, :],
self.map_ids[batch_inds, env_indices, :],
)
return ReplayBufferSamples(*data)