Improve buffer.prev() & buffer.next() (#294)
This commit is contained in:
parent
7036073649
commit
e99e1b0fdd
@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert b.info.a[1] == 4 and b.info.b.c[1] == 0
|
||||
assert b.info.d.e[1] == -np.inf
|
||||
# test batch-style adding method, where len(batch) == 1
|
||||
batch.done = 1
|
||||
batch.done = [1]
|
||||
batch.info.e = np.zeros([1, 4])
|
||||
batch = Batch.stack([batch])
|
||||
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
|
||||
@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert b.info.e.shape == (b.maxsize, 1, 4)
|
||||
with pytest.raises(IndexError):
|
||||
b[22]
|
||||
# test prev / next
|
||||
assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
|
||||
assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
|
||||
batch.done = [0]
|
||||
b.add(batch, buffer_ids=[0])
|
||||
assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
|
||||
assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])
|
||||
|
||||
|
||||
def test_ignore_obs_next(size=10):
|
||||
@ -718,7 +725,6 @@ if __name__ == '__main__':
|
||||
test_stack()
|
||||
test_segtree()
|
||||
test_priortized_replaybuffer()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
test_update()
|
||||
test_pickle()
|
||||
test_hdf5()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import h5py
|
||||
import torch
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
from typing import Any, Dict, List, Tuple, Union, Sequence, Optional
|
||||
|
||||
from tianshou.data.batch import _create_value
|
||||
@ -116,6 +117,7 @@ class ReplayBuffer:
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all the data in replay buffer and episode statistics."""
|
||||
self.last_index = np.array([0])
|
||||
self._index = self._size = 0
|
||||
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
|
||||
|
||||
@ -137,7 +139,7 @@ class ReplayBuffer:
|
||||
The index won't be modified if it is the beginning of an episode.
|
||||
"""
|
||||
index = (index - 1) % self._size
|
||||
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
||||
end_flag = self.done[index] | (index == self.last_index[0])
|
||||
return (index + end_flag) % self._size
|
||||
|
||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
@ -145,7 +147,7 @@ class ReplayBuffer:
|
||||
|
||||
The index won't be modified if it is the end of an episode.
|
||||
"""
|
||||
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
||||
end_flag = self.done[index] | (index == self.last_index[0])
|
||||
return (index + (1 - end_flag)) % self._size
|
||||
|
||||
def update(self, buffer: "ReplayBuffer") -> np.ndarray:
|
||||
@ -163,6 +165,7 @@ class ReplayBuffer:
|
||||
to_indices = []
|
||||
for _ in range(len(from_indices)):
|
||||
to_indices.append(self._index)
|
||||
self.last_index[0] = self._index
|
||||
self._index = (self._index + 1) % self.maxsize
|
||||
self._size = min(self._size + 1, self.maxsize)
|
||||
to_indices = np.array(to_indices)
|
||||
@ -180,7 +183,7 @@ class ReplayBuffer:
|
||||
Return (index_to_be_modified, episode_reward, episode_length,
|
||||
episode_start_index).
|
||||
"""
|
||||
ptr = self._index
|
||||
self.last_index[0] = ptr = self._index
|
||||
self._size = min(self._size + 1, self.maxsize)
|
||||
self._index = (self._index + 1) % self.maxsize
|
||||
|
||||
@ -296,6 +299,13 @@ class ReplayBuffer:
|
||||
"""Return the stacked result.
|
||||
|
||||
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index.
|
||||
|
||||
:param index: the index for getting stacked data (t in the example).
|
||||
:param str key: the key to get, should be one of the reserved_keys.
|
||||
:param default_value: if the given key's data is not found and default_value is
|
||||
set, return this default_value.
|
||||
:param int stack_num: the stack num (4 in the example). Default to
|
||||
self.stack_num.
|
||||
"""
|
||||
if key not in self._meta and default_value is not None:
|
||||
return default_value
|
||||
@ -306,7 +316,10 @@ class ReplayBuffer:
|
||||
if stack_num == 1: # the most often case
|
||||
return val[index]
|
||||
stack: List[Any] = []
|
||||
indice = np.asarray(index)
|
||||
if isinstance(index, list):
|
||||
indice = np.array(index)
|
||||
else:
|
||||
indice = index
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
indice = self.prev(indice)
|
||||
@ -453,12 +466,24 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
offset.append(size)
|
||||
size += buf.maxsize
|
||||
self._offset = np.array(offset)
|
||||
self._extend_offset = np.array(offset + [size])
|
||||
self._lengths = np.zeros_like(offset)
|
||||
super().__init__(size=size, **kwargs)
|
||||
self._compile()
|
||||
|
||||
def _compile(self) -> None:
|
||||
lens = last = index = np.array([0])
|
||||
offset = np.array([0, 1])
|
||||
done = np.array([False, False])
|
||||
_prev_index(index, offset, done, last, lens)
|
||||
_next_index(index, offset, done, last, lens)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum([len(buf) for buf in self.buffers])
|
||||
return self._lengths.sum()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_index = self._offset.copy()
|
||||
self._lengths = np.zeros_like(self._offset)
|
||||
for buf in self.buffers:
|
||||
buf.reset()
|
||||
|
||||
@ -477,22 +502,20 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
])
|
||||
|
||||
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
index = np.asarray(index) % self.maxsize
|
||||
prev_indices = np.zeros_like(index)
|
||||
for offset, buf in zip(self._offset, self.buffers):
|
||||
mask = (offset <= index) & (index < offset + buf.maxsize)
|
||||
if np.any(mask):
|
||||
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
|
||||
return prev_indices
|
||||
if isinstance(index, (list, np.ndarray)):
|
||||
return _prev_index(np.asarray(index), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)
|
||||
else:
|
||||
return _prev_index(np.array([index]), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)[0]
|
||||
|
||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
index = np.asarray(index) % self.maxsize
|
||||
next_indices = np.zeros_like(index)
|
||||
for offset, buf in zip(self._offset, self.buffers):
|
||||
mask = (offset <= index) & (index < offset + buf.maxsize)
|
||||
if np.any(mask):
|
||||
next_indices[mask] = buf.next(index[mask] - offset) + offset
|
||||
return next_indices
|
||||
if isinstance(index, (list, np.ndarray)):
|
||||
return _next_index(np.asarray(index), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)
|
||||
else:
|
||||
return _next_index(np.array([index]), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)[0]
|
||||
|
||||
def update(self, buffer: ReplayBuffer) -> np.ndarray:
|
||||
"""The ReplayBufferManager cannot be updated by any buffer."""
|
||||
@ -534,6 +557,8 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
ep_lens.append(ep_len)
|
||||
ep_rews.append(ep_rew)
|
||||
ep_idxs.append(ep_idx + self._offset[buffer_id])
|
||||
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
|
||||
self._lengths[buffer_id] = len(self.buffers[buffer_id])
|
||||
ptrs = np.array(ptrs)
|
||||
try:
|
||||
self._meta[ptrs] = batch
|
||||
@ -564,9 +589,8 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
if batch_size == 0: # get all available indices
|
||||
sample_num = np.zeros(self.buffer_num, np.int)
|
||||
else:
|
||||
buffer_lens = np.array([len(buf) for buf in self.buffers])
|
||||
buffer_idx = np.random.choice(
|
||||
self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum()
|
||||
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
|
||||
)
|
||||
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
|
||||
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
|
||||
@ -726,6 +750,51 @@ class CachedReplayBuffer(ReplayBufferManager):
|
||||
updated_ep_idx.append(index[0])
|
||||
updated_ptr.append(index[-1])
|
||||
self.buffers[buffer_idx].reset()
|
||||
self._lengths[0] = len(self.main_buffer)
|
||||
self._lengths[buffer_idx] = 0
|
||||
self.last_index[0] = index[-1]
|
||||
self.last_index[buffer_idx] = self._offset[buffer_idx]
|
||||
ptr[done] = updated_ptr
|
||||
ep_idx[done] = updated_ep_idx
|
||||
return ptr, ep_rew, ep_len, ep_idx
|
||||
|
||||
|
||||
@njit
|
||||
def _prev_index(
|
||||
index: np.ndarray,
|
||||
offset: np.ndarray,
|
||||
done: np.ndarray,
|
||||
last_index: np.ndarray,
|
||||
lengths: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
index = index % offset[-1]
|
||||
prev_index = np.zeros_like(index)
|
||||
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
|
||||
mask = (start <= index) & (index < end)
|
||||
cur_len = max(1, cur_len)
|
||||
if np.sum(mask) > 0:
|
||||
subind = index[mask]
|
||||
subind = (subind - start - 1) % cur_len
|
||||
end_flag = done[subind + start] | (subind + start == last)
|
||||
prev_index[mask] = (subind + end_flag) % cur_len + start
|
||||
return prev_index
|
||||
|
||||
|
||||
@njit
|
||||
def _next_index(
|
||||
index: np.ndarray,
|
||||
offset: np.ndarray,
|
||||
done: np.ndarray,
|
||||
last_index: np.ndarray,
|
||||
lengths: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
index = index % offset[-1]
|
||||
next_index = np.zeros_like(index)
|
||||
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
|
||||
mask = (start <= index) & (index < end)
|
||||
cur_len = max(1, cur_len)
|
||||
if np.sum(mask) > 0:
|
||||
subind = index[mask]
|
||||
end_flag = done[subind] | (subind == last)
|
||||
next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
|
||||
return next_index
|
||||
|
@ -220,12 +220,12 @@ class BasePolicy(ABC, nn.Module):
|
||||
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
|
||||
to calculate q function/reward to go of given batch.
|
||||
|
||||
:param Batch batch: a data batch which contains several episodes of data
|
||||
in sequential order. Mind that the end of each finished episode of batch
|
||||
:param Batch batch: a data batch which contains several episodes of data in
|
||||
sequential order. Mind that the end of each finished episode of batch
|
||||
should be marked by done flag, unfinished (or collecting) episodes will be
|
||||
recongized by buffer.unfinished_index().
|
||||
:param np.ndarray indice: tell batch's location in buffer, batch is
|
||||
equal to buffer[indice].
|
||||
:param numpy.ndarray indice: tell batch's location in buffer, batch is equal to
|
||||
buffer[indice].
|
||||
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
||||
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
|
||||
|
Loading…
x
Reference in New Issue
Block a user