Improve buffer.prev() & buffer.next() (#294)

This commit is contained in:
Trinkle23897 2021-02-22 19:19:22 +08:00
parent 7036073649
commit e99e1b0fdd
3 changed files with 102 additions and 27 deletions

View File

@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.a[1] == 4 and b.info.b.c[1] == 0
assert b.info.d.e[1] == -np.inf
# test batch-style adding method, where len(batch) == 1
batch.done = 1
batch.done = [1]
batch.info.e = np.zeros([1, 4])
batch = Batch.stack([batch])
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.e.shape == (b.maxsize, 1, 4)
with pytest.raises(IndexError):
b[22]
# test prev / next
assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
batch.done = [0]
b.add(batch, buffer_ids=[0])
assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])
def test_ignore_obs_next(size=10):
@ -718,7 +725,6 @@ if __name__ == '__main__':
test_stack()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()
test_pickle()
test_hdf5()

View File

@ -1,6 +1,7 @@
import h5py
import torch
import numpy as np
from numba import njit
from typing import Any, Dict, List, Tuple, Union, Sequence, Optional
from tianshou.data.batch import _create_value
@ -116,6 +117,7 @@ class ReplayBuffer:
def reset(self) -> None:
"""Clear all the data in replay buffer and episode statistics."""
self.last_index = np.array([0])
self._index = self._size = 0
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
@ -137,7 +139,7 @@ class ReplayBuffer:
The index won't be modified if it is the beginning of an episode.
"""
index = (index - 1) % self._size
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
end_flag = self.done[index] | (index == self.last_index[0])
return (index + end_flag) % self._size
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
@ -145,7 +147,7 @@ class ReplayBuffer:
The index won't be modified if it is the end of an episode.
"""
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size
def update(self, buffer: "ReplayBuffer") -> np.ndarray:
@ -163,6 +165,7 @@ class ReplayBuffer:
to_indices = []
for _ in range(len(from_indices)):
to_indices.append(self._index)
self.last_index[0] = self._index
self._index = (self._index + 1) % self.maxsize
self._size = min(self._size + 1, self.maxsize)
to_indices = np.array(to_indices)
@ -180,7 +183,7 @@ class ReplayBuffer:
Return (index_to_be_modified, episode_reward, episode_length,
episode_start_index).
"""
ptr = self._index
self.last_index[0] = ptr = self._index
self._size = min(self._size + 1, self.maxsize)
self._index = (self._index + 1) % self.maxsize
@ -296,6 +299,13 @@ class ReplayBuffer:
"""Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index.
:param index: the index for getting stacked data (t in the example).
:param str key: the key to get, should be one of the reserved_keys.
:param default_value: if the given key's data is not found and default_value is
set, return this default_value.
:param int stack_num: the stack num (4 in the example). Default to
self.stack_num.
"""
if key not in self._meta and default_value is not None:
return default_value
@ -306,7 +316,10 @@ class ReplayBuffer:
if stack_num == 1: # the most often case
return val[index]
stack: List[Any] = []
indice = np.asarray(index)
if isinstance(index, list):
indice = np.array(index)
else:
indice = index
for _ in range(stack_num):
stack = [val[indice]] + stack
indice = self.prev(indice)
@ -453,12 +466,24 @@ class ReplayBufferManager(ReplayBuffer):
offset.append(size)
size += buf.maxsize
self._offset = np.array(offset)
self._extend_offset = np.array(offset + [size])
self._lengths = np.zeros_like(offset)
super().__init__(size=size, **kwargs)
self._compile()
def _compile(self) -> None:
lens = last = index = np.array([0])
offset = np.array([0, 1])
done = np.array([False, False])
_prev_index(index, offset, done, last, lens)
_next_index(index, offset, done, last, lens)
def __len__(self) -> int:
return sum([len(buf) for buf in self.buffers])
return self._lengths.sum()
def reset(self) -> None:
self.last_index = self._offset.copy()
self._lengths = np.zeros_like(self._offset)
for buf in self.buffers:
buf.reset()
@ -477,22 +502,20 @@ class ReplayBufferManager(ReplayBuffer):
])
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
prev_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
return prev_indices
if isinstance(index, (list, np.ndarray)):
return _prev_index(np.asarray(index), self._extend_offset,
self.done, self.last_index, self._lengths)
else:
return _prev_index(np.array([index]), self._extend_offset,
self.done, self.last_index, self._lengths)[0]
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
next_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
next_indices[mask] = buf.next(index[mask] - offset) + offset
return next_indices
if isinstance(index, (list, np.ndarray)):
return _next_index(np.asarray(index), self._extend_offset,
self.done, self.last_index, self._lengths)
else:
return _next_index(np.array([index]), self._extend_offset,
self.done, self.last_index, self._lengths)[0]
def update(self, buffer: ReplayBuffer) -> np.ndarray:
"""The ReplayBufferManager cannot be updated by any buffer."""
@ -534,6 +557,8 @@ class ReplayBufferManager(ReplayBuffer):
ep_lens.append(ep_len)
ep_rews.append(ep_rew)
ep_idxs.append(ep_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
self._lengths[buffer_id] = len(self.buffers[buffer_id])
ptrs = np.array(ptrs)
try:
self._meta[ptrs] = batch
@ -564,9 +589,8 @@ class ReplayBufferManager(ReplayBuffer):
if batch_size == 0: # get all available indices
sample_num = np.zeros(self.buffer_num, np.int)
else:
buffer_lens = np.array([len(buf) for buf in self.buffers])
buffer_idx = np.random.choice(
self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum()
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
)
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
@ -726,6 +750,51 @@ class CachedReplayBuffer(ReplayBufferManager):
updated_ep_idx.append(index[0])
updated_ptr.append(index[-1])
self.buffers[buffer_idx].reset()
self._lengths[0] = len(self.main_buffer)
self._lengths[buffer_idx] = 0
self.last_index[0] = index[-1]
self.last_index[buffer_idx] = self._offset[buffer_idx]
ptr[done] = updated_ptr
ep_idx[done] = updated_ep_idx
return ptr, ep_rew, ep_len, ep_idx
@njit
def _prev_index(
index: np.ndarray,
offset: np.ndarray,
done: np.ndarray,
last_index: np.ndarray,
lengths: np.ndarray,
) -> np.ndarray:
index = index % offset[-1]
prev_index = np.zeros_like(index)
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
mask = (start <= index) & (index < end)
cur_len = max(1, cur_len)
if np.sum(mask) > 0:
subind = index[mask]
subind = (subind - start - 1) % cur_len
end_flag = done[subind + start] | (subind + start == last)
prev_index[mask] = (subind + end_flag) % cur_len + start
return prev_index
@njit
def _next_index(
index: np.ndarray,
offset: np.ndarray,
done: np.ndarray,
last_index: np.ndarray,
lengths: np.ndarray,
) -> np.ndarray:
index = index % offset[-1]
next_index = np.zeros_like(index)
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
mask = (start <= index) & (index < end)
cur_len = max(1, cur_len)
if np.sum(mask) > 0:
subind = index[mask]
end_flag = done[subind] | (subind == last)
next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
return next_index

View File

@ -220,12 +220,12 @@ class BasePolicy(ABC, nn.Module):
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch.
:param Batch batch: a data batch which contains several episodes of data
in sequential order. Mind that the end of each finished episode of batch
:param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be
recongized by buffer.unfinished_index().
:param np.ndarray indice: tell batch's location in buffer, batch is
equal to buffer[indice].
:param numpy.ndarray indice: tell batch's location in buffer, batch is equal to
buffer[indice].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation,