2021-03-02 12:28:28 +08:00
|
|
|
import numpy as np
|
|
|
|
from numba import njit
|
|
|
|
from typing import List, Tuple, Union, Sequence, Optional
|
|
|
|
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
|
|
|
from tianshou.data.batch import _create_value, _alloc_by_keys_diff
|
|
|
|
|
|
|
|
|
|
|
|
class ReplayBufferManager(ReplayBuffer):
|
|
|
|
"""ReplayBufferManager contains a list of ReplayBuffer with exactly the same \
|
|
|
|
configuration.
|
|
|
|
|
|
|
|
These replay buffers have contiguous memory layout, and the storage space each
|
|
|
|
buffer has is a shallow copy of the topmost memory.
|
|
|
|
|
|
|
|
:param buffer_list: a list of ReplayBuffer needed to be handled.
|
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
|
|
|
|
self.buffer_num = len(buffer_list)
|
|
|
|
self.buffers = np.array(buffer_list, dtype=np.object)
|
|
|
|
offset, size = [], 0
|
|
|
|
buffer_type = type(self.buffers[0])
|
|
|
|
kwargs = self.buffers[0].options
|
|
|
|
for buf in self.buffers:
|
|
|
|
assert buf._meta.is_empty()
|
|
|
|
assert isinstance(buf, buffer_type) and buf.options == kwargs
|
|
|
|
offset.append(size)
|
|
|
|
size += buf.maxsize
|
|
|
|
self._offset = np.array(offset)
|
|
|
|
self._extend_offset = np.array(offset + [size])
|
|
|
|
self._lengths = np.zeros_like(offset)
|
|
|
|
super().__init__(size=size, **kwargs)
|
|
|
|
self._compile()
|
|
|
|
self._meta: Batch
|
|
|
|
|
|
|
|
def _compile(self) -> None:
|
|
|
|
lens = last = index = np.array([0])
|
|
|
|
offset = np.array([0, 1])
|
|
|
|
done = np.array([False, False])
|
|
|
|
_prev_index(index, offset, done, last, lens)
|
|
|
|
_next_index(index, offset, done, last, lens)
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return self._lengths.sum()
|
|
|
|
|
2021-03-27 16:58:48 +08:00
|
|
|
def reset(self, keep_statistics: bool = False) -> None:
|
2021-03-02 12:28:28 +08:00
|
|
|
self.last_index = self._offset.copy()
|
|
|
|
self._lengths = np.zeros_like(self._offset)
|
|
|
|
for buf in self.buffers:
|
2021-03-27 16:58:48 +08:00
|
|
|
buf.reset(keep_statistics=keep_statistics)
|
2021-03-02 12:28:28 +08:00
|
|
|
|
|
|
|
def _set_batch_for_children(self) -> None:
|
|
|
|
for offset, buf in zip(self._offset, self.buffers):
|
|
|
|
buf.set_batch(self._meta[offset:offset + buf.maxsize])
|
|
|
|
|
|
|
|
def set_batch(self, batch: Batch) -> None:
|
|
|
|
super().set_batch(batch)
|
|
|
|
self._set_batch_for_children()
|
|
|
|
|
|
|
|
def unfinished_index(self) -> np.ndarray:
|
|
|
|
return np.concatenate([
|
|
|
|
buf.unfinished_index() + offset
|
|
|
|
for offset, buf in zip(self._offset, self.buffers)
|
|
|
|
])
|
|
|
|
|
|
|
|
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
|
|
|
if isinstance(index, (list, np.ndarray)):
|
|
|
|
return _prev_index(np.asarray(index), self._extend_offset,
|
|
|
|
self.done, self.last_index, self._lengths)
|
|
|
|
else:
|
|
|
|
return _prev_index(np.array([index]), self._extend_offset,
|
|
|
|
self.done, self.last_index, self._lengths)[0]
|
|
|
|
|
|
|
|
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
|
|
|
if isinstance(index, (list, np.ndarray)):
|
|
|
|
return _next_index(np.asarray(index), self._extend_offset,
|
|
|
|
self.done, self.last_index, self._lengths)
|
|
|
|
else:
|
|
|
|
return _next_index(np.array([index]), self._extend_offset,
|
|
|
|
self.done, self.last_index, self._lengths)[0]
|
|
|
|
|
|
|
|
def update(self, buffer: ReplayBuffer) -> np.ndarray:
|
|
|
|
"""The ReplayBufferManager cannot be updated by any buffer."""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def add(
|
|
|
|
self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
|
|
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
|
|
"""Add a batch of data into ReplayBufferManager.
|
|
|
|
|
|
|
|
Each of the data's length (first dimension) must equal to the length of
|
|
|
|
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
|
|
|
|
|
|
|
|
Return (current_index, episode_reward, episode_length, episode_start_index). If
|
|
|
|
the episode is not finished, the return value of episode_length and
|
|
|
|
episode_reward is 0.
|
|
|
|
"""
|
|
|
|
# preprocess batch
|
|
|
|
b = Batch()
|
|
|
|
for key in set(self._reserved_keys).intersection(batch.keys()):
|
|
|
|
b.__dict__[key] = batch[key]
|
|
|
|
batch = b
|
|
|
|
assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
|
|
|
|
if self._save_only_last_obs:
|
|
|
|
batch.obs = batch.obs[:, -1]
|
|
|
|
if not self._save_obs_next:
|
|
|
|
batch.pop("obs_next", None)
|
|
|
|
elif self._save_only_last_obs:
|
|
|
|
batch.obs_next = batch.obs_next[:, -1]
|
|
|
|
# get index
|
|
|
|
if buffer_ids is None:
|
|
|
|
buffer_ids = np.arange(self.buffer_num)
|
|
|
|
ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], []
|
|
|
|
for batch_idx, buffer_id in enumerate(buffer_ids):
|
|
|
|
ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index(
|
|
|
|
batch.rew[batch_idx], batch.done[batch_idx]
|
|
|
|
)
|
|
|
|
ptrs.append(ptr + self._offset[buffer_id])
|
|
|
|
ep_lens.append(ep_len)
|
|
|
|
ep_rews.append(ep_rew)
|
|
|
|
ep_idxs.append(ep_idx + self._offset[buffer_id])
|
|
|
|
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
|
|
|
|
self._lengths[buffer_id] = len(self.buffers[buffer_id])
|
|
|
|
ptrs = np.array(ptrs)
|
|
|
|
try:
|
|
|
|
self._meta[ptrs] = batch
|
|
|
|
except ValueError:
|
|
|
|
batch.rew = batch.rew.astype(np.float)
|
|
|
|
batch.done = batch.done.astype(np.bool_)
|
|
|
|
if self._meta.is_empty():
|
|
|
|
self._meta = _create_value( # type: ignore
|
|
|
|
batch, self.maxsize, stack=False)
|
|
|
|
else: # dynamic key pops up in batch
|
|
|
|
_alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
|
|
|
|
self._set_batch_for_children()
|
|
|
|
self._meta[ptrs] = batch
|
|
|
|
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
|
|
|
|
|
|
|
|
def sample_index(self, batch_size: int) -> np.ndarray:
|
|
|
|
if batch_size < 0:
|
|
|
|
return np.array([], np.int)
|
|
|
|
if self._sample_avail and self.stack_num > 1:
|
|
|
|
all_indices = np.concatenate([
|
|
|
|
buf.sample_index(0) + offset
|
|
|
|
for offset, buf in zip(self._offset, self.buffers)
|
|
|
|
])
|
|
|
|
if batch_size == 0:
|
|
|
|
return all_indices
|
|
|
|
else:
|
|
|
|
return np.random.choice(all_indices, batch_size)
|
|
|
|
if batch_size == 0: # get all available indices
|
|
|
|
sample_num = np.zeros(self.buffer_num, np.int)
|
|
|
|
else:
|
|
|
|
buffer_idx = np.random.choice(
|
|
|
|
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
|
|
|
|
)
|
|
|
|
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
|
|
|
|
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
|
|
|
|
sample_num[sample_num == 0] = -1
|
|
|
|
|
|
|
|
return np.concatenate([
|
|
|
|
buf.sample_index(bsz) + offset
|
|
|
|
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager):
|
|
|
|
"""PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with \
|
|
|
|
exactly the same configuration.
|
|
|
|
|
|
|
|
These replay buffers have contiguous memory layout, and the storage space each
|
|
|
|
buffer has is a shallow copy of the topmost memory.
|
|
|
|
|
|
|
|
:param buffer_list: a list of PrioritizedReplayBuffer needed to be handled.
|
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None:
|
|
|
|
ReplayBufferManager.__init__(self, buffer_list) # type: ignore
|
|
|
|
kwargs = buffer_list[0].options
|
|
|
|
for buf in buffer_list:
|
|
|
|
del buf.weight
|
|
|
|
PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@njit
|
|
|
|
def _prev_index(
|
|
|
|
index: np.ndarray,
|
|
|
|
offset: np.ndarray,
|
|
|
|
done: np.ndarray,
|
|
|
|
last_index: np.ndarray,
|
|
|
|
lengths: np.ndarray,
|
|
|
|
) -> np.ndarray:
|
|
|
|
index = index % offset[-1]
|
|
|
|
prev_index = np.zeros_like(index)
|
|
|
|
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
|
|
|
|
mask = (start <= index) & (index < end)
|
|
|
|
cur_len = max(1, cur_len)
|
|
|
|
if np.sum(mask) > 0:
|
|
|
|
subind = index[mask]
|
|
|
|
subind = (subind - start - 1) % cur_len
|
|
|
|
end_flag = done[subind + start] | (subind + start == last)
|
|
|
|
prev_index[mask] = (subind + end_flag) % cur_len + start
|
|
|
|
return prev_index
|
|
|
|
|
|
|
|
|
|
|
|
@njit
|
|
|
|
def _next_index(
|
|
|
|
index: np.ndarray,
|
|
|
|
offset: np.ndarray,
|
|
|
|
done: np.ndarray,
|
|
|
|
last_index: np.ndarray,
|
|
|
|
lengths: np.ndarray,
|
|
|
|
) -> np.ndarray:
|
|
|
|
index = index % offset[-1]
|
|
|
|
next_index = np.zeros_like(index)
|
|
|
|
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
|
|
|
|
mask = (start <= index) & (index < end)
|
|
|
|
cur_len = max(1, cur_len)
|
|
|
|
if np.sum(mask) > 0:
|
|
|
|
subind = index[mask]
|
|
|
|
end_flag = done[subind] | (subind == last)
|
|
|
|
next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
|
|
|
|
return next_index
|