Add CachedReplayBuffer and ReplayBufferManager (#278)

This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail.

1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer;
2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`;
3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle;
4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data;
5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False);
6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in;
7. Move buffer's documentation to `docs/tutorials/concepts.rst`.

Co-authored-by: n+e <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-01-29 12:23:18 +08:00 committed by GitHub
parent 1eb6137645
commit f0129f4ca7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 971 additions and 319 deletions

View File

@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
Buffer
------
.. automodule:: tianshou.data.ReplayBuffer
:members:
:noindex:
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
The following code snippet illustrates its usage, including:
- the basic data storage: ``add()``;
- get attribute, get slicing data, ...;
- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``;
- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``;
- save/load data from buffer: pickle and HDF5;
::
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... done = i % 4 == 0
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
>>> buf.obs
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
0, 0, 0, 0])
>>> # get all available index by using batch_size = 0
>>> indice = buf.sample_index(0)
>>> indice
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> # get one step previous/next transition
>>> buf.prev(indice)
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
>>> buf.next(indice)
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38):
.. raw:: html
<details>
<summary>Advance usage of ReplayBuffer</summary>
.. code-block:: python
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
... done=done, obs_next={'id': i + 1})
... print(i, ep_len, ep_rew)
0 1 0.0
1 0 0.0
2 0 0.0
3 0 0.0
4 0 0.0
5 5 15.0
6 0 0.0
7 0 0.0
8 0 0.0
9 0 0.0
10 5 40.0
11 0 0.0
12 0 0.0
13 0 0.0
14 0 0.0
15 5 65.0
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
obs: Batch(
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
),
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([False, True, False, False, False, False, True, False,
False]),
info: Batch(),
policy: Batch(),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7 7 8 9]
[ 7 8 9 10]
[11 11 11 11]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[ 7 7 7 7]
[ 7 7 7 8]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7 8 9 10]
[ 7 8 9 10]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[12 13 14 15]
[ 7 7 7 8]
[ 7 7 8 9]]
.. raw:: html
</details><br>
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
Policy

View File

@ -7,9 +7,10 @@ import h5py
import numpy as np
from timeit import timeit
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5
from tianshou.data import Batch, SegmentTree, ReplayBuffer
from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data import ReplayBufferManager, CachedReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
@ -38,11 +39,14 @@ def test_replaybuffer(size=10, bufsize=20):
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
b = ReplayBuffer(size=10)
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
# neg bsz should return empty index
assert b.sample_index(-1).tolist() == []
b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1
assert b.done[0] == 'str'
assert b.done[0]
assert b.obs_next[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.done[1:] == np.array(None))
assert np.all(b.obs_next[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
@ -91,7 +95,7 @@ def test_ignore_obs_next(size=10):
assert data.obs_next
def test_stack(size=5, bufsize=9, stack_num=4):
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
@ -115,7 +119,9 @@ def test_stack(size=5, bufsize=9, stack_num=4):
_, indice = buf2.sample(0)
assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1)
assert indice in [2, 6]
assert indice[0] in [2, 6]
batch, indice = buf2.sample(-1) # neg bsz -> no data
assert indice.tolist() == [] and len(batch) == 0
with pytest.raises(IndexError):
buf[bufsize * 2]
@ -152,6 +158,12 @@ def test_update():
assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all()
b = ListReplayBuffer()
with pytest.raises(NotImplementedError):
b.update(b)
b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
with pytest.raises(NotImplementedError):
b.update(b)
def test_segtree():
@ -260,8 +272,7 @@ def test_pickle():
vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
rew = np.array([1, 1])
for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3):
@ -287,18 +298,18 @@ def test_hdf5():
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
info_t = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': rew,
'done': 0,
'info': {"number": {"n": i}, 'extra': None},
'rew': np.array([1, 2]),
'done': i % 3 == 2,
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
}
buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs)
@ -320,10 +331,10 @@ def test_hdf5():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k]._maxsize == buffers[k]._maxsize
assert _buffers[k]._index == buffers[k]._index
assert _buffers[k].maxsize == buffers[k].maxsize
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
assert _buffers[k]._index == buffers[k]._index
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
@ -332,28 +343,350 @@ def test_hdf5():
assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra)
for path in paths.values():
os.remove(path)
# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x*x}
data = {"not_supported": lambda x: x * x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x*x)}
data = {"not_supported": np.array(lambda x: x * x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)
def test_replaybuffermanager():
buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)])
ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3],
done=[0, 0, 1], buffer_ids=[0, 1, 2])
assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3])
with pytest.raises(NotImplementedError):
# ReplayBufferManager cannot be updated
buf.update(buf)
# sample index / prev / next / unfinished_index
indice = buf.sample_index(11000)
assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample
batch, indice = buf.sample(0)
assert np.allclose(indice, [0, 5, 10])
indice_prev = buf.prev(indice)
assert np.allclose(indice_prev, indice), indice_prev
indice_next = buf.next(indice)
assert np.allclose(indice_next, indice), indice_next
assert np.allclose(buf.unfinished_index(), [0, 5])
buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3])
assert np.allclose(buf.unfinished_index(), [0, 5])
batch, indice = buf.sample(10)
batch, indice = buf.sample(0)
assert np.allclose(indice, [0, 5, 10, 15])
indice_prev = buf.prev(indice)
assert np.allclose(indice_prev, indice), indice_prev
indice_next = buf.next(indice)
assert np.allclose(indice_next, indice), indice_next
data = np.array([0, 0, 0, 0])
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
buf.add(obs=data, act=data, rew=data, done=1 - data,
buffer_ids=[0, 1, 2, 3])
assert len(buf) == 12
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1],
buffer_ids=[0, 1, 2, 3])
assert len(buf) == 20
indice = buf.sample_index(120000)
assert np.bincount(indice).min() >= 5000
batch, indice = buf.sample(10)
indice = buf.sample_index(0)
assert np.allclose(indice, np.arange(len(buf)))
# check the actual data stored in buf._meta
assert np.allclose(buf.done, [
0, 0, 1, 0, 0,
0, 0, 1, 0, 1,
1, 0, 1, 0, 0,
1, 0, 1, 0, 1,
])
assert np.allclose(buf.prev(indice), [
0, 0, 1, 3, 3,
5, 5, 6, 8, 8,
10, 11, 11, 13, 13,
15, 16, 16, 18, 18,
])
assert np.allclose(buf.next(indice), [
1, 2, 2, 4, 4,
6, 7, 7, 9, 9,
10, 12, 12, 14, 14,
15, 17, 17, 19, 19,
])
assert np.allclose(buf.unfinished_index(), [4, 14])
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1],
buffer_ids=[2])
assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1])
assert np.allclose(buf.unfinished_index(), [4])
indice = list(sorted(buf.sample_index(0)))
assert np.allclose(indice, np.arange(len(buf)))
assert np.allclose(buf.prev(indice), [
0, 0, 1, 3, 3,
5, 5, 6, 8, 8,
14, 11, 11, 13, 13,
15, 16, 16, 18, 18,
])
assert np.allclose(buf.next(indice), [
1, 2, 2, 4, 4,
6, 7, 7, 9, 9,
10, 12, 12, 14, 10,
15, 17, 17, 19, 19,
])
# corner case: list, int and -1
assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
batch = buf._meta
batch.info.n = np.ones(buf.maxsize)
buf.set_batch(batch)
assert np.allclose(buf.buffers[-1].info.n, [1] * 5)
assert buf.sample_index(-1).tolist() == []
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
def test_cachedbuffer():
buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
assert buf.sample_index(0).tolist() == []
# check the normal function/usage/storage in CachedReplayBuffer
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0],
cached_buffer_ids=[1])
obs = np.zeros(buf.maxsize)
obs[15] = 1
indice = buf.sample_index(0)
assert np.allclose(indice, [15])
assert np.allclose(buf.prev(indice), [15])
assert np.allclose(buf.next(indice), [15])
assert np.allclose(buf.obs, obs)
assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1],
cached_buffer_ids=[3])
obs[[0, 25]] = 2
indice = buf.sample_index(0)
assert np.allclose(indice, [0, 15])
assert np.allclose(buf.prev(indice), [0, 15])
assert np.allclose(buf.next(indice), [0, 15])
assert np.allclose(buf.obs, obs)
assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
assert np.allclose(buf.unfinished_index(), [15])
assert np.allclose(buf.sample_index(0), [0, 15])
ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4],
done=[0, 1], cached_buffer_ids=[3, 1])
assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
assert np.allclose(buf.obs, obs)
assert np.allclose(buf.unfinished_index(), [25])
indice = buf.sample_index(0)
assert np.allclose(indice, [0, 1, 2, 25])
assert np.allclose(buf.done[indice], [1, 0, 1, 0])
assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
assert np.allclose(buf.next(indice), [0, 2, 2, 25])
indice = buf.sample_index(10000)
assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample
# cached buffer with main_buffer size == 0 (no update)
# used in test_collector
buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
data = np.zeros(4)
rew = np.ones([4, 4])
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
assert np.allclose(buf.done, [
0, 0, 1, 0, 0,
0, 1, 1, 0, 0,
0, 0, 0, 0, 0,
0, 1, 0, 0, 0,
])
indice = buf.sample_index(0)
assert np.allclose(indice, [0, 1, 10, 11])
assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
assert np.allclose(buf.next(indice), [1, 1, 11, 11])
def test_multibuf_stack():
size = 5
bufsize = 9
stack_num = 4
cached_num = 3
env = MyTestEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
cached_num, size)
# test if CachedReplayBuffer can handle super corner case:
# prio-buffer + stack_num + ignore_obs_next + sample_avail
buf5 = CachedReplayBuffer(
PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num,
ignore_obs_next=True, sample_avail=True),
cached_num, size)
obs = env.reset(1)
for i in range(18):
obs_next, rew, done, info = env.step(1)
obs_list = np.array([obs + size * i for i in range(cached_num)])
act_list = [1] * cached_num
rew_list = [rew] * cached_num
done_list = [done] * cached_num
obs_next_list = -obs_list
info_list = [info] * cached_num
buf4.add(obs_list, act_list, rew_list, done_list,
obs_next_list, info_list)
buf5.add(obs_list, act_list, rew_list, done_list,
obs_next_list, info_list)
obs = obs_next
if done:
obs = env.reset(1)
# check the `add` order is correct
assert np.allclose(buf4.obs.reshape(-1), [
12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer
1, 2, 3, 4, 0, # cached_buffer[0]
6, 7, 8, 9, 0, # cached_buffer[1]
11, 12, 13, 14, 0, # cached_buffer[2]
]), buf4.obs
assert np.allclose(buf4.done, [
0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer
0, 0, 0, 1, 0, # cached_buffer[0]
0, 0, 0, 1, 0, # cached_buffer[1]
0, 0, 0, 1, 0, # cached_buffer[2]
]), buf4.done
assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
indice = sorted(buf4.sample_index(0))
assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
assert np.allclose(buf4[indice].obs[..., 0], [
[11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14],
[4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7],
[6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11],
[1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6],
[6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12],
])
assert np.allclose(buf4[indice].obs_next[..., 0], [
[11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14],
[4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8],
[6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12],
[1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7],
[6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12],
])
assert np.all(buf4.done == buf5.done)
indice = buf5.sample_index(0)
assert np.allclose(sorted(indice), [2, 7])
assert np.all(np.isin(buf5.sample_index(100), indice))
# manually change the stack num
buf5.stack_num = 2
for buf in buf5.buffers:
buf.stack_num = 2
indice = buf5.sample_index(0)
assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
batch, _ = buf5.sample(0)
assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
buf5.update_weight(indice, batch.weight * 0)
weight = buf5[np.arange(buf5.maxsize)].weight
modified_weight = weight[[0, 1, 2, 5, 6, 7]]
assert modified_weight.min() == modified_weight.max()
assert modified_weight.max() < 1
unmodified_weight = weight[[3, 4, 8]]
assert unmodified_weight.min() == unmodified_weight.max()
assert unmodified_weight.max() < 1
cached_weight = weight[9:]
assert cached_weight.min() == cached_weight.max() == 1
# test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
buf6 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num,
save_only_last_obs=True, ignore_obs_next=True),
cached_num, size)
obs = np.random.rand(size, 4, 84, 84)
buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1],
obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2])
assert buf6.obs.shape == (buf6.maxsize, 84, 84)
assert np.allclose(buf6.obs[0], obs[0, -1])
assert np.allclose(buf6.obs[14], obs[2, -1])
assert np.allclose(buf6.obs[19], obs[0, -1])
assert buf6[0].obs.shape == (4, 84, 84)
def test_multibuf_hdf5():
size = 100
buffers = {
"vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
"cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
info_t = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': np.array([1, 2]),
'done': i % 3 == 2,
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
}
buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
buffer_ids=[0, 1, 2])
buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
cached_buffer_ids=[0, 1, 2])
# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path
# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k].maxsize == buffers[k].maxsize
assert np.all(_buffers[k]._indices == buffers[k]._indices)
# check shallow copy in ReplayBufferManager
for k in ["vector", "cached"]:
buffers[k].info.number.n[0] = -100
assert buffers[k].buffers[0].info.number.n[0] == -100
# check if still behave normally
for k in ["vector", "cached"]:
kwargs = {
'obs': Batch(index=np.array([5])),
'act': 5,
'rew': np.array([2, 1]),
'done': False,
'info': {"number": {"n": i}, 'Timelimit.truncate': True},
}
buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
act = np.zeros(buffers[k].maxsize)
if k == "vector":
act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
act[size * 3] = 5
elif k == "cached":
act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
act[np.arange(3) + size] = np.array([3, 5, 2])
act[np.arange(3) + size * 2] = np.array([3, 5, 2])
act[np.arange(3) + size * 3] = np.array([3, 5, 2])
act[size * 4] = 5
assert np.allclose(buffers[k].act, act)
for path in paths.values():
os.remove(path)
if __name__ == '__main__':
test_hdf5()
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_pickle()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()
test_pickle()
test_hdf5()
test_replaybuffermanager()
test_cachedbuffer()
test_multibuf_stack()
test_multibuf_hdf5()

View File

@ -18,7 +18,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)

View File

@ -16,7 +16,7 @@ from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)

View File

@ -1,8 +1,8 @@
from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \
PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer
from tianshou.data.collector import Collector
__all__ = [
@ -14,5 +14,7 @@ __all__ = [
"ReplayBuffer",
"ListReplayBuffer",
"PrioritizedReplayBuffer",
"ReplayBufferManager",
"CachedReplayBuffer",
"Collector",
]

View File

@ -1,5 +1,6 @@
import h5py
import torch
import warnings
import numpy as np
from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional
@ -13,121 +14,13 @@ class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from \
interaction between the policy and environment.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
ReplayBuffer can be considered as a specialized form (or management) of
Batch. It stores all the data in a batch with circular-queue style.
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
For the example usage of ReplayBuffer, please check out Section Buffer in
:doc:`/tutorials/concepts`.
The following code snippet illustrates its usage:
::
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
>>> # load contents of HDF5 file into existing buffer
>>> # (only possible if size of buffer and data in file match)
>>> buf.load_contents_hdf5('buf.hdf5')
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
observation (save memory in atari tasks), and multi-modal observation (see
issue#38):
::
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
... obs_next={'id': i + 1})
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
info: Batch(),
obs: Batch(
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
),
policy: Batch(),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
:param int size: the size of replay buffer.
:param int size: the maximum size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than or equal to 1, defaults to 1 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to False.
@ -136,9 +29,11 @@ class ReplayBuffer:
False.
:param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to False.
This feature is not supported in Prioritized Replay Buffer currently.
"""
_reserved_keys = ("obs", "act", "rew", "done",
"obs_next", "info", "policy")
def __init__(
self,
size: int,
@ -147,16 +42,20 @@ class ReplayBuffer:
save_only_last_obs: bool = False,
sample_avail: bool = False,
) -> None:
self.options: Dict[str, Any] = {
"stack_num": stack_num,
"ignore_obs_next": ignore_obs_next,
"save_only_last_obs": save_only_last_obs,
"sample_avail": sample_avail,
}
super().__init__()
self._maxsize = size
self._indices = np.arange(size)
self.maxsize = size
assert stack_num > 0, "stack_num should greater than 0"
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index: List[int] = []
self._save_s_ = not ignore_obs_next
self._last_obs = save_only_last_obs
self._index = 0
self._size = 0
self._indices = np.arange(size)
self._save_obs_next = not ignore_obs_next
self._save_only_last_obs = save_only_last_obs
self._sample_avail = sample_avail
self._meta: Batch = Batch()
self.reset()
@ -181,20 +80,92 @@ class ReplayBuffer:
We need it because pickling buffer does not work out-of-the-box
("buffer.__getattr__" is customized).
"""
self._indices = np.arange(state["_maxsize"])
self.__dict__.update(state)
# compatible with version == 0.3.1's HDF5 data format
self._indices = np.arange(self.maxsize)
def __getstate__(self) -> dict:
exclude = {"_indices"}
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
return state
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
assert key not in self._reserved_keys, (
"key '{}' is reserved and cannot be assigned".format(key))
super().__setattr__(key, value)
def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__dict__, f)
@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf
def reset(self) -> None:
"""Clear all the data in replay buffer and episode statistics."""
self._index = self._size = 0
self._episode_length, self._episode_reward = 0, 0.0
def set_batch(self, batch: Batch) -> None:
"""Manually choose the batch you want the ReplayBuffer to manage."""
assert len(batch) == self.maxsize and \
set(batch.keys()).issubset(self._reserved_keys), \
"Input batch doesn't meet ReplayBuffer's data form requirement."
self._meta = batch
def unfinished_index(self) -> np.ndarray:
"""Return the index of unfinished episode."""
last = (self._index - 1) % self._size if self._size else 0
return np.array(
[last] if not self.done[last] and self._size else [], np.int)
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
"""Return the index of previous transition.
The index won't be modified if it is the beginning of an episode.
"""
index = (index - 1) % self._size
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
return (index + end_flag) % self._size
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
"""Return the index of next transition.
The index won't be modified if it is the end of an episode.
"""
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
return (index + (1 - end_flag)) % self._size
def update(self, buffer: "ReplayBuffer") -> None:
"""Move the data from the given buffer to current buffer."""
if len(buffer) == 0 or self.maxsize == 0:
return
stack_num, buffer.stack_num = buffer.stack_num, 1
save_only_last_obs = self._save_only_last_obs
self._save_only_last_obs = False
indices = buffer.sample_index(0) # get all available indices
for i in indices:
self.add(**buffer[i]) # type: ignore
buffer.stack_num = stack_num
self._save_only_last_obs = save_only_last_obs
def _buffer_allocator(self, key: List[str], value: Any) -> None:
"""Allocate memory on buffer._meta for new (key, value) pair."""
data = self._meta
for k in key[:-1]:
data = data[k]
data[key[-1]] = _create_value(value, self.maxsize)
def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
self._buffer_allocator([name], inst)
value = self._meta[name]
if isinstance(inst, (torch.Tensor, np.ndarray)):
if inst.shape != value.shape[1:]:
raise ValueError(
@ -203,33 +174,10 @@ class ReplayBuffer:
)
try:
value[self._index] = inst
except KeyError:
for key in set(inst.keys()).difference(value.__dict__.keys()):
value.__dict__[key] = _create_value(inst[key], self._maxsize)
value[self._index] = inst
@property
def stack_num(self) -> int:
return self._stack
@stack_num.setter
def stack_num(self, num: int) -> None:
assert num > 0, "stack_num should greater than 0"
self._stack = num
def update(self, buffer: "ReplayBuffer") -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i]) # type: ignore
i = (i + 1) % len(buffer)
if i == begin:
break
buffer.stack_num = stack_num_orig
except KeyError: # inst is a dict/Batch
for key in set(inst.keys()).difference(value.keys()):
self._buffer_allocator([name, key], inst[key])
self._meta[name][self._index] = inst
def add(
self,
@ -241,117 +189,110 @@ class ReplayBuffer:
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
**kwargs: Any,
) -> None:
"""Add a batch of data into replay buffer."""
) -> Tuple[int, Union[float, np.ndarray]]:
"""Add a batch of data into replay buffer.
Return (episode_length, episode_reward) if one episode is terminated,
otherwise return (0, 0.0).
"""
assert isinstance(
info, (dict, Batch)
), "You should return a dict in the last argument of env.step()."
if self._last_obs:
if self._save_only_last_obs:
obs = obs[-1]
self._add_to_buffer("obs", obs)
self._add_to_buffer("act", act)
# make sure the reward is a float instead of an int
self._add_to_buffer("rew", rew * 1.0) # type: ignore
self._add_to_buffer("done", done)
if self._save_s_:
# make sure the data type of reward is float instead of int
# but rew may be np.ndarray, so that we cannot use float(rew)
rew = rew * 1.0 # type: ignore
self._add_to_buffer("rew", rew)
self._add_to_buffer("done", bool(done)) # done should be a bool scalar
if self._save_obs_next:
if obs_next is None:
obs_next = Batch()
elif self._last_obs:
elif self._save_only_last_obs:
obs_next = obs_next[-1]
self._add_to_buffer("obs_next", obs_next)
self._add_to_buffer("info", info)
self._add_to_buffer("policy", policy)
# maintain available index for frame-stack sampling
if self._avail:
# update current frame
avail = sum(self.done[i] for i in range(
self._index - self.stack_num + 1, self._index)) == 0
if self._size < self.stack_num - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self.stack_num - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)
if self.maxsize > 0:
self._size = min(self._size + 1, self.maxsize)
self._index = (self._index + 1) % self.maxsize
else: # TODO: remove this after deleting ListReplayBuffer
self._size = self._index = self._size + 1
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
self._episode_reward += rew
self._episode_length += 1
if done:
result = self._episode_length, self._episode_reward
self._episode_length, self._episode_reward = 0, 0.0
return result
else:
self._size = self._index = self._index + 1
return 0, self._episode_reward * 0.0
def reset(self) -> None:
"""Clear all the data in replay buffer."""
self._index = 0
self._size = 0
self._avail_index = []
def sample_index(self, batch_size: int) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return
an empty numpy array if batch_size < 0 or no available index can be
sampled.
"""
if self.stack_num == 1 or not self._sample_avail: # most often case
if batch_size > 0:
return np.random.choice(self._size, batch_size)
elif batch_size == 0: # construct current available indices
return np.concatenate([
np.arange(self._index, self._size),
np.arange(self._index)])
else:
return np.array([], np.int)
else:
if batch_size < 0:
return np.array([], np.int)
all_indices = prev_indices = np.concatenate([
np.arange(self._index, self._size), np.arange(self._index)])
for _ in range(self.stack_num - 2):
prev_indices = self.prev(prev_indices)
all_indices = all_indices[prev_indices != self.prev(prev_indices)]
if batch_size > 0:
return np.random.choice(all_indices, batch_size)
else:
return all_indices
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size.
"""Get a random sample from buffer with size = batch_size.
Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0:
_all = self._avail_index if self._avail else self._size
indice = np.random.choice(_all, batch_size)
else:
if self._avail:
indice = np.array(self._avail_index)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
assert len(indice) > 0, "No available indice can be sampled."
return self[indice], indice
indices = self.sample_index(batch_size)
return self[indices], indices
def get(
self,
indice: Union[slice, int, np.integer, np.ndarray],
index: Union[int, np.integer, np.ndarray],
key: str,
stack_num: Optional[int] = None,
) -> Union[Batch, np.ndarray]:
"""Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
indice. The stack_num (here equals to 4) is given from buffer
initialization procedure.
index.
"""
if stack_num is None:
stack_num = self.stack_num
if stack_num == 1: # the most often case
if key != "obs_next" or self._save_s_:
val = self._meta.__dict__[key]
try:
return val[indice]
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
indice = self._indices[:self._size][indice]
done = self._meta.__dict__["done"]
if key == "obs_next" and not self._save_s_:
indice += 1 - done[indice].astype(np.int)
indice[indice == self._size] = 0
key = "obs"
val = self._meta.__dict__[key]
val = self._meta[key]
try:
if stack_num == 1:
return val[indice]
if stack_num == 1: # the most often case
return val[index]
stack: List[Any] = []
indice = np.asarray(index)
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
indice = self.prev(indice)
if isinstance(val, Batch):
return Batch.stack(stack, axis=indice.ndim)
else:
@ -369,31 +310,24 @@ class ReplayBuffer:
If stack_num is larger than 1, return the stacked obs and obs_next with
shape (batch, len, ...).
"""
if isinstance(index, slice): # change slice to np array
index = self._indices[:len(self)][index]
# raise KeyError first instead of AttributeError, to support np.array
obs = self.get(index, "obs")
if self._save_obs_next:
obs_next = self.get(index, "obs_next")
else:
obs_next = self.get(self.next(index), "obs")
return Batch(
obs=self.get(index, "obs"),
obs=obs,
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, "obs_next"),
obs_next=obs_next,
info=self.get(index, "info"),
policy=self.get(index, "policy"),
)
def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__getstate__(), f)
@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf
class ListReplayBuffer(ReplayBuffer):
"""List-based replay buffer.
@ -411,24 +345,27 @@ class ListReplayBuffer(ReplayBuffer):
"""
def __init__(self, **kwargs: Any) -> None:
warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.")
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
def _add_to_buffer(
self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]
) -> None:
if self._meta.__dict__.get(name) is None:
def _add_to_buffer(self, name: str, inst: Any) -> None:
if self._meta.get(name) is None:
self._meta.__dict__[name] = []
self._meta.__dict__[name].append(inst)
self._meta[name].append(inst)
def reset(self) -> None:
self._index = self._size = 0
for k in list(self._meta.__dict__.keys()):
if isinstance(self._meta.__dict__[k], list):
super().reset()
for k in self._meta.keys():
if isinstance(self._meta[k], list):
self._meta.__dict__[k] = []
def update(self, buffer: ReplayBuffer) -> None:
"""The ListReplayBuffer cannot be updated by any buffer."""
raise NotImplementedError
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
@ -464,8 +401,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
policy: Optional[Union[dict, Batch]] = {},
weight: Optional[Union[Number, np.number]] = None,
**kwargs: Any,
) -> None:
"""Add a batch of data into replay buffer."""
) -> Tuple[int, Union[float, np.ndarray]]:
if weight is None:
weight = self._max_prio
else:
@ -473,60 +409,289 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_prio = max(self._max_prio, weight)
self._min_prio = min(self._min_prio, weight)
self.weight[self._index] = weight ** self._alpha
super().add(obs, act, rew, done, obs_next, info, policy, **kwargs)
return super().add(obs, act, rew, done, obs_next,
info, policy, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability.
def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and self._size > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar)
else:
return super().sample_index(batch_size)
Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer.
def get_weight(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> np.ndarray:
"""Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function
to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less).
"""
assert self._size > 0, "Cannot sample a buffer with 0 size!"
if batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
scalar = np.random.rand(batch_size) * self.weight.reduce()
indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
return batch, indice
return (self.weight[index] / self._min_prio) ** (-self._beta)
def update_weight(
self,
indice: Union[np.ndarray],
new_weight: Union[np.ndarray, torch.Tensor]
index: np.ndarray,
new_weight: Union[np.ndarray, torch.Tensor],
) -> None:
"""Update priority weight by indice in this buffer.
"""Update priority weight by index in this buffer.
:param np.ndarray indice: indice you want to update weight.
:param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[indice] = weight ** self._alpha
self.weight[index] = weight ** self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch:
return Batch(
obs=self.get(index, "obs"),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, "obs_next"),
info=self.get(index, "info"),
policy=self.get(index, "policy"),
weight=self.weight[index],
)
batch = super().__getitem__(index)
batch.weight = self.get_weight(index)
return batch
class ReplayBufferManager(ReplayBuffer):
"""ReplayBufferManager contains a list of ReplayBuffer.
These replay buffers have contiguous memory layout, and the storage space
each buffer has is a shallow copy of the topmost memory.
:param int buffer_list: a list of ReplayBuffers needed to be handled.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
explanation.
"""
def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None:
self.buffer_num = len(buffer_list)
self.buffers = buffer_list
self._offset = []
offset = 0
for buf in self.buffers:
# overwrite sub-buffers' _buffer_allocator so that
# the top buffer can allocate new memory for all sub-buffers
buf._buffer_allocator = self._buffer_allocator # type: ignore
assert buf._meta.is_empty()
self._offset.append(offset)
offset += buf.maxsize
super().__init__(size=offset, **kwargs)
def __len__(self) -> int:
return sum([len(buf) for buf in self.buffers])
def reset(self) -> None:
for buf in self.buffers:
buf.reset()
def _set_batch_for_children(self) -> None:
for offset, buf in zip(self._offset, self.buffers):
buf.set_batch(self._meta[offset:offset + buf.maxsize])
def set_batch(self, batch: Batch) -> None:
super().set_batch(batch)
self._set_batch_for_children()
def unfinished_index(self) -> np.ndarray:
return np.concatenate([
buf.unfinished_index() + offset
for offset, buf in zip(self._offset, self.buffers)])
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
prev_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
return prev_indices
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
next_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
next_indices[mask] = buf.next(index[mask] - offset) + offset
return next_indices
def update(self, buffer: ReplayBuffer) -> None:
"""The ReplayBufferManager cannot be updated by any buffer."""
raise NotImplementedError
def _buffer_allocator(self, key: List[str], value: Any) -> None:
super()._buffer_allocator(key, value)
self._set_batch_for_children()
def add( # type: ignore
self,
obs: Any,
act: Any,
rew: np.ndarray,
done: np.ndarray,
obs_next: Any = Batch(),
info: Optional[Batch] = Batch(),
policy: Optional[Batch] = Batch(),
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
**kwargs: Any
) -> Tuple[np.ndarray, np.ndarray]:
"""Add a batch of data into ReplayBufferManager.
Each of the data's length (first dimension) must equal to the length of
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
Return the array of episode_length and episode_reward with shape
(len(buffer_ids), ...), where (episode_length[i], episode_reward[i])
refers to the buffer_ids[i]'s corresponding episode result.
"""
if buffer_ids is None:
buffer_ids = np.arange(self.buffer_num)
# assume each element in buffer_ids is unique
assert np.bincount(buffer_ids).max() == 1
batch = Batch(obs=obs, act=act, rew=rew, done=done,
obs_next=obs_next, info=info, policy=policy)
assert len(buffer_ids) == len(batch)
episode_lengths = [] # (len(buffer_ids),)
episode_rewards = [] # (len(buffer_ids), ...)
for batch_idx, buffer_id in enumerate(buffer_ids):
length, reward = self.buffers[buffer_id].add(**batch[batch_idx])
episode_lengths.append(length)
episode_rewards.append(reward)
return np.stack(episode_lengths), np.stack(episode_rewards)
def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
return np.array([], np.int)
if self._sample_avail and self.stack_num > 1:
all_indices = np.concatenate([
buf.sample_index(0) + offset
for offset, buf in zip(self._offset, self.buffers)])
if batch_size == 0:
return all_indices
else:
return np.random.choice(all_indices, batch_size)
if batch_size == 0: # get all available indices
sample_num = np.zeros(self.buffer_num, np.int)
else:
buffer_lens = np.array([len(buf) for buf in self.buffers])
buffer_idx = np.random.choice(self.buffer_num, batch_size,
p=buffer_lens / buffer_lens.sum())
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
sample_num[sample_num == 0] = -1
return np.concatenate([
buf.sample_index(bsz) + offset
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
])
class CachedReplayBuffer(ReplayBufferManager):
"""CachedReplayBuffer contains a given main buffer and n cached buffers, \
cached_buffer_num * ReplayBuffer(size=max_episode_length).
The memory layout is: ``| main_buffer | cached_buffers[0] |
cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``.
The data is first stored in cached buffers. When the episode is
terminated, the data will move to the main buffer and the corresponding
cached buffer will be reset.
:param ReplayBuffer main_buffer: the main buffer whose ``.update()``
function behaves normally.
:param int cached_buffer_num: number of ReplayBuffer needs to be created
for cached buffer.
:param int max_episode_length: the maximum length of one episode, used in
each cached buffer's maxsize.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` or
:class:`~tianshou.data.ReplayBufferManager` for more detailed
explanation.
"""
def __init__(
self,
main_buffer: ReplayBuffer,
cached_buffer_num: int,
max_episode_length: int,
) -> None:
assert cached_buffer_num > 0 and max_episode_length > 0
self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer)
kwargs = main_buffer.options
buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs)
for _ in range(cached_buffer_num)]
super().__init__(buffer_list=buffers, **kwargs)
self.main_buffer = self.buffers[0]
self.cached_buffers = self.buffers[1:]
self.cached_buffer_num = cached_buffer_num
def add( # type: ignore
self,
obs: Any,
act: Any,
rew: np.ndarray,
done: np.ndarray,
obs_next: Any = Batch(),
info: Optional[Batch] = Batch(),
policy: Optional[Batch] = Batch(),
cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray]:
"""Add a batch of data into CachedReplayBuffer.
Each of the data's length (first dimension) must equal to the length of
cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ...,
cached_buffer_num - 1].
Return the array of episode_length and episode_reward with shape
(len(cached_buffer_ids), ...), where (episode_length[i],
episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's
corresponding episode result.
"""
if cached_buffer_ids is None:
cached_buffer_ids = np.arange(self.cached_buffer_num)
else: # make sure it is np.ndarray
cached_buffer_ids = np.asarray(cached_buffer_ids)
# in self.buffers, the first buffer is main_buffer
buffer_ids = cached_buffer_ids + 1 # type: ignore
result = super().add(obs, act, rew, done, obs_next, info,
policy, buffer_ids=buffer_ids, **kwargs)
# find the terminated episode, move data from cached buf to main buf
for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]:
self.main_buffer.update(self.cached_buffers[buffer_idx])
self.cached_buffers[buffer_idx].reset()
return result
def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch:
batch = super().__getitem__(index)
if self._is_prioritized:
indice = self._indices[index]
mask = indice < self.main_buffer.maxsize
batch.weight = np.ones(len(indice))
batch.weight[mask] = self.main_buffer.get_weight(indice[mask])
return batch
def update_weight(
self,
index: np.ndarray,
new_weight: Union[np.ndarray, torch.Tensor],
) -> None:
"""Update priority weight by index in main buffer.
:param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
if self._is_prioritized:
mask = index < self.main_buffer.maxsize
self.main_buffer.update_weight(index[mask], new_weight[mask])