Improve buffer.prev() & buffer.next() (#294)
This commit is contained in:
parent
7036073649
commit
e99e1b0fdd
@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
assert b.info.a[1] == 4 and b.info.b.c[1] == 0
|
assert b.info.a[1] == 4 and b.info.b.c[1] == 0
|
||||||
assert b.info.d.e[1] == -np.inf
|
assert b.info.d.e[1] == -np.inf
|
||||||
# test batch-style adding method, where len(batch) == 1
|
# test batch-style adding method, where len(batch) == 1
|
||||||
batch.done = 1
|
batch.done = [1]
|
||||||
batch.info.e = np.zeros([1, 4])
|
batch.info.e = np.zeros([1, 4])
|
||||||
batch = Batch.stack([batch])
|
batch = Batch.stack([batch])
|
||||||
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
|
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
|
||||||
@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
assert b.info.e.shape == (b.maxsize, 1, 4)
|
assert b.info.e.shape == (b.maxsize, 1, 4)
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
b[22]
|
b[22]
|
||||||
|
# test prev / next
|
||||||
|
assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
|
||||||
|
assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
|
||||||
|
batch.done = [0]
|
||||||
|
b.add(batch, buffer_ids=[0])
|
||||||
|
assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
|
||||||
|
assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
def test_ignore_obs_next(size=10):
|
def test_ignore_obs_next(size=10):
|
||||||
@ -718,7 +725,6 @@ if __name__ == '__main__':
|
|||||||
test_stack()
|
test_stack()
|
||||||
test_segtree()
|
test_segtree()
|
||||||
test_priortized_replaybuffer()
|
test_priortized_replaybuffer()
|
||||||
test_priortized_replaybuffer(233333, 200000)
|
|
||||||
test_update()
|
test_update()
|
||||||
test_pickle()
|
test_pickle()
|
||||||
test_hdf5()
|
test_hdf5()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import h5py
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numba import njit
|
||||||
from typing import Any, Dict, List, Tuple, Union, Sequence, Optional
|
from typing import Any, Dict, List, Tuple, Union, Sequence, Optional
|
||||||
|
|
||||||
from tianshou.data.batch import _create_value
|
from tianshou.data.batch import _create_value
|
||||||
@ -116,6 +117,7 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Clear all the data in replay buffer and episode statistics."""
|
"""Clear all the data in replay buffer and episode statistics."""
|
||||||
|
self.last_index = np.array([0])
|
||||||
self._index = self._size = 0
|
self._index = self._size = 0
|
||||||
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
|
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
|
||||||
|
|
||||||
@ -137,7 +139,7 @@ class ReplayBuffer:
|
|||||||
The index won't be modified if it is the beginning of an episode.
|
The index won't be modified if it is the beginning of an episode.
|
||||||
"""
|
"""
|
||||||
index = (index - 1) % self._size
|
index = (index - 1) % self._size
|
||||||
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
end_flag = self.done[index] | (index == self.last_index[0])
|
||||||
return (index + end_flag) % self._size
|
return (index + end_flag) % self._size
|
||||||
|
|
||||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
@ -145,7 +147,7 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
The index won't be modified if it is the end of an episode.
|
The index won't be modified if it is the end of an episode.
|
||||||
"""
|
"""
|
||||||
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
end_flag = self.done[index] | (index == self.last_index[0])
|
||||||
return (index + (1 - end_flag)) % self._size
|
return (index + (1 - end_flag)) % self._size
|
||||||
|
|
||||||
def update(self, buffer: "ReplayBuffer") -> np.ndarray:
|
def update(self, buffer: "ReplayBuffer") -> np.ndarray:
|
||||||
@ -163,6 +165,7 @@ class ReplayBuffer:
|
|||||||
to_indices = []
|
to_indices = []
|
||||||
for _ in range(len(from_indices)):
|
for _ in range(len(from_indices)):
|
||||||
to_indices.append(self._index)
|
to_indices.append(self._index)
|
||||||
|
self.last_index[0] = self._index
|
||||||
self._index = (self._index + 1) % self.maxsize
|
self._index = (self._index + 1) % self.maxsize
|
||||||
self._size = min(self._size + 1, self.maxsize)
|
self._size = min(self._size + 1, self.maxsize)
|
||||||
to_indices = np.array(to_indices)
|
to_indices = np.array(to_indices)
|
||||||
@ -180,7 +183,7 @@ class ReplayBuffer:
|
|||||||
Return (index_to_be_modified, episode_reward, episode_length,
|
Return (index_to_be_modified, episode_reward, episode_length,
|
||||||
episode_start_index).
|
episode_start_index).
|
||||||
"""
|
"""
|
||||||
ptr = self._index
|
self.last_index[0] = ptr = self._index
|
||||||
self._size = min(self._size + 1, self.maxsize)
|
self._size = min(self._size + 1, self.maxsize)
|
||||||
self._index = (self._index + 1) % self.maxsize
|
self._index = (self._index + 1) % self.maxsize
|
||||||
|
|
||||||
@ -296,6 +299,13 @@ class ReplayBuffer:
|
|||||||
"""Return the stacked result.
|
"""Return the stacked result.
|
||||||
|
|
||||||
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index.
|
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index.
|
||||||
|
|
||||||
|
:param index: the index for getting stacked data (t in the example).
|
||||||
|
:param str key: the key to get, should be one of the reserved_keys.
|
||||||
|
:param default_value: if the given key's data is not found and default_value is
|
||||||
|
set, return this default_value.
|
||||||
|
:param int stack_num: the stack num (4 in the example). Default to
|
||||||
|
self.stack_num.
|
||||||
"""
|
"""
|
||||||
if key not in self._meta and default_value is not None:
|
if key not in self._meta and default_value is not None:
|
||||||
return default_value
|
return default_value
|
||||||
@ -306,7 +316,10 @@ class ReplayBuffer:
|
|||||||
if stack_num == 1: # the most often case
|
if stack_num == 1: # the most often case
|
||||||
return val[index]
|
return val[index]
|
||||||
stack: List[Any] = []
|
stack: List[Any] = []
|
||||||
indice = np.asarray(index)
|
if isinstance(index, list):
|
||||||
|
indice = np.array(index)
|
||||||
|
else:
|
||||||
|
indice = index
|
||||||
for _ in range(stack_num):
|
for _ in range(stack_num):
|
||||||
stack = [val[indice]] + stack
|
stack = [val[indice]] + stack
|
||||||
indice = self.prev(indice)
|
indice = self.prev(indice)
|
||||||
@ -453,12 +466,24 @@ class ReplayBufferManager(ReplayBuffer):
|
|||||||
offset.append(size)
|
offset.append(size)
|
||||||
size += buf.maxsize
|
size += buf.maxsize
|
||||||
self._offset = np.array(offset)
|
self._offset = np.array(offset)
|
||||||
|
self._extend_offset = np.array(offset + [size])
|
||||||
|
self._lengths = np.zeros_like(offset)
|
||||||
super().__init__(size=size, **kwargs)
|
super().__init__(size=size, **kwargs)
|
||||||
|
self._compile()
|
||||||
|
|
||||||
|
def _compile(self) -> None:
|
||||||
|
lens = last = index = np.array([0])
|
||||||
|
offset = np.array([0, 1])
|
||||||
|
done = np.array([False, False])
|
||||||
|
_prev_index(index, offset, done, last, lens)
|
||||||
|
_next_index(index, offset, done, last, lens)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return sum([len(buf) for buf in self.buffers])
|
return self._lengths.sum()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
self.last_index = self._offset.copy()
|
||||||
|
self._lengths = np.zeros_like(self._offset)
|
||||||
for buf in self.buffers:
|
for buf in self.buffers:
|
||||||
buf.reset()
|
buf.reset()
|
||||||
|
|
||||||
@ -477,22 +502,20 @@ class ReplayBufferManager(ReplayBuffer):
|
|||||||
])
|
])
|
||||||
|
|
||||||
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
index = np.asarray(index) % self.maxsize
|
if isinstance(index, (list, np.ndarray)):
|
||||||
prev_indices = np.zeros_like(index)
|
return _prev_index(np.asarray(index), self._extend_offset,
|
||||||
for offset, buf in zip(self._offset, self.buffers):
|
self.done, self.last_index, self._lengths)
|
||||||
mask = (offset <= index) & (index < offset + buf.maxsize)
|
else:
|
||||||
if np.any(mask):
|
return _prev_index(np.array([index]), self._extend_offset,
|
||||||
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
|
self.done, self.last_index, self._lengths)[0]
|
||||||
return prev_indices
|
|
||||||
|
|
||||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
index = np.asarray(index) % self.maxsize
|
if isinstance(index, (list, np.ndarray)):
|
||||||
next_indices = np.zeros_like(index)
|
return _next_index(np.asarray(index), self._extend_offset,
|
||||||
for offset, buf in zip(self._offset, self.buffers):
|
self.done, self.last_index, self._lengths)
|
||||||
mask = (offset <= index) & (index < offset + buf.maxsize)
|
else:
|
||||||
if np.any(mask):
|
return _next_index(np.array([index]), self._extend_offset,
|
||||||
next_indices[mask] = buf.next(index[mask] - offset) + offset
|
self.done, self.last_index, self._lengths)[0]
|
||||||
return next_indices
|
|
||||||
|
|
||||||
def update(self, buffer: ReplayBuffer) -> np.ndarray:
|
def update(self, buffer: ReplayBuffer) -> np.ndarray:
|
||||||
"""The ReplayBufferManager cannot be updated by any buffer."""
|
"""The ReplayBufferManager cannot be updated by any buffer."""
|
||||||
@ -534,6 +557,8 @@ class ReplayBufferManager(ReplayBuffer):
|
|||||||
ep_lens.append(ep_len)
|
ep_lens.append(ep_len)
|
||||||
ep_rews.append(ep_rew)
|
ep_rews.append(ep_rew)
|
||||||
ep_idxs.append(ep_idx + self._offset[buffer_id])
|
ep_idxs.append(ep_idx + self._offset[buffer_id])
|
||||||
|
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
|
||||||
|
self._lengths[buffer_id] = len(self.buffers[buffer_id])
|
||||||
ptrs = np.array(ptrs)
|
ptrs = np.array(ptrs)
|
||||||
try:
|
try:
|
||||||
self._meta[ptrs] = batch
|
self._meta[ptrs] = batch
|
||||||
@ -564,9 +589,8 @@ class ReplayBufferManager(ReplayBuffer):
|
|||||||
if batch_size == 0: # get all available indices
|
if batch_size == 0: # get all available indices
|
||||||
sample_num = np.zeros(self.buffer_num, np.int)
|
sample_num = np.zeros(self.buffer_num, np.int)
|
||||||
else:
|
else:
|
||||||
buffer_lens = np.array([len(buf) for buf in self.buffers])
|
|
||||||
buffer_idx = np.random.choice(
|
buffer_idx = np.random.choice(
|
||||||
self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum()
|
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
|
||||||
)
|
)
|
||||||
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
|
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
|
||||||
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
|
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
|
||||||
@ -726,6 +750,51 @@ class CachedReplayBuffer(ReplayBufferManager):
|
|||||||
updated_ep_idx.append(index[0])
|
updated_ep_idx.append(index[0])
|
||||||
updated_ptr.append(index[-1])
|
updated_ptr.append(index[-1])
|
||||||
self.buffers[buffer_idx].reset()
|
self.buffers[buffer_idx].reset()
|
||||||
|
self._lengths[0] = len(self.main_buffer)
|
||||||
|
self._lengths[buffer_idx] = 0
|
||||||
|
self.last_index[0] = index[-1]
|
||||||
|
self.last_index[buffer_idx] = self._offset[buffer_idx]
|
||||||
ptr[done] = updated_ptr
|
ptr[done] = updated_ptr
|
||||||
ep_idx[done] = updated_ep_idx
|
ep_idx[done] = updated_ep_idx
|
||||||
return ptr, ep_rew, ep_len, ep_idx
|
return ptr, ep_rew, ep_len, ep_idx
|
||||||
|
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def _prev_index(
|
||||||
|
index: np.ndarray,
|
||||||
|
offset: np.ndarray,
|
||||||
|
done: np.ndarray,
|
||||||
|
last_index: np.ndarray,
|
||||||
|
lengths: np.ndarray,
|
||||||
|
) -> np.ndarray:
|
||||||
|
index = index % offset[-1]
|
||||||
|
prev_index = np.zeros_like(index)
|
||||||
|
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
|
||||||
|
mask = (start <= index) & (index < end)
|
||||||
|
cur_len = max(1, cur_len)
|
||||||
|
if np.sum(mask) > 0:
|
||||||
|
subind = index[mask]
|
||||||
|
subind = (subind - start - 1) % cur_len
|
||||||
|
end_flag = done[subind + start] | (subind + start == last)
|
||||||
|
prev_index[mask] = (subind + end_flag) % cur_len + start
|
||||||
|
return prev_index
|
||||||
|
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def _next_index(
|
||||||
|
index: np.ndarray,
|
||||||
|
offset: np.ndarray,
|
||||||
|
done: np.ndarray,
|
||||||
|
last_index: np.ndarray,
|
||||||
|
lengths: np.ndarray,
|
||||||
|
) -> np.ndarray:
|
||||||
|
index = index % offset[-1]
|
||||||
|
next_index = np.zeros_like(index)
|
||||||
|
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
|
||||||
|
mask = (start <= index) & (index < end)
|
||||||
|
cur_len = max(1, cur_len)
|
||||||
|
if np.sum(mask) > 0:
|
||||||
|
subind = index[mask]
|
||||||
|
end_flag = done[subind] | (subind == last)
|
||||||
|
next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
|
||||||
|
return next_index
|
||||||
|
@ -220,12 +220,12 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
|
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
|
||||||
to calculate q function/reward to go of given batch.
|
to calculate q function/reward to go of given batch.
|
||||||
|
|
||||||
:param Batch batch: a data batch which contains several episodes of data
|
:param Batch batch: a data batch which contains several episodes of data in
|
||||||
in sequential order. Mind that the end of each finished episode of batch
|
sequential order. Mind that the end of each finished episode of batch
|
||||||
should be marked by done flag, unfinished (or collecting) episodes will be
|
should be marked by done flag, unfinished (or collecting) episodes will be
|
||||||
recongized by buffer.unfinished_index().
|
recongized by buffer.unfinished_index().
|
||||||
:param np.ndarray indice: tell batch's location in buffer, batch is
|
:param numpy.ndarray indice: tell batch's location in buffer, batch is equal to
|
||||||
equal to buffer[indice].
|
buffer[indice].
|
||||||
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
||||||
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
||||||
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
|
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user