Add CachedReplayBuffer and ReplayBufferManager (#278)

This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail.

1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer;
2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`;
3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle;
4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data;
5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False);
6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in;
7. Move buffer's documentation to `docs/tutorials/concepts.rst`.

Co-authored-by: n+e <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-01-29 12:23:18 +08:00 committed by GitHub
parent 1eb6137645
commit f0129f4ca7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 971 additions and 319 deletions

View File

@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
Buffer Buffer
------ ------
.. automodule:: tianshou.data.ReplayBuffer :class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
:members:
:noindex:
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
The following code snippet illustrates its usage, including:
- the basic data storage: ``add()``;
- get attribute, get slicing data, ...;
- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``;
- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``;
- save/load data from buffer: pickle and HDF5;
::
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... done = i % 4 == 0
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
>>> buf.obs
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
0, 0, 0, 0])
>>> # get all available index by using batch_size = 0
>>> indice = buf.sample_index(0)
>>> indice
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> # get one step previous/next transition
>>> buf.prev(indice)
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
>>> buf.next(indice)
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38):
.. raw:: html
<details>
<summary>Advance usage of ReplayBuffer</summary>
.. code-block:: python
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
... done=done, obs_next={'id': i + 1})
... print(i, ep_len, ep_rew)
0 1 0.0
1 0 0.0
2 0 0.0
3 0 0.0
4 0 0.0
5 5 15.0
6 0 0.0
7 0 0.0
8 0 0.0
9 0 0.0
10 5 40.0
11 0 0.0
12 0 0.0
13 0 0.0
14 0 0.0
15 5 65.0
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
obs: Batch(
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
),
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([False, True, False, False, False, False, True, False,
False]),
info: Batch(),
policy: Batch(),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7 7 8 9]
[ 7 8 9 10]
[11 11 11 11]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[ 7 7 7 7]
[ 7 7 7 8]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7 8 9 10]
[ 7 8 9 10]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[12 13 14 15]
[ 7 7 7 8]
[ 7 7 8 9]]
.. raw:: html
</details><br>
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
Policy Policy

View File

@ -7,9 +7,10 @@ import h5py
import numpy as np import numpy as np
from timeit import timeit from timeit import timeit
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5 from tianshou.data.utils.converter import to_hdf5
from tianshou.data import Batch, SegmentTree, ReplayBuffer
from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data import ReplayBufferManager, CachedReplayBuffer
if __name__ == '__main__': if __name__ == '__main__':
from env import MyTestEnv from env import MyTestEnv
@ -38,11 +39,14 @@ def test_replaybuffer(size=10, bufsize=20):
assert (data.obs < size).all() assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all() assert (0 <= data.done).all() and (data.done <= 1).all()
b = ReplayBuffer(size=10) b = ReplayBuffer(size=10)
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) # neg bsz should return empty index
assert b.sample_index(-1).tolist() == []
b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1 assert b.obs[0] == 1
assert b.done[0] == 'str' assert b.done[0]
assert b.obs_next[0] == 'str'
assert np.all(b.obs[1:] == 0) assert np.all(b.obs[1:] == 0)
assert np.all(b.done[1:] == np.array(None)) assert np.all(b.obs_next[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0) assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
@ -91,7 +95,7 @@ def test_ignore_obs_next(size=10):
assert data.obs_next assert data.obs_next
def test_stack(size=5, bufsize=9, stack_num=4): def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
env = MyTestEnv(size) env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num) buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
@ -115,7 +119,9 @@ def test_stack(size=5, bufsize=9, stack_num=4):
_, indice = buf2.sample(0) _, indice = buf2.sample(0)
assert indice.tolist() == [2, 6] assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1) _, indice = buf2.sample(1)
assert indice in [2, 6] assert indice[0] in [2, 6]
batch, indice = buf2.sample(-1) # neg bsz -> no data
assert indice.tolist() == [] and len(batch) == 0
with pytest.raises(IndexError): with pytest.raises(IndexError):
buf[bufsize * 2] buf[bufsize * 2]
@ -152,6 +158,12 @@ def test_update():
assert len(buf1) == len(buf2) assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all() assert (buf2[-1].obs == buf1[0].obs).all()
b = ListReplayBuffer()
with pytest.raises(NotImplementedError):
b.update(b)
b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
with pytest.raises(NotImplementedError):
b.update(b)
def test_segtree(): def test_segtree():
@ -260,8 +272,7 @@ def test_pickle():
vbuf = ReplayBuffer(size, stack_num=2) vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer() lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
device = 'cuda' if torch.cuda.is_available() else 'cpu' rew = np.array([1, 1])
rew = torch.tensor([1.]).to(device)
for i in range(4): for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3): for i in range(3):
@ -287,18 +298,18 @@ def test_hdf5():
buffers = { buffers = {
"array": ReplayBuffer(size, stack_num=2), "array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(), "list": ListReplayBuffer(),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4) "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
} }
buffer_types = {k: b.__class__ for k, b in buffers.items()} buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device) info_t = torch.tensor([1.]).to(device)
for i in range(4): for i in range(4):
kwargs = { kwargs = {
'obs': Batch(index=np.array([i])), 'obs': Batch(index=np.array([i])),
'act': i, 'act': i,
'rew': rew, 'rew': np.array([1, 2]),
'done': 0, 'done': i % 3 == 2,
'info': {"number": {"n": i}, 'extra': None}, 'info': {"number": {"n": i, "t": info_t}, 'extra': None},
} }
buffers["array"].add(**kwargs) buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs) buffers["list"].add(**kwargs)
@ -320,10 +331,10 @@ def test_hdf5():
assert len(_buffers[k]) == len(buffers[k]) assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act) assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k]._maxsize == buffers[k]._maxsize assert _buffers[k].maxsize == buffers[k].maxsize
assert _buffers[k]._index == buffers[k]._index
assert np.all(_buffers[k]._indices == buffers[k]._indices) assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]: for k in ["array", "prioritized"]:
assert _buffers[k]._index == buffers[k]._index
assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]: for k in ["array"]:
@ -332,28 +343,350 @@ def test_hdf5():
assert np.all( assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra) buffers[k][:].info.extra == _buffers[k][:].info.extra)
for path in paths.values():
os.remove(path)
# raise exception when value cannot be pickled # raise exception when value cannot be pickled
data = {"not_supported": lambda x: x*x} data = {"not_supported": lambda x: x * x}
grp = h5py.Group grp = h5py.Group
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
to_hdf5(data, grp) to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled # ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x*x)} data = {"not_supported": np.array(lambda x: x * x)}
grp = h5py.Group grp = h5py.Group
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
to_hdf5(data, grp) to_hdf5(data, grp)
def test_replaybuffermanager():
buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)])
ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3],
done=[0, 0, 1], buffer_ids=[0, 1, 2])
assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3])
with pytest.raises(NotImplementedError):
# ReplayBufferManager cannot be updated
buf.update(buf)
# sample index / prev / next / unfinished_index
indice = buf.sample_index(11000)
assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample
batch, indice = buf.sample(0)
assert np.allclose(indice, [0, 5, 10])
indice_prev = buf.prev(indice)
assert np.allclose(indice_prev, indice), indice_prev
indice_next = buf.next(indice)
assert np.allclose(indice_next, indice), indice_next
assert np.allclose(buf.unfinished_index(), [0, 5])
buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3])
assert np.allclose(buf.unfinished_index(), [0, 5])
batch, indice = buf.sample(10)
batch, indice = buf.sample(0)
assert np.allclose(indice, [0, 5, 10, 15])
indice_prev = buf.prev(indice)
assert np.allclose(indice_prev, indice), indice_prev
indice_next = buf.next(indice)
assert np.allclose(indice_next, indice), indice_next
data = np.array([0, 0, 0, 0])
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
buf.add(obs=data, act=data, rew=data, done=1 - data,
buffer_ids=[0, 1, 2, 3])
assert len(buf) == 12
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1],
buffer_ids=[0, 1, 2, 3])
assert len(buf) == 20
indice = buf.sample_index(120000)
assert np.bincount(indice).min() >= 5000
batch, indice = buf.sample(10)
indice = buf.sample_index(0)
assert np.allclose(indice, np.arange(len(buf)))
# check the actual data stored in buf._meta
assert np.allclose(buf.done, [
0, 0, 1, 0, 0,
0, 0, 1, 0, 1,
1, 0, 1, 0, 0,
1, 0, 1, 0, 1,
])
assert np.allclose(buf.prev(indice), [
0, 0, 1, 3, 3,
5, 5, 6, 8, 8,
10, 11, 11, 13, 13,
15, 16, 16, 18, 18,
])
assert np.allclose(buf.next(indice), [
1, 2, 2, 4, 4,
6, 7, 7, 9, 9,
10, 12, 12, 14, 14,
15, 17, 17, 19, 19,
])
assert np.allclose(buf.unfinished_index(), [4, 14])
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1],
buffer_ids=[2])
assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1])
assert np.allclose(buf.unfinished_index(), [4])
indice = list(sorted(buf.sample_index(0)))
assert np.allclose(indice, np.arange(len(buf)))
assert np.allclose(buf.prev(indice), [
0, 0, 1, 3, 3,
5, 5, 6, 8, 8,
14, 11, 11, 13, 13,
15, 16, 16, 18, 18,
])
assert np.allclose(buf.next(indice), [
1, 2, 2, 4, 4,
6, 7, 7, 9, 9,
10, 12, 12, 14, 10,
15, 17, 17, 19, 19,
])
# corner case: list, int and -1
assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
batch = buf._meta
batch.info.n = np.ones(buf.maxsize)
buf.set_batch(batch)
assert np.allclose(buf.buffers[-1].info.n, [1] * 5)
assert buf.sample_index(-1).tolist() == []
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
def test_cachedbuffer():
buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
assert buf.sample_index(0).tolist() == []
# check the normal function/usage/storage in CachedReplayBuffer
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0],
cached_buffer_ids=[1])
obs = np.zeros(buf.maxsize)
obs[15] = 1
indice = buf.sample_index(0)
assert np.allclose(indice, [15])
assert np.allclose(buf.prev(indice), [15])
assert np.allclose(buf.next(indice), [15])
assert np.allclose(buf.obs, obs)
assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1],
cached_buffer_ids=[3])
obs[[0, 25]] = 2
indice = buf.sample_index(0)
assert np.allclose(indice, [0, 15])
assert np.allclose(buf.prev(indice), [0, 15])
assert np.allclose(buf.next(indice), [0, 15])
assert np.allclose(buf.obs, obs)
assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
assert np.allclose(buf.unfinished_index(), [15])
assert np.allclose(buf.sample_index(0), [0, 15])
ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4],
done=[0, 1], cached_buffer_ids=[3, 1])
assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
assert np.allclose(buf.obs, obs)
assert np.allclose(buf.unfinished_index(), [25])
indice = buf.sample_index(0)
assert np.allclose(indice, [0, 1, 2, 25])
assert np.allclose(buf.done[indice], [1, 0, 1, 0])
assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
assert np.allclose(buf.next(indice), [0, 2, 2, 25])
indice = buf.sample_index(10000)
assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample
# cached buffer with main_buffer size == 0 (no update)
# used in test_collector
buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
data = np.zeros(4)
rew = np.ones([4, 4])
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
assert np.allclose(buf.done, [
0, 0, 1, 0, 0,
0, 1, 1, 0, 0,
0, 0, 0, 0, 0,
0, 1, 0, 0, 0,
])
indice = buf.sample_index(0)
assert np.allclose(indice, [0, 1, 10, 11])
assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
assert np.allclose(buf.next(indice), [1, 1, 11, 11])
def test_multibuf_stack():
size = 5
bufsize = 9
stack_num = 4
cached_num = 3
env = MyTestEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
cached_num, size)
# test if CachedReplayBuffer can handle super corner case:
# prio-buffer + stack_num + ignore_obs_next + sample_avail
buf5 = CachedReplayBuffer(
PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num,
ignore_obs_next=True, sample_avail=True),
cached_num, size)
obs = env.reset(1)
for i in range(18):
obs_next, rew, done, info = env.step(1)
obs_list = np.array([obs + size * i for i in range(cached_num)])
act_list = [1] * cached_num
rew_list = [rew] * cached_num
done_list = [done] * cached_num
obs_next_list = -obs_list
info_list = [info] * cached_num
buf4.add(obs_list, act_list, rew_list, done_list,
obs_next_list, info_list)
buf5.add(obs_list, act_list, rew_list, done_list,
obs_next_list, info_list)
obs = obs_next
if done:
obs = env.reset(1)
# check the `add` order is correct
assert np.allclose(buf4.obs.reshape(-1), [
12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer
1, 2, 3, 4, 0, # cached_buffer[0]
6, 7, 8, 9, 0, # cached_buffer[1]
11, 12, 13, 14, 0, # cached_buffer[2]
]), buf4.obs
assert np.allclose(buf4.done, [
0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer
0, 0, 0, 1, 0, # cached_buffer[0]
0, 0, 0, 1, 0, # cached_buffer[1]
0, 0, 0, 1, 0, # cached_buffer[2]
]), buf4.done
assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
indice = sorted(buf4.sample_index(0))
assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
assert np.allclose(buf4[indice].obs[..., 0], [
[11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14],
[4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7],
[6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11],
[1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6],
[6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12],
])
assert np.allclose(buf4[indice].obs_next[..., 0], [
[11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14],
[4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8],
[6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12],
[1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7],
[6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12],
])
assert np.all(buf4.done == buf5.done)
indice = buf5.sample_index(0)
assert np.allclose(sorted(indice), [2, 7])
assert np.all(np.isin(buf5.sample_index(100), indice))
# manually change the stack num
buf5.stack_num = 2
for buf in buf5.buffers:
buf.stack_num = 2
indice = buf5.sample_index(0)
assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
batch, _ = buf5.sample(0)
assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
buf5.update_weight(indice, batch.weight * 0)
weight = buf5[np.arange(buf5.maxsize)].weight
modified_weight = weight[[0, 1, 2, 5, 6, 7]]
assert modified_weight.min() == modified_weight.max()
assert modified_weight.max() < 1
unmodified_weight = weight[[3, 4, 8]]
assert unmodified_weight.min() == unmodified_weight.max()
assert unmodified_weight.max() < 1
cached_weight = weight[9:]
assert cached_weight.min() == cached_weight.max() == 1
# test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
buf6 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num,
save_only_last_obs=True, ignore_obs_next=True),
cached_num, size)
obs = np.random.rand(size, 4, 84, 84)
buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1],
obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2])
assert buf6.obs.shape == (buf6.maxsize, 84, 84)
assert np.allclose(buf6.obs[0], obs[0, -1])
assert np.allclose(buf6.obs[14], obs[2, -1])
assert np.allclose(buf6.obs[19], obs[0, -1])
assert buf6[0].obs.shape == (4, 84, 84)
def test_multibuf_hdf5():
size = 100
buffers = {
"vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
"cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
info_t = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': np.array([1, 2]),
'done': i % 3 == 2,
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
}
buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
buffer_ids=[0, 1, 2])
buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
cached_buffer_ids=[0, 1, 2])
# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path
# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k].maxsize == buffers[k].maxsize
assert np.all(_buffers[k]._indices == buffers[k]._indices)
# check shallow copy in ReplayBufferManager
for k in ["vector", "cached"]:
buffers[k].info.number.n[0] = -100
assert buffers[k].buffers[0].info.number.n[0] == -100
# check if still behave normally
for k in ["vector", "cached"]:
kwargs = {
'obs': Batch(index=np.array([5])),
'act': 5,
'rew': np.array([2, 1]),
'done': False,
'info': {"number": {"n": i}, 'Timelimit.truncate': True},
}
buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
act = np.zeros(buffers[k].maxsize)
if k == "vector":
act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
act[size * 3] = 5
elif k == "cached":
act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
act[np.arange(3) + size] = np.array([3, 5, 2])
act[np.arange(3) + size * 2] = np.array([3, 5, 2])
act[np.arange(3) + size * 3] = np.array([3, 5, 2])
act[size * 4] = 5
assert np.allclose(buffers[k].act, act)
for path in paths.values():
os.remove(path)
if __name__ == '__main__': if __name__ == '__main__':
test_hdf5()
test_replaybuffer() test_replaybuffer()
test_ignore_obs_next() test_ignore_obs_next()
test_stack() test_stack()
test_pickle()
test_segtree() test_segtree()
test_priortized_replaybuffer() test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000) test_priortized_replaybuffer(233333, 200000)
test_update() test_update()
test_pickle()
test_hdf5()
test_replaybuffermanager()
test_cachedbuffer()
test_multibuf_stack()
test_multibuf_hdf5()

View File

@ -18,7 +18,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)

View File

@ -16,7 +16,7 @@ from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)

View File

@ -1,8 +1,8 @@
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \ from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer
from tianshou.data.collector import Collector from tianshou.data.collector import Collector
__all__ = [ __all__ = [
@ -14,5 +14,7 @@ __all__ = [
"ReplayBuffer", "ReplayBuffer",
"ListReplayBuffer", "ListReplayBuffer",
"PrioritizedReplayBuffer", "PrioritizedReplayBuffer",
"ReplayBufferManager",
"CachedReplayBuffer",
"Collector", "Collector",
] ]

View File

@ -1,5 +1,6 @@
import h5py import h5py
import torch import torch
import warnings
import numpy as np import numpy as np
from numbers import Number from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
@ -13,121 +14,13 @@ class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from \ """:class:`~tianshou.data.ReplayBuffer` stores data generated from \
interaction between the policy and environment. interaction between the policy and environment.
The current implementation of Tianshou typically use 7 reserved keys in ReplayBuffer can be considered as a specialized form (or management) of
:class:`~tianshou.data.Batch`: Batch. It stores all the data in a batch with circular-queue style.
* ``obs`` the observation of step :math:`t` ; For the example usage of ReplayBuffer, please check out Section Buffer in
* ``act`` the action of step :math:`t` ; :doc:`/tutorials/concepts`.
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
The following code snippet illustrates its usage: :param int size: the maximum size of replay buffer.
::
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
>>> # load contents of HDF5 file into existing buffer
>>> # (only possible if size of buffer and data in file match)
>>> buf.load_contents_hdf5('buf.hdf5')
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
observation (save memory in atari tasks), and multi-modal observation (see
issue#38):
::
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
... obs_next={'id': i + 1})
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
info: Batch(),
obs: Batch(
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
),
policy: Batch(),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater :param int stack_num: the frame-stack sampling argument, should be greater
than or equal to 1, defaults to 1 (no stacking). than or equal to 1, defaults to 1 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to False. :param bool ignore_obs_next: whether to store obs_next, defaults to False.
@ -136,9 +29,11 @@ class ReplayBuffer:
False. False.
:param bool sample_avail: the parameter indicating sampling only available :param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to False. index when using frame-stack sampling method, defaults to False.
This feature is not supported in Prioritized Replay Buffer currently.
""" """
_reserved_keys = ("obs", "act", "rew", "done",
"obs_next", "info", "policy")
def __init__( def __init__(
self, self,
size: int, size: int,
@ -147,16 +42,20 @@ class ReplayBuffer:
save_only_last_obs: bool = False, save_only_last_obs: bool = False,
sample_avail: bool = False, sample_avail: bool = False,
) -> None: ) -> None:
self.options: Dict[str, Any] = {
"stack_num": stack_num,
"ignore_obs_next": ignore_obs_next,
"save_only_last_obs": save_only_last_obs,
"sample_avail": sample_avail,
}
super().__init__() super().__init__()
self._maxsize = size self.maxsize = size
self._indices = np.arange(size) assert stack_num > 0, "stack_num should greater than 0"
self.stack_num = stack_num self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1 self._indices = np.arange(size)
self._avail_index: List[int] = [] self._save_obs_next = not ignore_obs_next
self._save_s_ = not ignore_obs_next self._save_only_last_obs = save_only_last_obs
self._last_obs = save_only_last_obs self._sample_avail = sample_avail
self._index = 0
self._size = 0
self._meta: Batch = Batch() self._meta: Batch = Batch()
self.reset() self.reset()
@ -181,20 +80,92 @@ class ReplayBuffer:
We need it because pickling buffer does not work out-of-the-box We need it because pickling buffer does not work out-of-the-box
("buffer.__getattr__" is customized). ("buffer.__getattr__" is customized).
""" """
self._indices = np.arange(state["_maxsize"])
self.__dict__.update(state) self.__dict__.update(state)
# compatible with version == 0.3.1's HDF5 data format
self._indices = np.arange(self.maxsize)
def __getstate__(self) -> dict: def __setattr__(self, key: str, value: Any) -> None:
exclude = {"_indices"} """Set self.key = value."""
state = {k: v for k, v in self.__dict__.items() if k not in exclude} assert key not in self._reserved_keys, (
return state "key '{}' is reserved and cannot be assigned".format(key))
super().__setattr__(key, value)
def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__dict__, f)
@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf
def reset(self) -> None:
"""Clear all the data in replay buffer and episode statistics."""
self._index = self._size = 0
self._episode_length, self._episode_reward = 0, 0.0
def set_batch(self, batch: Batch) -> None:
"""Manually choose the batch you want the ReplayBuffer to manage."""
assert len(batch) == self.maxsize and \
set(batch.keys()).issubset(self._reserved_keys), \
"Input batch doesn't meet ReplayBuffer's data form requirement."
self._meta = batch
def unfinished_index(self) -> np.ndarray:
"""Return the index of unfinished episode."""
last = (self._index - 1) % self._size if self._size else 0
return np.array(
[last] if not self.done[last] and self._size else [], np.int)
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
"""Return the index of previous transition.
The index won't be modified if it is the beginning of an episode.
"""
index = (index - 1) % self._size
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
return (index + end_flag) % self._size
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
"""Return the index of next transition.
The index won't be modified if it is the end of an episode.
"""
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
return (index + (1 - end_flag)) % self._size
def update(self, buffer: "ReplayBuffer") -> None:
"""Move the data from the given buffer to current buffer."""
if len(buffer) == 0 or self.maxsize == 0:
return
stack_num, buffer.stack_num = buffer.stack_num, 1
save_only_last_obs = self._save_only_last_obs
self._save_only_last_obs = False
indices = buffer.sample_index(0) # get all available indices
for i in indices:
self.add(**buffer[i]) # type: ignore
buffer.stack_num = stack_num
self._save_only_last_obs = save_only_last_obs
def _buffer_allocator(self, key: List[str], value: Any) -> None:
"""Allocate memory on buffer._meta for new (key, value) pair."""
data = self._meta
for k in key[:-1]:
data = data[k]
data[key[-1]] = _create_value(value, self.maxsize)
def _add_to_buffer(self, name: str, inst: Any) -> None: def _add_to_buffer(self, name: str, inst: Any) -> None:
try: try:
value = self._meta.__dict__[name] value = self._meta.__dict__[name]
except KeyError: except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize) self._buffer_allocator([name], inst)
value = self._meta.__dict__[name] value = self._meta[name]
if isinstance(inst, (torch.Tensor, np.ndarray)): if isinstance(inst, (torch.Tensor, np.ndarray)):
if inst.shape != value.shape[1:]: if inst.shape != value.shape[1:]:
raise ValueError( raise ValueError(
@ -203,33 +174,10 @@ class ReplayBuffer:
) )
try: try:
value[self._index] = inst value[self._index] = inst
except KeyError: except KeyError: # inst is a dict/Batch
for key in set(inst.keys()).difference(value.__dict__.keys()): for key in set(inst.keys()).difference(value.keys()):
value.__dict__[key] = _create_value(inst[key], self._maxsize) self._buffer_allocator([name, key], inst[key])
value[self._index] = inst self._meta[name][self._index] = inst
@property
def stack_num(self) -> int:
return self._stack
@stack_num.setter
def stack_num(self, num: int) -> None:
assert num > 0, "stack_num should greater than 0"
self._stack = num
def update(self, buffer: "ReplayBuffer") -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i]) # type: ignore
i = (i + 1) % len(buffer)
if i == begin:
break
buffer.stack_num = stack_num_orig
def add( def add(
self, self,
@ -241,117 +189,110 @@ class ReplayBuffer:
info: Optional[Union[dict, Batch]] = {}, info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {},
**kwargs: Any, **kwargs: Any,
) -> None: ) -> Tuple[int, Union[float, np.ndarray]]:
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer.
Return (episode_length, episode_reward) if one episode is terminated,
otherwise return (0, 0.0).
"""
assert isinstance( assert isinstance(
info, (dict, Batch) info, (dict, Batch)
), "You should return a dict in the last argument of env.step()." ), "You should return a dict in the last argument of env.step()."
if self._last_obs: if self._save_only_last_obs:
obs = obs[-1] obs = obs[-1]
self._add_to_buffer("obs", obs) self._add_to_buffer("obs", obs)
self._add_to_buffer("act", act) self._add_to_buffer("act", act)
# make sure the reward is a float instead of an int # make sure the data type of reward is float instead of int
self._add_to_buffer("rew", rew * 1.0) # type: ignore # but rew may be np.ndarray, so that we cannot use float(rew)
self._add_to_buffer("done", done) rew = rew * 1.0 # type: ignore
if self._save_s_: self._add_to_buffer("rew", rew)
self._add_to_buffer("done", bool(done)) # done should be a bool scalar
if self._save_obs_next:
if obs_next is None: if obs_next is None:
obs_next = Batch() obs_next = Batch()
elif self._last_obs: elif self._save_only_last_obs:
obs_next = obs_next[-1] obs_next = obs_next[-1]
self._add_to_buffer("obs_next", obs_next) self._add_to_buffer("obs_next", obs_next)
self._add_to_buffer("info", info) self._add_to_buffer("info", info)
self._add_to_buffer("policy", policy) self._add_to_buffer("policy", policy)
# maintain available index for frame-stack sampling if self.maxsize > 0:
if self._avail: self._size = min(self._size + 1, self.maxsize)
# update current frame self._index = (self._index + 1) % self.maxsize
avail = sum(self.done[i] for i in range( else: # TODO: remove this after deleting ListReplayBuffer
self._index - self.stack_num + 1, self._index)) == 0 self._size = self._index = self._size + 1
if self._size < self.stack_num - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self.stack_num - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)
if self._maxsize > 0: self._episode_reward += rew
self._size = min(self._size + 1, self._maxsize) self._episode_length += 1
self._index = (self._index + 1) % self._maxsize
if done:
result = self._episode_length, self._episode_reward
self._episode_length, self._episode_reward = 0, 0.0
return result
else: else:
self._size = self._index = self._index + 1 return 0, self._episode_reward * 0.0
def reset(self) -> None: def sample_index(self, batch_size: int) -> np.ndarray:
"""Clear all the data in replay buffer.""" """Get a random sample of index with size = batch_size.
self._index = 0
self._size = 0 Return all available indices in the buffer if batch_size is 0; return
self._avail_index = [] an empty numpy array if batch_size < 0 or no available index can be
sampled.
"""
if self.stack_num == 1 or not self._sample_avail: # most often case
if batch_size > 0:
return np.random.choice(self._size, batch_size)
elif batch_size == 0: # construct current available indices
return np.concatenate([
np.arange(self._index, self._size),
np.arange(self._index)])
else:
return np.array([], np.int)
else:
if batch_size < 0:
return np.array([], np.int)
all_indices = prev_indices = np.concatenate([
np.arange(self._index, self._size), np.arange(self._index)])
for _ in range(self.stack_num - 2):
prev_indices = self.prev(prev_indices)
all_indices = all_indices[prev_indices != self.prev(prev_indices)]
if batch_size > 0:
return np.random.choice(all_indices, batch_size)
else:
return all_indices
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. """Get a random sample from buffer with size = batch_size.
Return all the data in the buffer if batch_size is 0. Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer. :return: Sample data and its corresponding index inside the buffer.
""" """
if batch_size > 0: indices = self.sample_index(batch_size)
_all = self._avail_index if self._avail else self._size return self[indices], indices
indice = np.random.choice(_all, batch_size)
else:
if self._avail:
indice = np.array(self._avail_index)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
assert len(indice) > 0, "No available indice can be sampled."
return self[indice], indice
def get( def get(
self, self,
indice: Union[slice, int, np.integer, np.ndarray], index: Union[int, np.integer, np.ndarray],
key: str, key: str,
stack_num: Optional[int] = None, stack_num: Optional[int] = None,
) -> Union[Batch, np.ndarray]: ) -> Union[Batch, np.ndarray]:
"""Return the stacked result. """Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
indice. The stack_num (here equals to 4) is given from buffer index.
initialization procedure.
""" """
if stack_num is None: if stack_num is None:
stack_num = self.stack_num stack_num = self.stack_num
val = self._meta[key]
try:
if stack_num == 1: # the most often case if stack_num == 1: # the most often case
if key != "obs_next" or self._save_s_: return val[index]
val = self._meta.__dict__[key]
try:
return val[indice]
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
indice = self._indices[:self._size][indice]
done = self._meta.__dict__["done"]
if key == "obs_next" and not self._save_s_:
indice += 1 - done[indice].astype(np.int)
indice[indice == self._size] = 0
key = "obs"
val = self._meta.__dict__[key]
try:
if stack_num == 1:
return val[indice]
stack: List[Any] = [] stack: List[Any] = []
indice = np.asarray(index)
for _ in range(stack_num): for _ in range(stack_num):
stack = [val[indice]] + stack stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1) indice = self.prev(indice)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch): if isinstance(val, Batch):
return Batch.stack(stack, axis=indice.ndim) return Batch.stack(stack, axis=indice.ndim)
else: else:
@ -369,31 +310,24 @@ class ReplayBuffer:
If stack_num is larger than 1, return the stacked obs and obs_next with If stack_num is larger than 1, return the stacked obs and obs_next with
shape (batch, len, ...). shape (batch, len, ...).
""" """
if isinstance(index, slice): # change slice to np array
index = self._indices[:len(self)][index]
# raise KeyError first instead of AttributeError, to support np.array
obs = self.get(index, "obs")
if self._save_obs_next:
obs_next = self.get(index, "obs_next")
else:
obs_next = self.get(self.next(index), "obs")
return Batch( return Batch(
obs=self.get(index, "obs"), obs=obs,
act=self.act[index], act=self.act[index],
rew=self.rew[index], rew=self.rew[index],
done=self.done[index], done=self.done[index],
obs_next=self.get(index, "obs_next"), obs_next=obs_next,
info=self.get(index, "info"), info=self.get(index, "info"),
policy=self.get(index, "policy"), policy=self.get(index, "policy"),
) )
def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__getstate__(), f)
@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf
class ListReplayBuffer(ReplayBuffer): class ListReplayBuffer(ReplayBuffer):
"""List-based replay buffer. """List-based replay buffer.
@ -411,24 +345,27 @@ class ListReplayBuffer(ReplayBuffer):
""" """
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.")
super().__init__(size=0, ignore_obs_next=False, **kwargs) super().__init__(size=0, ignore_obs_next=False, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!") raise NotImplementedError("ListReplayBuffer cannot be sampled!")
def _add_to_buffer( def _add_to_buffer(self, name: str, inst: Any) -> None:
self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool] if self._meta.get(name) is None:
) -> None:
if self._meta.__dict__.get(name) is None:
self._meta.__dict__[name] = [] self._meta.__dict__[name] = []
self._meta.__dict__[name].append(inst) self._meta[name].append(inst)
def reset(self) -> None: def reset(self) -> None:
self._index = self._size = 0 super().reset()
for k in list(self._meta.__dict__.keys()): for k in self._meta.keys():
if isinstance(self._meta.__dict__[k], list): if isinstance(self._meta[k], list):
self._meta.__dict__[k] = [] self._meta.__dict__[k] = []
def update(self, buffer: ReplayBuffer) -> None:
"""The ListReplayBuffer cannot be updated by any buffer."""
raise NotImplementedError
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952. """Implementation of Prioritized Experience Replay. arXiv:1511.05952.
@ -464,8 +401,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
policy: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {},
weight: Optional[Union[Number, np.number]] = None, weight: Optional[Union[Number, np.number]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> Tuple[int, Union[float, np.ndarray]]:
"""Add a batch of data into replay buffer."""
if weight is None: if weight is None:
weight = self._max_prio weight = self._max_prio
else: else:
@ -473,60 +409,289 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_prio = max(self._max_prio, weight) self._max_prio = max(self._max_prio, weight)
self._min_prio = min(self._min_prio, weight) self._min_prio = min(self._min_prio, weight)
self.weight[self._index] = weight ** self._alpha self.weight[self._index] = weight ** self._alpha
super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) return super().add(obs, act, rew, done, obs_next,
info, policy, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def sample_index(self, batch_size: int) -> np.ndarray:
"""Get a random sample from buffer with priority probability. if batch_size > 0 and self._size > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar)
else:
return super().sample_index(batch_size)
Return all the data in the buffer if batch_size is 0. def get_weight(
self, index: Union[slice, int, np.integer, np.ndarray]
:return: Sample data and its corresponding index inside the buffer. ) -> np.ndarray:
"""Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function The "weight" in the returned Batch is the weight on loss function
to de-bias the sampling process (some transition tuples are sampled to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less). more often so their losses are weighted less).
""" """
assert self._size > 0, "Cannot sample a buffer with 0 size!"
if batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
scalar = np.random.rand(batch_size) * self.weight.reduce()
indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
# important sampling weight calculation # important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta) # simplified formula: (p_j/p_min)**(-beta)
batch.weight = (batch.weight / self._min_prio) ** (-self._beta) return (self.weight[index] / self._min_prio) ** (-self._beta)
return batch, indice
def update_weight( def update_weight(
self, self,
indice: Union[np.ndarray], index: np.ndarray,
new_weight: Union[np.ndarray, torch.Tensor] new_weight: Union[np.ndarray, torch.Tensor],
) -> None: ) -> None:
"""Update priority weight by indice in this buffer. """Update priority weight by index in this buffer.
:param np.ndarray indice: indice you want to update weight. :param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update. :param np.ndarray new_weight: new priority weight you want to update.
""" """
weight = np.abs(to_numpy(new_weight)) + self.__eps weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[indice] = weight ** self._alpha self.weight[index] = weight ** self._alpha
self._max_prio = max(self._max_prio, weight.max()) self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min()) self._min_prio = min(self._min_prio, weight.min())
def __getitem__( def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray] self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch: ) -> Batch:
return Batch( batch = super().__getitem__(index)
obs=self.get(index, "obs"), batch.weight = self.get_weight(index)
act=self.act[index], return batch
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, "obs_next"), class ReplayBufferManager(ReplayBuffer):
info=self.get(index, "info"), """ReplayBufferManager contains a list of ReplayBuffer.
policy=self.get(index, "policy"),
weight=self.weight[index], These replay buffers have contiguous memory layout, and the storage space
) each buffer has is a shallow copy of the topmost memory.
:param int buffer_list: a list of ReplayBuffers needed to be handled.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
explanation.
"""
def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None:
self.buffer_num = len(buffer_list)
self.buffers = buffer_list
self._offset = []
offset = 0
for buf in self.buffers:
# overwrite sub-buffers' _buffer_allocator so that
# the top buffer can allocate new memory for all sub-buffers
buf._buffer_allocator = self._buffer_allocator # type: ignore
assert buf._meta.is_empty()
self._offset.append(offset)
offset += buf.maxsize
super().__init__(size=offset, **kwargs)
def __len__(self) -> int:
return sum([len(buf) for buf in self.buffers])
def reset(self) -> None:
for buf in self.buffers:
buf.reset()
def _set_batch_for_children(self) -> None:
for offset, buf in zip(self._offset, self.buffers):
buf.set_batch(self._meta[offset:offset + buf.maxsize])
def set_batch(self, batch: Batch) -> None:
super().set_batch(batch)
self._set_batch_for_children()
def unfinished_index(self) -> np.ndarray:
return np.concatenate([
buf.unfinished_index() + offset
for offset, buf in zip(self._offset, self.buffers)])
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
prev_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
return prev_indices
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
next_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
next_indices[mask] = buf.next(index[mask] - offset) + offset
return next_indices
def update(self, buffer: ReplayBuffer) -> None:
"""The ReplayBufferManager cannot be updated by any buffer."""
raise NotImplementedError
def _buffer_allocator(self, key: List[str], value: Any) -> None:
super()._buffer_allocator(key, value)
self._set_batch_for_children()
def add( # type: ignore
self,
obs: Any,
act: Any,
rew: np.ndarray,
done: np.ndarray,
obs_next: Any = Batch(),
info: Optional[Batch] = Batch(),
policy: Optional[Batch] = Batch(),
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
**kwargs: Any
) -> Tuple[np.ndarray, np.ndarray]:
"""Add a batch of data into ReplayBufferManager.
Each of the data's length (first dimension) must equal to the length of
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
Return the array of episode_length and episode_reward with shape
(len(buffer_ids), ...), where (episode_length[i], episode_reward[i])
refers to the buffer_ids[i]'s corresponding episode result.
"""
if buffer_ids is None:
buffer_ids = np.arange(self.buffer_num)
# assume each element in buffer_ids is unique
assert np.bincount(buffer_ids).max() == 1
batch = Batch(obs=obs, act=act, rew=rew, done=done,
obs_next=obs_next, info=info, policy=policy)
assert len(buffer_ids) == len(batch)
episode_lengths = [] # (len(buffer_ids),)
episode_rewards = [] # (len(buffer_ids), ...)
for batch_idx, buffer_id in enumerate(buffer_ids):
length, reward = self.buffers[buffer_id].add(**batch[batch_idx])
episode_lengths.append(length)
episode_rewards.append(reward)
return np.stack(episode_lengths), np.stack(episode_rewards)
def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
return np.array([], np.int)
if self._sample_avail and self.stack_num > 1:
all_indices = np.concatenate([
buf.sample_index(0) + offset
for offset, buf in zip(self._offset, self.buffers)])
if batch_size == 0:
return all_indices
else:
return np.random.choice(all_indices, batch_size)
if batch_size == 0: # get all available indices
sample_num = np.zeros(self.buffer_num, np.int)
else:
buffer_lens = np.array([len(buf) for buf in self.buffers])
buffer_idx = np.random.choice(self.buffer_num, batch_size,
p=buffer_lens / buffer_lens.sum())
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
sample_num[sample_num == 0] = -1
return np.concatenate([
buf.sample_index(bsz) + offset
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
])
class CachedReplayBuffer(ReplayBufferManager):
"""CachedReplayBuffer contains a given main buffer and n cached buffers, \
cached_buffer_num * ReplayBuffer(size=max_episode_length).
The memory layout is: ``| main_buffer | cached_buffers[0] |
cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``.
The data is first stored in cached buffers. When the episode is
terminated, the data will move to the main buffer and the corresponding
cached buffer will be reset.
:param ReplayBuffer main_buffer: the main buffer whose ``.update()``
function behaves normally.
:param int cached_buffer_num: number of ReplayBuffer needs to be created
for cached buffer.
:param int max_episode_length: the maximum length of one episode, used in
each cached buffer's maxsize.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` or
:class:`~tianshou.data.ReplayBufferManager` for more detailed
explanation.
"""
def __init__(
self,
main_buffer: ReplayBuffer,
cached_buffer_num: int,
max_episode_length: int,
) -> None:
assert cached_buffer_num > 0 and max_episode_length > 0
self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer)
kwargs = main_buffer.options
buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs)
for _ in range(cached_buffer_num)]
super().__init__(buffer_list=buffers, **kwargs)
self.main_buffer = self.buffers[0]
self.cached_buffers = self.buffers[1:]
self.cached_buffer_num = cached_buffer_num
def add( # type: ignore
self,
obs: Any,
act: Any,
rew: np.ndarray,
done: np.ndarray,
obs_next: Any = Batch(),
info: Optional[Batch] = Batch(),
policy: Optional[Batch] = Batch(),
cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray]:
"""Add a batch of data into CachedReplayBuffer.
Each of the data's length (first dimension) must equal to the length of
cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ...,
cached_buffer_num - 1].
Return the array of episode_length and episode_reward with shape
(len(cached_buffer_ids), ...), where (episode_length[i],
episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's
corresponding episode result.
"""
if cached_buffer_ids is None:
cached_buffer_ids = np.arange(self.cached_buffer_num)
else: # make sure it is np.ndarray
cached_buffer_ids = np.asarray(cached_buffer_ids)
# in self.buffers, the first buffer is main_buffer
buffer_ids = cached_buffer_ids + 1 # type: ignore
result = super().add(obs, act, rew, done, obs_next, info,
policy, buffer_ids=buffer_ids, **kwargs)
# find the terminated episode, move data from cached buf to main buf
for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]:
self.main_buffer.update(self.cached_buffers[buffer_idx])
self.cached_buffers[buffer_idx].reset()
return result
def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch:
batch = super().__getitem__(index)
if self._is_prioritized:
indice = self._indices[index]
mask = indice < self.main_buffer.maxsize
batch.weight = np.ones(len(indice))
batch.weight[mask] = self.main_buffer.get_weight(indice[mask])
return batch
def update_weight(
self,
index: np.ndarray,
new_weight: Union[np.ndarray, torch.Tensor],
) -> None:
"""Update priority weight by index in main buffer.
:param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
if self._is_prioritized:
mask = index < self.main_buffer.maxsize
self.main_buffer.update_weight(index[mask], new_weight[mask])