Add CachedReplayBuffer and ReplayBufferManager (#278)
This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail. 1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer; 2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`; 3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle; 4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data; 5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False); 6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in; 7. Move buffer's documentation to `docs/tutorials/concepts.rst`. Co-authored-by: n+e <trinkle23897@gmail.com>
This commit is contained in:
parent
1eb6137645
commit
f0129f4ca7
@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
|
|||||||
Buffer
|
Buffer
|
||||||
------
|
------
|
||||||
|
|
||||||
.. automodule:: tianshou.data.ReplayBuffer
|
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
|
||||||
:members:
|
|
||||||
:noindex:
|
|
||||||
|
|
||||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
The current implementation of Tianshou typically use 7 reserved keys in
|
||||||
|
:class:`~tianshou.data.Batch`:
|
||||||
|
|
||||||
|
* ``obs`` the observation of step :math:`t` ;
|
||||||
|
* ``act`` the action of step :math:`t` ;
|
||||||
|
* ``rew`` the reward of step :math:`t` ;
|
||||||
|
* ``done`` the done flag of step :math:`t` ;
|
||||||
|
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||||
|
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
|
||||||
|
* ``policy`` the data computed by policy in step :math:`t`;
|
||||||
|
|
||||||
|
The following code snippet illustrates its usage, including:
|
||||||
|
|
||||||
|
- the basic data storage: ``add()``;
|
||||||
|
- get attribute, get slicing data, ...;
|
||||||
|
- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``;
|
||||||
|
- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``;
|
||||||
|
- save/load data from buffer: pickle and HDF5;
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> import pickle, numpy as np
|
||||||
|
>>> from tianshou.data import ReplayBuffer
|
||||||
|
>>> buf = ReplayBuffer(size=20)
|
||||||
|
>>> for i in range(3):
|
||||||
|
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
|
||||||
|
|
||||||
|
>>> buf.obs
|
||||||
|
# since we set size = 20, len(buf.obs) == 20.
|
||||||
|
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||||
|
>>> # but there are only three valid items, so len(buf) == 3.
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
|
>>> # save to file "buf.pkl"
|
||||||
|
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
|
||||||
|
>>> # save to HDF5 file
|
||||||
|
>>> buf.save_hdf5('buf.hdf5')
|
||||||
|
|
||||||
|
>>> buf2 = ReplayBuffer(size=10)
|
||||||
|
>>> for i in range(15):
|
||||||
|
... done = i % 4 == 0
|
||||||
|
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
|
||||||
|
>>> len(buf2)
|
||||||
|
10
|
||||||
|
>>> buf2.obs
|
||||||
|
# since its size = 10, it only stores the last 10 steps' result.
|
||||||
|
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
|
||||||
|
|
||||||
|
>>> # move buf2's result into buf (meanwhile keep it chronologically)
|
||||||
|
>>> buf.update(buf2)
|
||||||
|
>>> buf.obs
|
||||||
|
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
|
||||||
|
0, 0, 0, 0])
|
||||||
|
|
||||||
|
>>> # get all available index by using batch_size = 0
|
||||||
|
>>> indice = buf.sample_index(0)
|
||||||
|
>>> indice
|
||||||
|
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
|
||||||
|
>>> # get one step previous/next transition
|
||||||
|
>>> buf.prev(indice)
|
||||||
|
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
|
||||||
|
>>> buf.next(indice)
|
||||||
|
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
|
||||||
|
|
||||||
|
>>> # get a random sample from buffer
|
||||||
|
>>> # the batch_data is equal to buf[indice].
|
||||||
|
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||||
|
>>> batch_data.obs == buf[indice].obs
|
||||||
|
array([ True, True, True, True])
|
||||||
|
>>> len(buf)
|
||||||
|
13
|
||||||
|
|
||||||
|
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
|
>>> # load complete buffer from HDF5 file
|
||||||
|
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
|
|
||||||
|
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38):
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Advance usage of ReplayBuffer</summary>
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||||
|
>>> for i in range(16):
|
||||||
|
... done = i % 5 == 0
|
||||||
|
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
|
||||||
|
... done=done, obs_next={'id': i + 1})
|
||||||
|
... print(i, ep_len, ep_rew)
|
||||||
|
0 1 0.0
|
||||||
|
1 0 0.0
|
||||||
|
2 0 0.0
|
||||||
|
3 0 0.0
|
||||||
|
4 0 0.0
|
||||||
|
5 5 15.0
|
||||||
|
6 0 0.0
|
||||||
|
7 0 0.0
|
||||||
|
8 0 0.0
|
||||||
|
9 0 0.0
|
||||||
|
10 5 40.0
|
||||||
|
11 0 0.0
|
||||||
|
12 0 0.0
|
||||||
|
13 0 0.0
|
||||||
|
14 0 0.0
|
||||||
|
15 5 65.0
|
||||||
|
>>> print(buf) # you can see obs_next is not saved in buf
|
||||||
|
ReplayBuffer(
|
||||||
|
obs: Batch(
|
||||||
|
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
|
||||||
|
),
|
||||||
|
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
|
||||||
|
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||||
|
done: array([False, True, False, False, False, False, True, False,
|
||||||
|
False]),
|
||||||
|
info: Batch(),
|
||||||
|
policy: Batch(),
|
||||||
|
)
|
||||||
|
>>> index = np.arange(len(buf))
|
||||||
|
>>> print(buf.get(index, 'obs').id)
|
||||||
|
[[ 7 7 8 9]
|
||||||
|
[ 7 8 9 10]
|
||||||
|
[11 11 11 11]
|
||||||
|
[11 11 11 12]
|
||||||
|
[11 11 12 13]
|
||||||
|
[11 12 13 14]
|
||||||
|
[12 13 14 15]
|
||||||
|
[ 7 7 7 7]
|
||||||
|
[ 7 7 7 8]]
|
||||||
|
>>> # here is another way to get the stacked data
|
||||||
|
>>> # (stack only for obs and obs_next)
|
||||||
|
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||||
|
0
|
||||||
|
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||||
|
>>> print(buf[:].obs_next.id)
|
||||||
|
[[ 7 8 9 10]
|
||||||
|
[ 7 8 9 10]
|
||||||
|
[11 11 11 12]
|
||||||
|
[11 11 12 13]
|
||||||
|
[11 12 13 14]
|
||||||
|
[12 13 14 15]
|
||||||
|
[12 13 14 15]
|
||||||
|
[ 7 7 7 8]
|
||||||
|
[ 7 7 8 9]]
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
</details><br>
|
||||||
|
|
||||||
|
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
||||||
|
|
||||||
|
|
||||||
Policy
|
Policy
|
||||||
|
|||||||
@ -7,9 +7,10 @@ import h5py
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from timeit import timeit
|
from timeit import timeit
|
||||||
|
|
||||||
from tianshou.data import Batch, SegmentTree, \
|
|
||||||
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
|
||||||
from tianshou.data.utils.converter import to_hdf5
|
from tianshou.data.utils.converter import to_hdf5
|
||||||
|
from tianshou.data import Batch, SegmentTree, ReplayBuffer
|
||||||
|
from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
from tianshou.data import ReplayBufferManager, CachedReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -38,11 +39,14 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
assert (data.obs < size).all()
|
assert (data.obs < size).all()
|
||||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||||
b = ReplayBuffer(size=10)
|
b = ReplayBuffer(size=10)
|
||||||
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
# neg bsz should return empty index
|
||||||
|
assert b.sample_index(-1).tolist() == []
|
||||||
|
b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}})
|
||||||
assert b.obs[0] == 1
|
assert b.obs[0] == 1
|
||||||
assert b.done[0] == 'str'
|
assert b.done[0]
|
||||||
|
assert b.obs_next[0] == 'str'
|
||||||
assert np.all(b.obs[1:] == 0)
|
assert np.all(b.obs[1:] == 0)
|
||||||
assert np.all(b.done[1:] == np.array(None))
|
assert np.all(b.obs_next[1:] == np.array(None))
|
||||||
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
||||||
assert np.all(b.info.a[1:] == 0)
|
assert np.all(b.info.a[1:] == 0)
|
||||||
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
||||||
@ -91,7 +95,7 @@ def test_ignore_obs_next(size=10):
|
|||||||
assert data.obs_next
|
assert data.obs_next
|
||||||
|
|
||||||
|
|
||||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
|
||||||
env = MyTestEnv(size)
|
env = MyTestEnv(size)
|
||||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||||
@ -115,7 +119,9 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
|||||||
_, indice = buf2.sample(0)
|
_, indice = buf2.sample(0)
|
||||||
assert indice.tolist() == [2, 6]
|
assert indice.tolist() == [2, 6]
|
||||||
_, indice = buf2.sample(1)
|
_, indice = buf2.sample(1)
|
||||||
assert indice in [2, 6]
|
assert indice[0] in [2, 6]
|
||||||
|
batch, indice = buf2.sample(-1) # neg bsz -> no data
|
||||||
|
assert indice.tolist() == [] and len(batch) == 0
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
buf[bufsize * 2]
|
buf[bufsize * 2]
|
||||||
|
|
||||||
@ -152,6 +158,12 @@ def test_update():
|
|||||||
assert len(buf1) == len(buf2)
|
assert len(buf1) == len(buf2)
|
||||||
assert (buf2[0].obs == buf1[1].obs).all()
|
assert (buf2[0].obs == buf1[1].obs).all()
|
||||||
assert (buf2[-1].obs == buf1[0].obs).all()
|
assert (buf2[-1].obs == buf1[0].obs).all()
|
||||||
|
b = ListReplayBuffer()
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
b.update(b)
|
||||||
|
b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
b.update(b)
|
||||||
|
|
||||||
|
|
||||||
def test_segtree():
|
def test_segtree():
|
||||||
@ -260,8 +272,7 @@ def test_pickle():
|
|||||||
vbuf = ReplayBuffer(size, stack_num=2)
|
vbuf = ReplayBuffer(size, stack_num=2)
|
||||||
lbuf = ListReplayBuffer()
|
lbuf = ListReplayBuffer()
|
||||||
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
rew = np.array([1, 1])
|
||||||
rew = torch.tensor([1.]).to(device)
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
@ -287,18 +298,18 @@ def test_hdf5():
|
|||||||
buffers = {
|
buffers = {
|
||||||
"array": ReplayBuffer(size, stack_num=2),
|
"array": ReplayBuffer(size, stack_num=2),
|
||||||
"list": ListReplayBuffer(),
|
"list": ListReplayBuffer(),
|
||||||
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
|
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
|
||||||
}
|
}
|
||||||
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
rew = torch.tensor([1.]).to(device)
|
info_t = torch.tensor([1.]).to(device)
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'obs': Batch(index=np.array([i])),
|
'obs': Batch(index=np.array([i])),
|
||||||
'act': i,
|
'act': i,
|
||||||
'rew': rew,
|
'rew': np.array([1, 2]),
|
||||||
'done': 0,
|
'done': i % 3 == 2,
|
||||||
'info': {"number": {"n": i}, 'extra': None},
|
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
|
||||||
}
|
}
|
||||||
buffers["array"].add(**kwargs)
|
buffers["array"].add(**kwargs)
|
||||||
buffers["list"].add(**kwargs)
|
buffers["list"].add(**kwargs)
|
||||||
@ -320,10 +331,10 @@ def test_hdf5():
|
|||||||
assert len(_buffers[k]) == len(buffers[k])
|
assert len(_buffers[k]) == len(buffers[k])
|
||||||
assert np.allclose(_buffers[k].act, buffers[k].act)
|
assert np.allclose(_buffers[k].act, buffers[k].act)
|
||||||
assert _buffers[k].stack_num == buffers[k].stack_num
|
assert _buffers[k].stack_num == buffers[k].stack_num
|
||||||
assert _buffers[k]._maxsize == buffers[k]._maxsize
|
assert _buffers[k].maxsize == buffers[k].maxsize
|
||||||
assert _buffers[k]._index == buffers[k]._index
|
|
||||||
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
||||||
for k in ["array", "prioritized"]:
|
for k in ["array", "prioritized"]:
|
||||||
|
assert _buffers[k]._index == buffers[k]._index
|
||||||
assert isinstance(buffers[k].get(0, "info"), Batch)
|
assert isinstance(buffers[k].get(0, "info"), Batch)
|
||||||
assert isinstance(_buffers[k].get(0, "info"), Batch)
|
assert isinstance(_buffers[k].get(0, "info"), Batch)
|
||||||
for k in ["array"]:
|
for k in ["array"]:
|
||||||
@ -332,28 +343,350 @@ def test_hdf5():
|
|||||||
assert np.all(
|
assert np.all(
|
||||||
buffers[k][:].info.extra == _buffers[k][:].info.extra)
|
buffers[k][:].info.extra == _buffers[k][:].info.extra)
|
||||||
|
|
||||||
for path in paths.values():
|
|
||||||
os.remove(path)
|
|
||||||
|
|
||||||
# raise exception when value cannot be pickled
|
# raise exception when value cannot be pickled
|
||||||
data = {"not_supported": lambda x: x*x}
|
data = {"not_supported": lambda x: x * x}
|
||||||
grp = h5py.Group
|
grp = h5py.Group
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
to_hdf5(data, grp)
|
to_hdf5(data, grp)
|
||||||
# ndarray with data type not supported by HDF5 that cannot be pickled
|
# ndarray with data type not supported by HDF5 that cannot be pickled
|
||||||
data = {"not_supported": np.array(lambda x: x*x)}
|
data = {"not_supported": np.array(lambda x: x * x)}
|
||||||
grp = h5py.Group
|
grp = h5py.Group
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
to_hdf5(data, grp)
|
to_hdf5(data, grp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replaybuffermanager():
|
||||||
|
buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)])
|
||||||
|
ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3],
|
||||||
|
done=[0, 0, 1], buffer_ids=[0, 1, 2])
|
||||||
|
assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3])
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
# ReplayBufferManager cannot be updated
|
||||||
|
buf.update(buf)
|
||||||
|
# sample index / prev / next / unfinished_index
|
||||||
|
indice = buf.sample_index(11000)
|
||||||
|
assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample
|
||||||
|
batch, indice = buf.sample(0)
|
||||||
|
assert np.allclose(indice, [0, 5, 10])
|
||||||
|
indice_prev = buf.prev(indice)
|
||||||
|
assert np.allclose(indice_prev, indice), indice_prev
|
||||||
|
indice_next = buf.next(indice)
|
||||||
|
assert np.allclose(indice_next, indice), indice_next
|
||||||
|
assert np.allclose(buf.unfinished_index(), [0, 5])
|
||||||
|
buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3])
|
||||||
|
assert np.allclose(buf.unfinished_index(), [0, 5])
|
||||||
|
batch, indice = buf.sample(10)
|
||||||
|
batch, indice = buf.sample(0)
|
||||||
|
assert np.allclose(indice, [0, 5, 10, 15])
|
||||||
|
indice_prev = buf.prev(indice)
|
||||||
|
assert np.allclose(indice_prev, indice), indice_prev
|
||||||
|
indice_next = buf.next(indice)
|
||||||
|
assert np.allclose(indice_next, indice), indice_next
|
||||||
|
data = np.array([0, 0, 0, 0])
|
||||||
|
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
|
||||||
|
buf.add(obs=data, act=data, rew=data, done=1 - data,
|
||||||
|
buffer_ids=[0, 1, 2, 3])
|
||||||
|
assert len(buf) == 12
|
||||||
|
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
|
||||||
|
buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1],
|
||||||
|
buffer_ids=[0, 1, 2, 3])
|
||||||
|
assert len(buf) == 20
|
||||||
|
indice = buf.sample_index(120000)
|
||||||
|
assert np.bincount(indice).min() >= 5000
|
||||||
|
batch, indice = buf.sample(10)
|
||||||
|
indice = buf.sample_index(0)
|
||||||
|
assert np.allclose(indice, np.arange(len(buf)))
|
||||||
|
# check the actual data stored in buf._meta
|
||||||
|
assert np.allclose(buf.done, [
|
||||||
|
0, 0, 1, 0, 0,
|
||||||
|
0, 0, 1, 0, 1,
|
||||||
|
1, 0, 1, 0, 0,
|
||||||
|
1, 0, 1, 0, 1,
|
||||||
|
])
|
||||||
|
assert np.allclose(buf.prev(indice), [
|
||||||
|
0, 0, 1, 3, 3,
|
||||||
|
5, 5, 6, 8, 8,
|
||||||
|
10, 11, 11, 13, 13,
|
||||||
|
15, 16, 16, 18, 18,
|
||||||
|
])
|
||||||
|
assert np.allclose(buf.next(indice), [
|
||||||
|
1, 2, 2, 4, 4,
|
||||||
|
6, 7, 7, 9, 9,
|
||||||
|
10, 12, 12, 14, 14,
|
||||||
|
15, 17, 17, 19, 19,
|
||||||
|
])
|
||||||
|
assert np.allclose(buf.unfinished_index(), [4, 14])
|
||||||
|
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1],
|
||||||
|
buffer_ids=[2])
|
||||||
|
assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1])
|
||||||
|
assert np.allclose(buf.unfinished_index(), [4])
|
||||||
|
indice = list(sorted(buf.sample_index(0)))
|
||||||
|
assert np.allclose(indice, np.arange(len(buf)))
|
||||||
|
assert np.allclose(buf.prev(indice), [
|
||||||
|
0, 0, 1, 3, 3,
|
||||||
|
5, 5, 6, 8, 8,
|
||||||
|
14, 11, 11, 13, 13,
|
||||||
|
15, 16, 16, 18, 18,
|
||||||
|
])
|
||||||
|
assert np.allclose(buf.next(indice), [
|
||||||
|
1, 2, 2, 4, 4,
|
||||||
|
6, 7, 7, 9, 9,
|
||||||
|
10, 12, 12, 14, 10,
|
||||||
|
15, 17, 17, 19, 19,
|
||||||
|
])
|
||||||
|
# corner case: list, int and -1
|
||||||
|
assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
|
||||||
|
assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
|
||||||
|
batch = buf._meta
|
||||||
|
batch.info.n = np.ones(buf.maxsize)
|
||||||
|
buf.set_batch(batch)
|
||||||
|
assert np.allclose(buf.buffers[-1].info.n, [1] * 5)
|
||||||
|
assert buf.sample_index(-1).tolist() == []
|
||||||
|
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
|
||||||
|
|
||||||
|
|
||||||
|
def test_cachedbuffer():
|
||||||
|
buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
|
||||||
|
assert buf.sample_index(0).tolist() == []
|
||||||
|
# check the normal function/usage/storage in CachedReplayBuffer
|
||||||
|
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0],
|
||||||
|
cached_buffer_ids=[1])
|
||||||
|
obs = np.zeros(buf.maxsize)
|
||||||
|
obs[15] = 1
|
||||||
|
indice = buf.sample_index(0)
|
||||||
|
assert np.allclose(indice, [15])
|
||||||
|
assert np.allclose(buf.prev(indice), [15])
|
||||||
|
assert np.allclose(buf.next(indice), [15])
|
||||||
|
assert np.allclose(buf.obs, obs)
|
||||||
|
assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
|
||||||
|
ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1],
|
||||||
|
cached_buffer_ids=[3])
|
||||||
|
obs[[0, 25]] = 2
|
||||||
|
indice = buf.sample_index(0)
|
||||||
|
assert np.allclose(indice, [0, 15])
|
||||||
|
assert np.allclose(buf.prev(indice), [0, 15])
|
||||||
|
assert np.allclose(buf.next(indice), [0, 15])
|
||||||
|
assert np.allclose(buf.obs, obs)
|
||||||
|
assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
|
||||||
|
assert np.allclose(buf.unfinished_index(), [15])
|
||||||
|
assert np.allclose(buf.sample_index(0), [0, 15])
|
||||||
|
ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4],
|
||||||
|
done=[0, 1], cached_buffer_ids=[3, 1])
|
||||||
|
assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
|
||||||
|
obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
|
||||||
|
assert np.allclose(buf.obs, obs)
|
||||||
|
assert np.allclose(buf.unfinished_index(), [25])
|
||||||
|
indice = buf.sample_index(0)
|
||||||
|
assert np.allclose(indice, [0, 1, 2, 25])
|
||||||
|
assert np.allclose(buf.done[indice], [1, 0, 1, 0])
|
||||||
|
assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
|
||||||
|
assert np.allclose(buf.next(indice), [0, 2, 2, 25])
|
||||||
|
indice = buf.sample_index(10000)
|
||||||
|
assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample
|
||||||
|
# cached buffer with main_buffer size == 0 (no update)
|
||||||
|
# used in test_collector
|
||||||
|
buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
|
||||||
|
data = np.zeros(4)
|
||||||
|
rew = np.ones([4, 4])
|
||||||
|
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
|
||||||
|
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
|
||||||
|
buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
|
||||||
|
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
|
||||||
|
buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
|
||||||
|
assert np.allclose(buf.done, [
|
||||||
|
0, 0, 1, 0, 0,
|
||||||
|
0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0,
|
||||||
|
0, 1, 0, 0, 0,
|
||||||
|
])
|
||||||
|
indice = buf.sample_index(0)
|
||||||
|
assert np.allclose(indice, [0, 1, 10, 11])
|
||||||
|
assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
|
||||||
|
assert np.allclose(buf.next(indice), [1, 1, 11, 11])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multibuf_stack():
|
||||||
|
size = 5
|
||||||
|
bufsize = 9
|
||||||
|
stack_num = 4
|
||||||
|
cached_num = 3
|
||||||
|
env = MyTestEnv(size)
|
||||||
|
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
|
||||||
|
buf4 = CachedReplayBuffer(
|
||||||
|
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
|
||||||
|
cached_num, size)
|
||||||
|
# test if CachedReplayBuffer can handle super corner case:
|
||||||
|
# prio-buffer + stack_num + ignore_obs_next + sample_avail
|
||||||
|
buf5 = CachedReplayBuffer(
|
||||||
|
PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num,
|
||||||
|
ignore_obs_next=True, sample_avail=True),
|
||||||
|
cached_num, size)
|
||||||
|
obs = env.reset(1)
|
||||||
|
for i in range(18):
|
||||||
|
obs_next, rew, done, info = env.step(1)
|
||||||
|
obs_list = np.array([obs + size * i for i in range(cached_num)])
|
||||||
|
act_list = [1] * cached_num
|
||||||
|
rew_list = [rew] * cached_num
|
||||||
|
done_list = [done] * cached_num
|
||||||
|
obs_next_list = -obs_list
|
||||||
|
info_list = [info] * cached_num
|
||||||
|
buf4.add(obs_list, act_list, rew_list, done_list,
|
||||||
|
obs_next_list, info_list)
|
||||||
|
buf5.add(obs_list, act_list, rew_list, done_list,
|
||||||
|
obs_next_list, info_list)
|
||||||
|
obs = obs_next
|
||||||
|
if done:
|
||||||
|
obs = env.reset(1)
|
||||||
|
# check the `add` order is correct
|
||||||
|
assert np.allclose(buf4.obs.reshape(-1), [
|
||||||
|
12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer
|
||||||
|
1, 2, 3, 4, 0, # cached_buffer[0]
|
||||||
|
6, 7, 8, 9, 0, # cached_buffer[1]
|
||||||
|
11, 12, 13, 14, 0, # cached_buffer[2]
|
||||||
|
]), buf4.obs
|
||||||
|
assert np.allclose(buf4.done, [
|
||||||
|
0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer
|
||||||
|
0, 0, 0, 1, 0, # cached_buffer[0]
|
||||||
|
0, 0, 0, 1, 0, # cached_buffer[1]
|
||||||
|
0, 0, 0, 1, 0, # cached_buffer[2]
|
||||||
|
]), buf4.done
|
||||||
|
assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
|
||||||
|
indice = sorted(buf4.sample_index(0))
|
||||||
|
assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
|
||||||
|
assert np.allclose(buf4[indice].obs[..., 0], [
|
||||||
|
[11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14],
|
||||||
|
[4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7],
|
||||||
|
[6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11],
|
||||||
|
[1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6],
|
||||||
|
[6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12],
|
||||||
|
])
|
||||||
|
assert np.allclose(buf4[indice].obs_next[..., 0], [
|
||||||
|
[11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14],
|
||||||
|
[4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8],
|
||||||
|
[6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12],
|
||||||
|
[1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7],
|
||||||
|
[6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12],
|
||||||
|
])
|
||||||
|
assert np.all(buf4.done == buf5.done)
|
||||||
|
indice = buf5.sample_index(0)
|
||||||
|
assert np.allclose(sorted(indice), [2, 7])
|
||||||
|
assert np.all(np.isin(buf5.sample_index(100), indice))
|
||||||
|
# manually change the stack num
|
||||||
|
buf5.stack_num = 2
|
||||||
|
for buf in buf5.buffers:
|
||||||
|
buf.stack_num = 2
|
||||||
|
indice = buf5.sample_index(0)
|
||||||
|
assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
|
||||||
|
batch, _ = buf5.sample(0)
|
||||||
|
assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
|
||||||
|
buf5.update_weight(indice, batch.weight * 0)
|
||||||
|
weight = buf5[np.arange(buf5.maxsize)].weight
|
||||||
|
modified_weight = weight[[0, 1, 2, 5, 6, 7]]
|
||||||
|
assert modified_weight.min() == modified_weight.max()
|
||||||
|
assert modified_weight.max() < 1
|
||||||
|
unmodified_weight = weight[[3, 4, 8]]
|
||||||
|
assert unmodified_weight.min() == unmodified_weight.max()
|
||||||
|
assert unmodified_weight.max() < 1
|
||||||
|
cached_weight = weight[9:]
|
||||||
|
assert cached_weight.min() == cached_weight.max() == 1
|
||||||
|
# test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
|
||||||
|
buf6 = CachedReplayBuffer(
|
||||||
|
ReplayBuffer(bufsize, stack_num=stack_num,
|
||||||
|
save_only_last_obs=True, ignore_obs_next=True),
|
||||||
|
cached_num, size)
|
||||||
|
obs = np.random.rand(size, 4, 84, 84)
|
||||||
|
buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1],
|
||||||
|
obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2])
|
||||||
|
assert buf6.obs.shape == (buf6.maxsize, 84, 84)
|
||||||
|
assert np.allclose(buf6.obs[0], obs[0, -1])
|
||||||
|
assert np.allclose(buf6.obs[14], obs[2, -1])
|
||||||
|
assert np.allclose(buf6.obs[19], obs[0, -1])
|
||||||
|
assert buf6[0].obs.shape == (4, 84, 84)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multibuf_hdf5():
|
||||||
|
size = 100
|
||||||
|
buffers = {
|
||||||
|
"vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
|
||||||
|
"cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
|
||||||
|
}
|
||||||
|
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
info_t = torch.tensor([1.]).to(device)
|
||||||
|
for i in range(4):
|
||||||
|
kwargs = {
|
||||||
|
'obs': Batch(index=np.array([i])),
|
||||||
|
'act': i,
|
||||||
|
'rew': np.array([1, 2]),
|
||||||
|
'done': i % 3 == 2,
|
||||||
|
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
|
||||||
|
}
|
||||||
|
buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
|
||||||
|
buffer_ids=[0, 1, 2])
|
||||||
|
buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
|
||||||
|
cached_buffer_ids=[0, 1, 2])
|
||||||
|
|
||||||
|
# save
|
||||||
|
paths = {}
|
||||||
|
for k, buf in buffers.items():
|
||||||
|
f, path = tempfile.mkstemp(suffix='.hdf5')
|
||||||
|
os.close(f)
|
||||||
|
buf.save_hdf5(path)
|
||||||
|
paths[k] = path
|
||||||
|
|
||||||
|
# load replay buffer
|
||||||
|
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
|
||||||
|
|
||||||
|
# compare
|
||||||
|
for k in buffers.keys():
|
||||||
|
assert len(_buffers[k]) == len(buffers[k])
|
||||||
|
assert np.allclose(_buffers[k].act, buffers[k].act)
|
||||||
|
assert _buffers[k].stack_num == buffers[k].stack_num
|
||||||
|
assert _buffers[k].maxsize == buffers[k].maxsize
|
||||||
|
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
||||||
|
# check shallow copy in ReplayBufferManager
|
||||||
|
for k in ["vector", "cached"]:
|
||||||
|
buffers[k].info.number.n[0] = -100
|
||||||
|
assert buffers[k].buffers[0].info.number.n[0] == -100
|
||||||
|
# check if still behave normally
|
||||||
|
for k in ["vector", "cached"]:
|
||||||
|
kwargs = {
|
||||||
|
'obs': Batch(index=np.array([5])),
|
||||||
|
'act': 5,
|
||||||
|
'rew': np.array([2, 1]),
|
||||||
|
'done': False,
|
||||||
|
'info': {"number": {"n": i}, 'Timelimit.truncate': True},
|
||||||
|
}
|
||||||
|
buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
|
||||||
|
act = np.zeros(buffers[k].maxsize)
|
||||||
|
if k == "vector":
|
||||||
|
act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
|
||||||
|
act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
|
||||||
|
act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
|
||||||
|
act[size * 3] = 5
|
||||||
|
elif k == "cached":
|
||||||
|
act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
|
||||||
|
act[np.arange(3) + size] = np.array([3, 5, 2])
|
||||||
|
act[np.arange(3) + size * 2] = np.array([3, 5, 2])
|
||||||
|
act[np.arange(3) + size * 3] = np.array([3, 5, 2])
|
||||||
|
act[size * 4] = 5
|
||||||
|
assert np.allclose(buffers[k].act, act)
|
||||||
|
|
||||||
|
for path in paths.values():
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_hdf5()
|
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
test_ignore_obs_next()
|
test_ignore_obs_next()
|
||||||
test_stack()
|
test_stack()
|
||||||
test_pickle()
|
|
||||||
test_segtree()
|
test_segtree()
|
||||||
test_priortized_replaybuffer()
|
test_priortized_replaybuffer()
|
||||||
test_priortized_replaybuffer(233333, 200000)
|
test_priortized_replaybuffer(233333, 200000)
|
||||||
test_update()
|
test_update()
|
||||||
|
test_pickle()
|
||||||
|
test_hdf5()
|
||||||
|
test_replaybuffermanager()
|
||||||
|
test_cachedbuffer()
|
||||||
|
test_multibuf_stack()
|
||||||
|
test_multibuf_hdf5()
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from tianshou.data.batch import Batch
|
from tianshou.data.batch import Batch
|
||||||
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
|
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
|
||||||
from tianshou.data.utils.segtree import SegmentTree
|
from tianshou.data.utils.segtree import SegmentTree
|
||||||
from tianshou.data.buffer import ReplayBuffer, \
|
from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \
|
||||||
ListReplayBuffer, PrioritizedReplayBuffer
|
PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer
|
||||||
from tianshou.data.collector import Collector
|
from tianshou.data.collector import Collector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -14,5 +14,7 @@ __all__ = [
|
|||||||
"ReplayBuffer",
|
"ReplayBuffer",
|
||||||
"ListReplayBuffer",
|
"ListReplayBuffer",
|
||||||
"PrioritizedReplayBuffer",
|
"PrioritizedReplayBuffer",
|
||||||
|
"ReplayBufferManager",
|
||||||
|
"CachedReplayBuffer",
|
||||||
"Collector",
|
"Collector",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import h5py
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||||
@ -13,121 +14,13 @@ class ReplayBuffer:
|
|||||||
""":class:`~tianshou.data.ReplayBuffer` stores data generated from \
|
""":class:`~tianshou.data.ReplayBuffer` stores data generated from \
|
||||||
interaction between the policy and environment.
|
interaction between the policy and environment.
|
||||||
|
|
||||||
The current implementation of Tianshou typically use 7 reserved keys in
|
ReplayBuffer can be considered as a specialized form (or management) of
|
||||||
:class:`~tianshou.data.Batch`:
|
Batch. It stores all the data in a batch with circular-queue style.
|
||||||
|
|
||||||
* ``obs`` the observation of step :math:`t` ;
|
For the example usage of ReplayBuffer, please check out Section Buffer in
|
||||||
* ``act`` the action of step :math:`t` ;
|
:doc:`/tutorials/concepts`.
|
||||||
* ``rew`` the reward of step :math:`t` ;
|
|
||||||
* ``done`` the done flag of step :math:`t` ;
|
|
||||||
* ``obs_next`` the observation of step :math:`t+1` ;
|
|
||||||
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
|
|
||||||
function returns 4 arguments, and the last one is ``info``);
|
|
||||||
* ``policy`` the data computed by policy in step :math:`t`;
|
|
||||||
|
|
||||||
The following code snippet illustrates its usage:
|
:param int size: the maximum size of replay buffer.
|
||||||
::
|
|
||||||
|
|
||||||
>>> import pickle, numpy as np
|
|
||||||
>>> from tianshou.data import ReplayBuffer
|
|
||||||
>>> buf = ReplayBuffer(size=20)
|
|
||||||
>>> for i in range(3):
|
|
||||||
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
|
||||||
>>> buf.obs
|
|
||||||
# since we set size = 20, len(buf.obs) == 20.
|
|
||||||
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
||||||
0., 0., 0., 0.])
|
|
||||||
>>> # but there are only three valid items, so len(buf) == 3.
|
|
||||||
>>> len(buf)
|
|
||||||
3
|
|
||||||
>>> # save to file "buf.pkl"
|
|
||||||
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
|
|
||||||
>>> # save to HDF5 file
|
|
||||||
>>> buf.save_hdf5('buf.hdf5')
|
|
||||||
>>> buf2 = ReplayBuffer(size=10)
|
|
||||||
>>> for i in range(15):
|
|
||||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
|
||||||
>>> len(buf2)
|
|
||||||
10
|
|
||||||
>>> buf2.obs
|
|
||||||
# since its size = 10, it only stores the last 10 steps' result.
|
|
||||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
|
||||||
|
|
||||||
>>> # move buf2's result into buf (meanwhile keep it chronologically)
|
|
||||||
>>> buf.update(buf2)
|
|
||||||
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
|
||||||
0., 0., 0., 0., 0., 0., 0.])
|
|
||||||
|
|
||||||
>>> # get a random sample from buffer
|
|
||||||
>>> # the batch_data is equal to buf[indice].
|
|
||||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
|
||||||
>>> batch_data.obs == buf[indice].obs
|
|
||||||
array([ True, True, True, True])
|
|
||||||
>>> len(buf)
|
|
||||||
13
|
|
||||||
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
|
||||||
>>> len(buf)
|
|
||||||
3
|
|
||||||
>>> # load complete buffer from HDF5 file
|
|
||||||
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
|
|
||||||
>>> len(buf)
|
|
||||||
3
|
|
||||||
>>> # load contents of HDF5 file into existing buffer
|
|
||||||
>>> # (only possible if size of buffer and data in file match)
|
|
||||||
>>> buf.load_contents_hdf5('buf.hdf5')
|
|
||||||
>>> len(buf)
|
|
||||||
3
|
|
||||||
|
|
||||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
|
||||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
|
||||||
observation (save memory in atari tasks), and multi-modal observation (see
|
|
||||||
issue#38):
|
|
||||||
::
|
|
||||||
|
|
||||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
|
||||||
>>> for i in range(16):
|
|
||||||
... done = i % 5 == 0
|
|
||||||
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
|
|
||||||
... obs_next={'id': i + 1})
|
|
||||||
>>> print(buf) # you can see obs_next is not saved in buf
|
|
||||||
ReplayBuffer(
|
|
||||||
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
|
||||||
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
|
|
||||||
info: Batch(),
|
|
||||||
obs: Batch(
|
|
||||||
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
|
||||||
),
|
|
||||||
policy: Batch(),
|
|
||||||
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
|
||||||
)
|
|
||||||
>>> index = np.arange(len(buf))
|
|
||||||
>>> print(buf.get(index, 'obs').id)
|
|
||||||
[[ 7. 7. 8. 9.]
|
|
||||||
[ 7. 8. 9. 10.]
|
|
||||||
[11. 11. 11. 11.]
|
|
||||||
[11. 11. 11. 12.]
|
|
||||||
[11. 11. 12. 13.]
|
|
||||||
[11. 12. 13. 14.]
|
|
||||||
[12. 13. 14. 15.]
|
|
||||||
[ 7. 7. 7. 7.]
|
|
||||||
[ 7. 7. 7. 8.]]
|
|
||||||
>>> # here is another way to get the stacked data
|
|
||||||
>>> # (stack only for obs and obs_next)
|
|
||||||
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
|
||||||
0.0
|
|
||||||
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
|
||||||
>>> print(buf[:].obs_next.id)
|
|
||||||
[[ 7. 8. 9. 10.]
|
|
||||||
[ 7. 8. 9. 10.]
|
|
||||||
[11. 11. 11. 12.]
|
|
||||||
[11. 11. 12. 13.]
|
|
||||||
[11. 12. 13. 14.]
|
|
||||||
[12. 13. 14. 15.]
|
|
||||||
[12. 13. 14. 15.]
|
|
||||||
[ 7. 7. 7. 8.]
|
|
||||||
[ 7. 7. 8. 9.]]
|
|
||||||
|
|
||||||
:param int size: the size of replay buffer.
|
|
||||||
:param int stack_num: the frame-stack sampling argument, should be greater
|
:param int stack_num: the frame-stack sampling argument, should be greater
|
||||||
than or equal to 1, defaults to 1 (no stacking).
|
than or equal to 1, defaults to 1 (no stacking).
|
||||||
:param bool ignore_obs_next: whether to store obs_next, defaults to False.
|
:param bool ignore_obs_next: whether to store obs_next, defaults to False.
|
||||||
@ -136,9 +29,11 @@ class ReplayBuffer:
|
|||||||
False.
|
False.
|
||||||
:param bool sample_avail: the parameter indicating sampling only available
|
:param bool sample_avail: the parameter indicating sampling only available
|
||||||
index when using frame-stack sampling method, defaults to False.
|
index when using frame-stack sampling method, defaults to False.
|
||||||
This feature is not supported in Prioritized Replay Buffer currently.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_reserved_keys = ("obs", "act", "rew", "done",
|
||||||
|
"obs_next", "info", "policy")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
@ -147,16 +42,20 @@ class ReplayBuffer:
|
|||||||
save_only_last_obs: bool = False,
|
save_only_last_obs: bool = False,
|
||||||
sample_avail: bool = False,
|
sample_avail: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.options: Dict[str, Any] = {
|
||||||
|
"stack_num": stack_num,
|
||||||
|
"ignore_obs_next": ignore_obs_next,
|
||||||
|
"save_only_last_obs": save_only_last_obs,
|
||||||
|
"sample_avail": sample_avail,
|
||||||
|
}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._maxsize = size
|
self.maxsize = size
|
||||||
self._indices = np.arange(size)
|
assert stack_num > 0, "stack_num should greater than 0"
|
||||||
self.stack_num = stack_num
|
self.stack_num = stack_num
|
||||||
self._avail = sample_avail and stack_num > 1
|
self._indices = np.arange(size)
|
||||||
self._avail_index: List[int] = []
|
self._save_obs_next = not ignore_obs_next
|
||||||
self._save_s_ = not ignore_obs_next
|
self._save_only_last_obs = save_only_last_obs
|
||||||
self._last_obs = save_only_last_obs
|
self._sample_avail = sample_avail
|
||||||
self._index = 0
|
|
||||||
self._size = 0
|
|
||||||
self._meta: Batch = Batch()
|
self._meta: Batch = Batch()
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -181,20 +80,92 @@ class ReplayBuffer:
|
|||||||
We need it because pickling buffer does not work out-of-the-box
|
We need it because pickling buffer does not work out-of-the-box
|
||||||
("buffer.__getattr__" is customized).
|
("buffer.__getattr__" is customized).
|
||||||
"""
|
"""
|
||||||
self._indices = np.arange(state["_maxsize"])
|
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
# compatible with version == 0.3.1's HDF5 data format
|
||||||
|
self._indices = np.arange(self.maxsize)
|
||||||
|
|
||||||
def __getstate__(self) -> dict:
|
def __setattr__(self, key: str, value: Any) -> None:
|
||||||
exclude = {"_indices"}
|
"""Set self.key = value."""
|
||||||
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
|
assert key not in self._reserved_keys, (
|
||||||
return state
|
"key '{}' is reserved and cannot be assigned".format(key))
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def save_hdf5(self, path: str) -> None:
|
||||||
|
"""Save replay buffer to HDF5 file."""
|
||||||
|
with h5py.File(path, "w") as f:
|
||||||
|
to_hdf5(self.__dict__, f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_hdf5(
|
||||||
|
cls, path: str, device: Optional[str] = None
|
||||||
|
) -> "ReplayBuffer":
|
||||||
|
"""Load replay buffer from HDF5 file."""
|
||||||
|
with h5py.File(path, "r") as f:
|
||||||
|
buf = cls.__new__(cls)
|
||||||
|
buf.__setstate__(from_hdf5(f, device=device))
|
||||||
|
return buf
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear all the data in replay buffer and episode statistics."""
|
||||||
|
self._index = self._size = 0
|
||||||
|
self._episode_length, self._episode_reward = 0, 0.0
|
||||||
|
|
||||||
|
def set_batch(self, batch: Batch) -> None:
|
||||||
|
"""Manually choose the batch you want the ReplayBuffer to manage."""
|
||||||
|
assert len(batch) == self.maxsize and \
|
||||||
|
set(batch.keys()).issubset(self._reserved_keys), \
|
||||||
|
"Input batch doesn't meet ReplayBuffer's data form requirement."
|
||||||
|
self._meta = batch
|
||||||
|
|
||||||
|
def unfinished_index(self) -> np.ndarray:
|
||||||
|
"""Return the index of unfinished episode."""
|
||||||
|
last = (self._index - 1) % self._size if self._size else 0
|
||||||
|
return np.array(
|
||||||
|
[last] if not self.done[last] and self._size else [], np.int)
|
||||||
|
|
||||||
|
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
|
"""Return the index of previous transition.
|
||||||
|
|
||||||
|
The index won't be modified if it is the beginning of an episode.
|
||||||
|
"""
|
||||||
|
index = (index - 1) % self._size
|
||||||
|
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
||||||
|
return (index + end_flag) % self._size
|
||||||
|
|
||||||
|
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
|
"""Return the index of next transition.
|
||||||
|
|
||||||
|
The index won't be modified if it is the end of an episode.
|
||||||
|
"""
|
||||||
|
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
||||||
|
return (index + (1 - end_flag)) % self._size
|
||||||
|
|
||||||
|
def update(self, buffer: "ReplayBuffer") -> None:
|
||||||
|
"""Move the data from the given buffer to current buffer."""
|
||||||
|
if len(buffer) == 0 or self.maxsize == 0:
|
||||||
|
return
|
||||||
|
stack_num, buffer.stack_num = buffer.stack_num, 1
|
||||||
|
save_only_last_obs = self._save_only_last_obs
|
||||||
|
self._save_only_last_obs = False
|
||||||
|
indices = buffer.sample_index(0) # get all available indices
|
||||||
|
for i in indices:
|
||||||
|
self.add(**buffer[i]) # type: ignore
|
||||||
|
buffer.stack_num = stack_num
|
||||||
|
self._save_only_last_obs = save_only_last_obs
|
||||||
|
|
||||||
|
def _buffer_allocator(self, key: List[str], value: Any) -> None:
|
||||||
|
"""Allocate memory on buffer._meta for new (key, value) pair."""
|
||||||
|
data = self._meta
|
||||||
|
for k in key[:-1]:
|
||||||
|
data = data[k]
|
||||||
|
data[key[-1]] = _create_value(value, self.maxsize)
|
||||||
|
|
||||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||||
try:
|
try:
|
||||||
value = self._meta.__dict__[name]
|
value = self._meta.__dict__[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
self._buffer_allocator([name], inst)
|
||||||
value = self._meta.__dict__[name]
|
value = self._meta[name]
|
||||||
if isinstance(inst, (torch.Tensor, np.ndarray)):
|
if isinstance(inst, (torch.Tensor, np.ndarray)):
|
||||||
if inst.shape != value.shape[1:]:
|
if inst.shape != value.shape[1:]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -203,33 +174,10 @@ class ReplayBuffer:
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
except KeyError:
|
except KeyError: # inst is a dict/Batch
|
||||||
for key in set(inst.keys()).difference(value.__dict__.keys()):
|
for key in set(inst.keys()).difference(value.keys()):
|
||||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
self._buffer_allocator([name, key], inst[key])
|
||||||
value[self._index] = inst
|
self._meta[name][self._index] = inst
|
||||||
|
|
||||||
@property
|
|
||||||
def stack_num(self) -> int:
|
|
||||||
return self._stack
|
|
||||||
|
|
||||||
@stack_num.setter
|
|
||||||
def stack_num(self, num: int) -> None:
|
|
||||||
assert num > 0, "stack_num should greater than 0"
|
|
||||||
self._stack = num
|
|
||||||
|
|
||||||
def update(self, buffer: "ReplayBuffer") -> None:
|
|
||||||
"""Move the data from the given buffer to self."""
|
|
||||||
if len(buffer) == 0:
|
|
||||||
return
|
|
||||||
i = begin = buffer._index % len(buffer)
|
|
||||||
stack_num_orig = buffer.stack_num
|
|
||||||
buffer.stack_num = 1
|
|
||||||
while True:
|
|
||||||
self.add(**buffer[i]) # type: ignore
|
|
||||||
i = (i + 1) % len(buffer)
|
|
||||||
if i == begin:
|
|
||||||
break
|
|
||||||
buffer.stack_num = stack_num_orig
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
@ -241,117 +189,110 @@ class ReplayBuffer:
|
|||||||
info: Optional[Union[dict, Batch]] = {},
|
info: Optional[Union[dict, Batch]] = {},
|
||||||
policy: Optional[Union[dict, Batch]] = {},
|
policy: Optional[Union[dict, Batch]] = {},
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Tuple[int, Union[float, np.ndarray]]:
|
||||||
"""Add a batch of data into replay buffer."""
|
"""Add a batch of data into replay buffer.
|
||||||
|
|
||||||
|
Return (episode_length, episode_reward) if one episode is terminated,
|
||||||
|
otherwise return (0, 0.0).
|
||||||
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
info, (dict, Batch)
|
info, (dict, Batch)
|
||||||
), "You should return a dict in the last argument of env.step()."
|
), "You should return a dict in the last argument of env.step()."
|
||||||
if self._last_obs:
|
if self._save_only_last_obs:
|
||||||
obs = obs[-1]
|
obs = obs[-1]
|
||||||
self._add_to_buffer("obs", obs)
|
self._add_to_buffer("obs", obs)
|
||||||
self._add_to_buffer("act", act)
|
self._add_to_buffer("act", act)
|
||||||
# make sure the reward is a float instead of an int
|
# make sure the data type of reward is float instead of int
|
||||||
self._add_to_buffer("rew", rew * 1.0) # type: ignore
|
# but rew may be np.ndarray, so that we cannot use float(rew)
|
||||||
self._add_to_buffer("done", done)
|
rew = rew * 1.0 # type: ignore
|
||||||
if self._save_s_:
|
self._add_to_buffer("rew", rew)
|
||||||
|
self._add_to_buffer("done", bool(done)) # done should be a bool scalar
|
||||||
|
if self._save_obs_next:
|
||||||
if obs_next is None:
|
if obs_next is None:
|
||||||
obs_next = Batch()
|
obs_next = Batch()
|
||||||
elif self._last_obs:
|
elif self._save_only_last_obs:
|
||||||
obs_next = obs_next[-1]
|
obs_next = obs_next[-1]
|
||||||
self._add_to_buffer("obs_next", obs_next)
|
self._add_to_buffer("obs_next", obs_next)
|
||||||
self._add_to_buffer("info", info)
|
self._add_to_buffer("info", info)
|
||||||
self._add_to_buffer("policy", policy)
|
self._add_to_buffer("policy", policy)
|
||||||
|
|
||||||
# maintain available index for frame-stack sampling
|
if self.maxsize > 0:
|
||||||
if self._avail:
|
self._size = min(self._size + 1, self.maxsize)
|
||||||
# update current frame
|
self._index = (self._index + 1) % self.maxsize
|
||||||
avail = sum(self.done[i] for i in range(
|
else: # TODO: remove this after deleting ListReplayBuffer
|
||||||
self._index - self.stack_num + 1, self._index)) == 0
|
self._size = self._index = self._size + 1
|
||||||
if self._size < self.stack_num - 1:
|
|
||||||
avail = False
|
|
||||||
if avail and self._index not in self._avail_index:
|
|
||||||
self._avail_index.append(self._index)
|
|
||||||
elif not avail and self._index in self._avail_index:
|
|
||||||
self._avail_index.remove(self._index)
|
|
||||||
# remove the later available frame because of broken storage
|
|
||||||
t = (self._index + self.stack_num - 1) % self._maxsize
|
|
||||||
if t in self._avail_index:
|
|
||||||
self._avail_index.remove(t)
|
|
||||||
|
|
||||||
if self._maxsize > 0:
|
self._episode_reward += rew
|
||||||
self._size = min(self._size + 1, self._maxsize)
|
self._episode_length += 1
|
||||||
self._index = (self._index + 1) % self._maxsize
|
|
||||||
|
if done:
|
||||||
|
result = self._episode_length, self._episode_reward
|
||||||
|
self._episode_length, self._episode_reward = 0, 0.0
|
||||||
|
return result
|
||||||
else:
|
else:
|
||||||
self._size = self._index = self._index + 1
|
return 0, self._episode_reward * 0.0
|
||||||
|
|
||||||
def reset(self) -> None:
|
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||||
"""Clear all the data in replay buffer."""
|
"""Get a random sample of index with size = batch_size.
|
||||||
self._index = 0
|
|
||||||
self._size = 0
|
Return all available indices in the buffer if batch_size is 0; return
|
||||||
self._avail_index = []
|
an empty numpy array if batch_size < 0 or no available index can be
|
||||||
|
sampled.
|
||||||
|
"""
|
||||||
|
if self.stack_num == 1 or not self._sample_avail: # most often case
|
||||||
|
if batch_size > 0:
|
||||||
|
return np.random.choice(self._size, batch_size)
|
||||||
|
elif batch_size == 0: # construct current available indices
|
||||||
|
return np.concatenate([
|
||||||
|
np.arange(self._index, self._size),
|
||||||
|
np.arange(self._index)])
|
||||||
|
else:
|
||||||
|
return np.array([], np.int)
|
||||||
|
else:
|
||||||
|
if batch_size < 0:
|
||||||
|
return np.array([], np.int)
|
||||||
|
all_indices = prev_indices = np.concatenate([
|
||||||
|
np.arange(self._index, self._size), np.arange(self._index)])
|
||||||
|
for _ in range(self.stack_num - 2):
|
||||||
|
prev_indices = self.prev(prev_indices)
|
||||||
|
all_indices = all_indices[prev_indices != self.prev(prev_indices)]
|
||||||
|
if batch_size > 0:
|
||||||
|
return np.random.choice(all_indices, batch_size)
|
||||||
|
else:
|
||||||
|
return all_indices
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||||
"""Get a random sample from buffer with size equal to batch_size.
|
"""Get a random sample from buffer with size = batch_size.
|
||||||
|
|
||||||
Return all the data in the buffer if batch_size is 0.
|
Return all the data in the buffer if batch_size is 0.
|
||||||
|
|
||||||
:return: Sample data and its corresponding index inside the buffer.
|
:return: Sample data and its corresponding index inside the buffer.
|
||||||
"""
|
"""
|
||||||
if batch_size > 0:
|
indices = self.sample_index(batch_size)
|
||||||
_all = self._avail_index if self._avail else self._size
|
return self[indices], indices
|
||||||
indice = np.random.choice(_all, batch_size)
|
|
||||||
else:
|
|
||||||
if self._avail:
|
|
||||||
indice = np.array(self._avail_index)
|
|
||||||
else:
|
|
||||||
indice = np.concatenate([
|
|
||||||
np.arange(self._index, self._size),
|
|
||||||
np.arange(0, self._index),
|
|
||||||
])
|
|
||||||
assert len(indice) > 0, "No available indice can be sampled."
|
|
||||||
return self[indice], indice
|
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
indice: Union[slice, int, np.integer, np.ndarray],
|
index: Union[int, np.integer, np.ndarray],
|
||||||
key: str,
|
key: str,
|
||||||
stack_num: Optional[int] = None,
|
stack_num: Optional[int] = None,
|
||||||
) -> Union[Batch, np.ndarray]:
|
) -> Union[Batch, np.ndarray]:
|
||||||
"""Return the stacked result.
|
"""Return the stacked result.
|
||||||
|
|
||||||
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
|
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
|
||||||
indice. The stack_num (here equals to 4) is given from buffer
|
index.
|
||||||
initialization procedure.
|
|
||||||
"""
|
"""
|
||||||
if stack_num is None:
|
if stack_num is None:
|
||||||
stack_num = self.stack_num
|
stack_num = self.stack_num
|
||||||
|
val = self._meta[key]
|
||||||
|
try:
|
||||||
if stack_num == 1: # the most often case
|
if stack_num == 1: # the most often case
|
||||||
if key != "obs_next" or self._save_s_:
|
return val[index]
|
||||||
val = self._meta.__dict__[key]
|
|
||||||
try:
|
|
||||||
return val[indice]
|
|
||||||
except IndexError as e:
|
|
||||||
if not (isinstance(val, Batch) and val.is_empty()):
|
|
||||||
raise e # val != Batch()
|
|
||||||
return Batch()
|
|
||||||
indice = self._indices[:self._size][indice]
|
|
||||||
done = self._meta.__dict__["done"]
|
|
||||||
if key == "obs_next" and not self._save_s_:
|
|
||||||
indice += 1 - done[indice].astype(np.int)
|
|
||||||
indice[indice == self._size] = 0
|
|
||||||
key = "obs"
|
|
||||||
val = self._meta.__dict__[key]
|
|
||||||
try:
|
|
||||||
if stack_num == 1:
|
|
||||||
return val[indice]
|
|
||||||
stack: List[Any] = []
|
stack: List[Any] = []
|
||||||
|
indice = np.asarray(index)
|
||||||
for _ in range(stack_num):
|
for _ in range(stack_num):
|
||||||
stack = [val[indice]] + stack
|
stack = [val[indice]] + stack
|
||||||
pre_indice = np.asarray(indice - 1)
|
indice = self.prev(indice)
|
||||||
pre_indice[pre_indice == -1] = self._size - 1
|
|
||||||
indice = np.asarray(
|
|
||||||
pre_indice + done[pre_indice].astype(np.int))
|
|
||||||
indice[indice == self._size] = 0
|
|
||||||
if isinstance(val, Batch):
|
if isinstance(val, Batch):
|
||||||
return Batch.stack(stack, axis=indice.ndim)
|
return Batch.stack(stack, axis=indice.ndim)
|
||||||
else:
|
else:
|
||||||
@ -369,31 +310,24 @@ class ReplayBuffer:
|
|||||||
If stack_num is larger than 1, return the stacked obs and obs_next with
|
If stack_num is larger than 1, return the stacked obs and obs_next with
|
||||||
shape (batch, len, ...).
|
shape (batch, len, ...).
|
||||||
"""
|
"""
|
||||||
|
if isinstance(index, slice): # change slice to np array
|
||||||
|
index = self._indices[:len(self)][index]
|
||||||
|
# raise KeyError first instead of AttributeError, to support np.array
|
||||||
|
obs = self.get(index, "obs")
|
||||||
|
if self._save_obs_next:
|
||||||
|
obs_next = self.get(index, "obs_next")
|
||||||
|
else:
|
||||||
|
obs_next = self.get(self.next(index), "obs")
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self.get(index, "obs"),
|
obs=obs,
|
||||||
act=self.act[index],
|
act=self.act[index],
|
||||||
rew=self.rew[index],
|
rew=self.rew[index],
|
||||||
done=self.done[index],
|
done=self.done[index],
|
||||||
obs_next=self.get(index, "obs_next"),
|
obs_next=obs_next,
|
||||||
info=self.get(index, "info"),
|
info=self.get(index, "info"),
|
||||||
policy=self.get(index, "policy"),
|
policy=self.get(index, "policy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_hdf5(self, path: str) -> None:
|
|
||||||
"""Save replay buffer to HDF5 file."""
|
|
||||||
with h5py.File(path, "w") as f:
|
|
||||||
to_hdf5(self.__getstate__(), f)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_hdf5(
|
|
||||||
cls, path: str, device: Optional[str] = None
|
|
||||||
) -> "ReplayBuffer":
|
|
||||||
"""Load replay buffer from HDF5 file."""
|
|
||||||
with h5py.File(path, "r") as f:
|
|
||||||
buf = cls.__new__(cls)
|
|
||||||
buf.__setstate__(from_hdf5(f, device=device))
|
|
||||||
return buf
|
|
||||||
|
|
||||||
|
|
||||||
class ListReplayBuffer(ReplayBuffer):
|
class ListReplayBuffer(ReplayBuffer):
|
||||||
"""List-based replay buffer.
|
"""List-based replay buffer.
|
||||||
@ -411,24 +345,27 @@ class ListReplayBuffer(ReplayBuffer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.")
|
||||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||||
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
||||||
|
|
||||||
def _add_to_buffer(
|
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||||
self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]
|
if self._meta.get(name) is None:
|
||||||
) -> None:
|
|
||||||
if self._meta.__dict__.get(name) is None:
|
|
||||||
self._meta.__dict__[name] = []
|
self._meta.__dict__[name] = []
|
||||||
self._meta.__dict__[name].append(inst)
|
self._meta[name].append(inst)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self._index = self._size = 0
|
super().reset()
|
||||||
for k in list(self._meta.__dict__.keys()):
|
for k in self._meta.keys():
|
||||||
if isinstance(self._meta.__dict__[k], list):
|
if isinstance(self._meta[k], list):
|
||||||
self._meta.__dict__[k] = []
|
self._meta.__dict__[k] = []
|
||||||
|
|
||||||
|
def update(self, buffer: ReplayBuffer) -> None:
|
||||||
|
"""The ListReplayBuffer cannot be updated by any buffer."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
|
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
|
||||||
@ -464,8 +401,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
policy: Optional[Union[dict, Batch]] = {},
|
policy: Optional[Union[dict, Batch]] = {},
|
||||||
weight: Optional[Union[Number, np.number]] = None,
|
weight: Optional[Union[Number, np.number]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> Tuple[int, Union[float, np.ndarray]]:
|
||||||
"""Add a batch of data into replay buffer."""
|
|
||||||
if weight is None:
|
if weight is None:
|
||||||
weight = self._max_prio
|
weight = self._max_prio
|
||||||
else:
|
else:
|
||||||
@ -473,60 +409,289 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self._max_prio = max(self._max_prio, weight)
|
self._max_prio = max(self._max_prio, weight)
|
||||||
self._min_prio = min(self._min_prio, weight)
|
self._min_prio = min(self._min_prio, weight)
|
||||||
self.weight[self._index] = weight ** self._alpha
|
self.weight[self._index] = weight ** self._alpha
|
||||||
super().add(obs, act, rew, done, obs_next, info, policy, **kwargs)
|
return super().add(obs, act, rew, done, obs_next,
|
||||||
|
info, policy, **kwargs)
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||||
"""Get a random sample from buffer with priority probability.
|
if batch_size > 0 and self._size > 0:
|
||||||
|
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
||||||
|
return self.weight.get_prefix_sum_idx(scalar)
|
||||||
|
else:
|
||||||
|
return super().sample_index(batch_size)
|
||||||
|
|
||||||
Return all the data in the buffer if batch_size is 0.
|
def get_weight(
|
||||||
|
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||||
:return: Sample data and its corresponding index inside the buffer.
|
) -> np.ndarray:
|
||||||
|
"""Get the importance sampling weight.
|
||||||
|
|
||||||
The "weight" in the returned Batch is the weight on loss function
|
The "weight" in the returned Batch is the weight on loss function
|
||||||
to de-bias the sampling process (some transition tuples are sampled
|
to de-bias the sampling process (some transition tuples are sampled
|
||||||
more often so their losses are weighted less).
|
more often so their losses are weighted less).
|
||||||
"""
|
"""
|
||||||
assert self._size > 0, "Cannot sample a buffer with 0 size!"
|
|
||||||
if batch_size == 0:
|
|
||||||
indice = np.concatenate([
|
|
||||||
np.arange(self._index, self._size),
|
|
||||||
np.arange(0, self._index),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
|
||||||
indice = self.weight.get_prefix_sum_idx(scalar)
|
|
||||||
batch = self[indice]
|
|
||||||
# important sampling weight calculation
|
# important sampling weight calculation
|
||||||
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
|
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
|
||||||
# simplified formula: (p_j/p_min)**(-beta)
|
# simplified formula: (p_j/p_min)**(-beta)
|
||||||
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
|
return (self.weight[index] / self._min_prio) ** (-self._beta)
|
||||||
return batch, indice
|
|
||||||
|
|
||||||
def update_weight(
|
def update_weight(
|
||||||
self,
|
self,
|
||||||
indice: Union[np.ndarray],
|
index: np.ndarray,
|
||||||
new_weight: Union[np.ndarray, torch.Tensor]
|
new_weight: Union[np.ndarray, torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update priority weight by indice in this buffer.
|
"""Update priority weight by index in this buffer.
|
||||||
|
|
||||||
:param np.ndarray indice: indice you want to update weight.
|
:param np.ndarray index: index you want to update weight.
|
||||||
:param np.ndarray new_weight: new priority weight you want to update.
|
:param np.ndarray new_weight: new priority weight you want to update.
|
||||||
"""
|
"""
|
||||||
weight = np.abs(to_numpy(new_weight)) + self.__eps
|
weight = np.abs(to_numpy(new_weight)) + self.__eps
|
||||||
self.weight[indice] = weight ** self._alpha
|
self.weight[index] = weight ** self._alpha
|
||||||
self._max_prio = max(self._max_prio, weight.max())
|
self._max_prio = max(self._max_prio, weight.max())
|
||||||
self._min_prio = min(self._min_prio, weight.min())
|
self._min_prio = min(self._min_prio, weight.min())
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
return Batch(
|
batch = super().__getitem__(index)
|
||||||
obs=self.get(index, "obs"),
|
batch.weight = self.get_weight(index)
|
||||||
act=self.act[index],
|
return batch
|
||||||
rew=self.rew[index],
|
|
||||||
done=self.done[index],
|
|
||||||
obs_next=self.get(index, "obs_next"),
|
class ReplayBufferManager(ReplayBuffer):
|
||||||
info=self.get(index, "info"),
|
"""ReplayBufferManager contains a list of ReplayBuffer.
|
||||||
policy=self.get(index, "policy"),
|
|
||||||
weight=self.weight[index],
|
These replay buffers have contiguous memory layout, and the storage space
|
||||||
)
|
each buffer has is a shallow copy of the topmost memory.
|
||||||
|
|
||||||
|
:param int buffer_list: a list of ReplayBuffers needed to be handled.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None:
|
||||||
|
self.buffer_num = len(buffer_list)
|
||||||
|
self.buffers = buffer_list
|
||||||
|
self._offset = []
|
||||||
|
offset = 0
|
||||||
|
for buf in self.buffers:
|
||||||
|
# overwrite sub-buffers' _buffer_allocator so that
|
||||||
|
# the top buffer can allocate new memory for all sub-buffers
|
||||||
|
buf._buffer_allocator = self._buffer_allocator # type: ignore
|
||||||
|
assert buf._meta.is_empty()
|
||||||
|
self._offset.append(offset)
|
||||||
|
offset += buf.maxsize
|
||||||
|
super().__init__(size=offset, **kwargs)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return sum([len(buf) for buf in self.buffers])
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
for buf in self.buffers:
|
||||||
|
buf.reset()
|
||||||
|
|
||||||
|
def _set_batch_for_children(self) -> None:
|
||||||
|
for offset, buf in zip(self._offset, self.buffers):
|
||||||
|
buf.set_batch(self._meta[offset:offset + buf.maxsize])
|
||||||
|
|
||||||
|
def set_batch(self, batch: Batch) -> None:
|
||||||
|
super().set_batch(batch)
|
||||||
|
self._set_batch_for_children()
|
||||||
|
|
||||||
|
def unfinished_index(self) -> np.ndarray:
|
||||||
|
return np.concatenate([
|
||||||
|
buf.unfinished_index() + offset
|
||||||
|
for offset, buf in zip(self._offset, self.buffers)])
|
||||||
|
|
||||||
|
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
|
index = np.asarray(index) % self.maxsize
|
||||||
|
prev_indices = np.zeros_like(index)
|
||||||
|
for offset, buf in zip(self._offset, self.buffers):
|
||||||
|
mask = (offset <= index) & (index < offset + buf.maxsize)
|
||||||
|
if np.any(mask):
|
||||||
|
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
|
||||||
|
return prev_indices
|
||||||
|
|
||||||
|
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||||
|
index = np.asarray(index) % self.maxsize
|
||||||
|
next_indices = np.zeros_like(index)
|
||||||
|
for offset, buf in zip(self._offset, self.buffers):
|
||||||
|
mask = (offset <= index) & (index < offset + buf.maxsize)
|
||||||
|
if np.any(mask):
|
||||||
|
next_indices[mask] = buf.next(index[mask] - offset) + offset
|
||||||
|
return next_indices
|
||||||
|
|
||||||
|
def update(self, buffer: ReplayBuffer) -> None:
|
||||||
|
"""The ReplayBufferManager cannot be updated by any buffer."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _buffer_allocator(self, key: List[str], value: Any) -> None:
|
||||||
|
super()._buffer_allocator(key, value)
|
||||||
|
self._set_batch_for_children()
|
||||||
|
|
||||||
|
def add( # type: ignore
|
||||||
|
self,
|
||||||
|
obs: Any,
|
||||||
|
act: Any,
|
||||||
|
rew: np.ndarray,
|
||||||
|
done: np.ndarray,
|
||||||
|
obs_next: Any = Batch(),
|
||||||
|
info: Optional[Batch] = Batch(),
|
||||||
|
policy: Optional[Batch] = Batch(),
|
||||||
|
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Add a batch of data into ReplayBufferManager.
|
||||||
|
|
||||||
|
Each of the data's length (first dimension) must equal to the length of
|
||||||
|
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
|
||||||
|
|
||||||
|
Return the array of episode_length and episode_reward with shape
|
||||||
|
(len(buffer_ids), ...), where (episode_length[i], episode_reward[i])
|
||||||
|
refers to the buffer_ids[i]'s corresponding episode result.
|
||||||
|
"""
|
||||||
|
if buffer_ids is None:
|
||||||
|
buffer_ids = np.arange(self.buffer_num)
|
||||||
|
# assume each element in buffer_ids is unique
|
||||||
|
assert np.bincount(buffer_ids).max() == 1
|
||||||
|
batch = Batch(obs=obs, act=act, rew=rew, done=done,
|
||||||
|
obs_next=obs_next, info=info, policy=policy)
|
||||||
|
assert len(buffer_ids) == len(batch)
|
||||||
|
episode_lengths = [] # (len(buffer_ids),)
|
||||||
|
episode_rewards = [] # (len(buffer_ids), ...)
|
||||||
|
for batch_idx, buffer_id in enumerate(buffer_ids):
|
||||||
|
length, reward = self.buffers[buffer_id].add(**batch[batch_idx])
|
||||||
|
episode_lengths.append(length)
|
||||||
|
episode_rewards.append(reward)
|
||||||
|
return np.stack(episode_lengths), np.stack(episode_rewards)
|
||||||
|
|
||||||
|
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||||
|
if batch_size < 0:
|
||||||
|
return np.array([], np.int)
|
||||||
|
if self._sample_avail and self.stack_num > 1:
|
||||||
|
all_indices = np.concatenate([
|
||||||
|
buf.sample_index(0) + offset
|
||||||
|
for offset, buf in zip(self._offset, self.buffers)])
|
||||||
|
if batch_size == 0:
|
||||||
|
return all_indices
|
||||||
|
else:
|
||||||
|
return np.random.choice(all_indices, batch_size)
|
||||||
|
if batch_size == 0: # get all available indices
|
||||||
|
sample_num = np.zeros(self.buffer_num, np.int)
|
||||||
|
else:
|
||||||
|
buffer_lens = np.array([len(buf) for buf in self.buffers])
|
||||||
|
buffer_idx = np.random.choice(self.buffer_num, batch_size,
|
||||||
|
p=buffer_lens / buffer_lens.sum())
|
||||||
|
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
|
||||||
|
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
|
||||||
|
sample_num[sample_num == 0] = -1
|
||||||
|
|
||||||
|
return np.concatenate([
|
||||||
|
buf.sample_index(bsz) + offset
|
||||||
|
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class CachedReplayBuffer(ReplayBufferManager):
|
||||||
|
"""CachedReplayBuffer contains a given main buffer and n cached buffers, \
|
||||||
|
cached_buffer_num * ReplayBuffer(size=max_episode_length).
|
||||||
|
|
||||||
|
The memory layout is: ``| main_buffer | cached_buffers[0] |
|
||||||
|
cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``.
|
||||||
|
|
||||||
|
The data is first stored in cached buffers. When the episode is
|
||||||
|
terminated, the data will move to the main buffer and the corresponding
|
||||||
|
cached buffer will be reset.
|
||||||
|
|
||||||
|
:param ReplayBuffer main_buffer: the main buffer whose ``.update()``
|
||||||
|
function behaves normally.
|
||||||
|
:param int cached_buffer_num: number of ReplayBuffer needs to be created
|
||||||
|
for cached buffer.
|
||||||
|
:param int max_episode_length: the maximum length of one episode, used in
|
||||||
|
each cached buffer's maxsize.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.data.ReplayBuffer` or
|
||||||
|
:class:`~tianshou.data.ReplayBufferManager` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
main_buffer: ReplayBuffer,
|
||||||
|
cached_buffer_num: int,
|
||||||
|
max_episode_length: int,
|
||||||
|
) -> None:
|
||||||
|
assert cached_buffer_num > 0 and max_episode_length > 0
|
||||||
|
self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer)
|
||||||
|
kwargs = main_buffer.options
|
||||||
|
buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs)
|
||||||
|
for _ in range(cached_buffer_num)]
|
||||||
|
super().__init__(buffer_list=buffers, **kwargs)
|
||||||
|
self.main_buffer = self.buffers[0]
|
||||||
|
self.cached_buffers = self.buffers[1:]
|
||||||
|
self.cached_buffer_num = cached_buffer_num
|
||||||
|
|
||||||
|
def add( # type: ignore
|
||||||
|
self,
|
||||||
|
obs: Any,
|
||||||
|
act: Any,
|
||||||
|
rew: np.ndarray,
|
||||||
|
done: np.ndarray,
|
||||||
|
obs_next: Any = Batch(),
|
||||||
|
info: Optional[Batch] = Batch(),
|
||||||
|
policy: Optional[Batch] = Batch(),
|
||||||
|
cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Add a batch of data into CachedReplayBuffer.
|
||||||
|
|
||||||
|
Each of the data's length (first dimension) must equal to the length of
|
||||||
|
cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ...,
|
||||||
|
cached_buffer_num - 1].
|
||||||
|
|
||||||
|
Return the array of episode_length and episode_reward with shape
|
||||||
|
(len(cached_buffer_ids), ...), where (episode_length[i],
|
||||||
|
episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's
|
||||||
|
corresponding episode result.
|
||||||
|
"""
|
||||||
|
if cached_buffer_ids is None:
|
||||||
|
cached_buffer_ids = np.arange(self.cached_buffer_num)
|
||||||
|
else: # make sure it is np.ndarray
|
||||||
|
cached_buffer_ids = np.asarray(cached_buffer_ids)
|
||||||
|
# in self.buffers, the first buffer is main_buffer
|
||||||
|
buffer_ids = cached_buffer_ids + 1 # type: ignore
|
||||||
|
result = super().add(obs, act, rew, done, obs_next, info,
|
||||||
|
policy, buffer_ids=buffer_ids, **kwargs)
|
||||||
|
# find the terminated episode, move data from cached buf to main buf
|
||||||
|
for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]:
|
||||||
|
self.main_buffer.update(self.cached_buffers[buffer_idx])
|
||||||
|
self.cached_buffers[buffer_idx].reset()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||||
|
) -> Batch:
|
||||||
|
batch = super().__getitem__(index)
|
||||||
|
if self._is_prioritized:
|
||||||
|
indice = self._indices[index]
|
||||||
|
mask = indice < self.main_buffer.maxsize
|
||||||
|
batch.weight = np.ones(len(indice))
|
||||||
|
batch.weight[mask] = self.main_buffer.get_weight(indice[mask])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def update_weight(
|
||||||
|
self,
|
||||||
|
index: np.ndarray,
|
||||||
|
new_weight: Union[np.ndarray, torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Update priority weight by index in main buffer.
|
||||||
|
|
||||||
|
:param np.ndarray index: index you want to update weight.
|
||||||
|
:param np.ndarray new_weight: new priority weight you want to update.
|
||||||
|
"""
|
||||||
|
if self._is_prioritized:
|
||||||
|
mask = index < self.main_buffer.maxsize
|
||||||
|
self.main_buffer.update_weight(index[mask], new_weight[mask])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user