Improve buffer.prev() & buffer.next() (#294)

This commit is contained in:
Trinkle23897 2021-02-22 19:19:22 +08:00
parent 7036073649
commit e99e1b0fdd
3 changed files with 102 additions and 27 deletions

View File

@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.a[1] == 4 and b.info.b.c[1] == 0
assert b.info.d.e[1] == -np.inf assert b.info.d.e[1] == -np.inf
# test batch-style adding method, where len(batch) == 1 # test batch-style adding method, where len(batch) == 1
batch.done = 1 batch.done = [1]
batch.info.e = np.zeros([1, 4]) batch.info.e = np.zeros([1, 4])
batch = Batch.stack([batch]) batch = Batch.stack([batch])
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.e.shape == (b.maxsize, 1, 4) assert b.info.e.shape == (b.maxsize, 1, 4)
with pytest.raises(IndexError): with pytest.raises(IndexError):
b[22] b[22]
# test prev / next
assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
batch.done = [0]
b.add(batch, buffer_ids=[0])
assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])
def test_ignore_obs_next(size=10): def test_ignore_obs_next(size=10):
@ -718,7 +725,6 @@ if __name__ == '__main__':
test_stack() test_stack()
test_segtree() test_segtree()
test_priortized_replaybuffer() test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update() test_update()
test_pickle() test_pickle()
test_hdf5() test_hdf5()

View File

@ -1,6 +1,7 @@
import h5py import h5py
import torch import torch
import numpy as np import numpy as np
from numba import njit
from typing import Any, Dict, List, Tuple, Union, Sequence, Optional from typing import Any, Dict, List, Tuple, Union, Sequence, Optional
from tianshou.data.batch import _create_value from tianshou.data.batch import _create_value
@ -116,6 +117,7 @@ class ReplayBuffer:
def reset(self) -> None: def reset(self) -> None:
"""Clear all the data in replay buffer and episode statistics.""" """Clear all the data in replay buffer and episode statistics."""
self.last_index = np.array([0])
self._index = self._size = 0 self._index = self._size = 0
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
@ -137,7 +139,7 @@ class ReplayBuffer:
The index won't be modified if it is the beginning of an episode. The index won't be modified if it is the beginning of an episode.
""" """
index = (index - 1) % self._size index = (index - 1) % self._size
end_flag = self.done[index] | np.isin(index, self.unfinished_index()) end_flag = self.done[index] | (index == self.last_index[0])
return (index + end_flag) % self._size return (index + end_flag) % self._size
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
@ -145,7 +147,7 @@ class ReplayBuffer:
The index won't be modified if it is the end of an episode. The index won't be modified if it is the end of an episode.
""" """
end_flag = self.done[index] | np.isin(index, self.unfinished_index()) end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size return (index + (1 - end_flag)) % self._size
def update(self, buffer: "ReplayBuffer") -> np.ndarray: def update(self, buffer: "ReplayBuffer") -> np.ndarray:
@ -163,6 +165,7 @@ class ReplayBuffer:
to_indices = [] to_indices = []
for _ in range(len(from_indices)): for _ in range(len(from_indices)):
to_indices.append(self._index) to_indices.append(self._index)
self.last_index[0] = self._index
self._index = (self._index + 1) % self.maxsize self._index = (self._index + 1) % self.maxsize
self._size = min(self._size + 1, self.maxsize) self._size = min(self._size + 1, self.maxsize)
to_indices = np.array(to_indices) to_indices = np.array(to_indices)
@ -180,7 +183,7 @@ class ReplayBuffer:
Return (index_to_be_modified, episode_reward, episode_length, Return (index_to_be_modified, episode_reward, episode_length,
episode_start_index). episode_start_index).
""" """
ptr = self._index self.last_index[0] = ptr = self._index
self._size = min(self._size + 1, self.maxsize) self._size = min(self._size + 1, self.maxsize)
self._index = (self._index + 1) % self.maxsize self._index = (self._index + 1) % self.maxsize
@ -296,6 +299,13 @@ class ReplayBuffer:
"""Return the stacked result. """Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index.
:param index: the index for getting stacked data (t in the example).
:param str key: the key to get, should be one of the reserved_keys.
:param default_value: if the given key's data is not found and default_value is
set, return this default_value.
:param int stack_num: the stack num (4 in the example). Default to
self.stack_num.
""" """
if key not in self._meta and default_value is not None: if key not in self._meta and default_value is not None:
return default_value return default_value
@ -306,7 +316,10 @@ class ReplayBuffer:
if stack_num == 1: # the most often case if stack_num == 1: # the most often case
return val[index] return val[index]
stack: List[Any] = [] stack: List[Any] = []
indice = np.asarray(index) if isinstance(index, list):
indice = np.array(index)
else:
indice = index
for _ in range(stack_num): for _ in range(stack_num):
stack = [val[indice]] + stack stack = [val[indice]] + stack
indice = self.prev(indice) indice = self.prev(indice)
@ -453,12 +466,24 @@ class ReplayBufferManager(ReplayBuffer):
offset.append(size) offset.append(size)
size += buf.maxsize size += buf.maxsize
self._offset = np.array(offset) self._offset = np.array(offset)
self._extend_offset = np.array(offset + [size])
self._lengths = np.zeros_like(offset)
super().__init__(size=size, **kwargs) super().__init__(size=size, **kwargs)
self._compile()
def _compile(self) -> None:
lens = last = index = np.array([0])
offset = np.array([0, 1])
done = np.array([False, False])
_prev_index(index, offset, done, last, lens)
_next_index(index, offset, done, last, lens)
def __len__(self) -> int: def __len__(self) -> int:
return sum([len(buf) for buf in self.buffers]) return self._lengths.sum()
def reset(self) -> None: def reset(self) -> None:
self.last_index = self._offset.copy()
self._lengths = np.zeros_like(self._offset)
for buf in self.buffers: for buf in self.buffers:
buf.reset() buf.reset()
@ -477,22 +502,20 @@ class ReplayBufferManager(ReplayBuffer):
]) ])
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize if isinstance(index, (list, np.ndarray)):
prev_indices = np.zeros_like(index) return _prev_index(np.asarray(index), self._extend_offset,
for offset, buf in zip(self._offset, self.buffers): self.done, self.last_index, self._lengths)
mask = (offset <= index) & (index < offset + buf.maxsize) else:
if np.any(mask): return _prev_index(np.array([index]), self._extend_offset,
prev_indices[mask] = buf.prev(index[mask] - offset) + offset self.done, self.last_index, self._lengths)[0]
return prev_indices
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize if isinstance(index, (list, np.ndarray)):
next_indices = np.zeros_like(index) return _next_index(np.asarray(index), self._extend_offset,
for offset, buf in zip(self._offset, self.buffers): self.done, self.last_index, self._lengths)
mask = (offset <= index) & (index < offset + buf.maxsize) else:
if np.any(mask): return _next_index(np.array([index]), self._extend_offset,
next_indices[mask] = buf.next(index[mask] - offset) + offset self.done, self.last_index, self._lengths)[0]
return next_indices
def update(self, buffer: ReplayBuffer) -> np.ndarray: def update(self, buffer: ReplayBuffer) -> np.ndarray:
"""The ReplayBufferManager cannot be updated by any buffer.""" """The ReplayBufferManager cannot be updated by any buffer."""
@ -534,6 +557,8 @@ class ReplayBufferManager(ReplayBuffer):
ep_lens.append(ep_len) ep_lens.append(ep_len)
ep_rews.append(ep_rew) ep_rews.append(ep_rew)
ep_idxs.append(ep_idx + self._offset[buffer_id]) ep_idxs.append(ep_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
self._lengths[buffer_id] = len(self.buffers[buffer_id])
ptrs = np.array(ptrs) ptrs = np.array(ptrs)
try: try:
self._meta[ptrs] = batch self._meta[ptrs] = batch
@ -564,9 +589,8 @@ class ReplayBufferManager(ReplayBuffer):
if batch_size == 0: # get all available indices if batch_size == 0: # get all available indices
sample_num = np.zeros(self.buffer_num, np.int) sample_num = np.zeros(self.buffer_num, np.int)
else: else:
buffer_lens = np.array([len(buf) for buf in self.buffers])
buffer_idx = np.random.choice( buffer_idx = np.random.choice(
self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum() self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
) )
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
# avoid batch_size > 0 and sample_num == 0 -> get child's all data # avoid batch_size > 0 and sample_num == 0 -> get child's all data
@ -726,6 +750,51 @@ class CachedReplayBuffer(ReplayBufferManager):
updated_ep_idx.append(index[0]) updated_ep_idx.append(index[0])
updated_ptr.append(index[-1]) updated_ptr.append(index[-1])
self.buffers[buffer_idx].reset() self.buffers[buffer_idx].reset()
self._lengths[0] = len(self.main_buffer)
self._lengths[buffer_idx] = 0
self.last_index[0] = index[-1]
self.last_index[buffer_idx] = self._offset[buffer_idx]
ptr[done] = updated_ptr ptr[done] = updated_ptr
ep_idx[done] = updated_ep_idx ep_idx[done] = updated_ep_idx
return ptr, ep_rew, ep_len, ep_idx return ptr, ep_rew, ep_len, ep_idx
@njit
def _prev_index(
index: np.ndarray,
offset: np.ndarray,
done: np.ndarray,
last_index: np.ndarray,
lengths: np.ndarray,
) -> np.ndarray:
index = index % offset[-1]
prev_index = np.zeros_like(index)
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
mask = (start <= index) & (index < end)
cur_len = max(1, cur_len)
if np.sum(mask) > 0:
subind = index[mask]
subind = (subind - start - 1) % cur_len
end_flag = done[subind + start] | (subind + start == last)
prev_index[mask] = (subind + end_flag) % cur_len + start
return prev_index
@njit
def _next_index(
index: np.ndarray,
offset: np.ndarray,
done: np.ndarray,
last_index: np.ndarray,
lengths: np.ndarray,
) -> np.ndarray:
index = index % offset[-1]
next_index = np.zeros_like(index)
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
mask = (start <= index) & (index < end)
cur_len = max(1, cur_len)
if np.sum(mask) > 0:
subind = index[mask]
end_flag = done[subind] | (subind == last)
next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
return next_index

View File

@ -220,12 +220,12 @@ class BasePolicy(ABC, nn.Module):
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch. to calculate q function/reward to go of given batch.
:param Batch batch: a data batch which contains several episodes of data :param Batch batch: a data batch which contains several episodes of data in
in sequential order. Mind that the end of each finished episode of batch sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be should be marked by done flag, unfinished (or collecting) episodes will be
recongized by buffer.unfinished_index(). recongized by buffer.unfinished_index().
:param np.ndarray indice: tell batch's location in buffer, batch is :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to
equal to buffer[indice]. buffer[indice].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation, :param float gae_lambda: the parameter for Generalized Advantage Estimation,