Add CachedReplayBuffer and ReplayBufferManager (#278)
This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail. 1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer; 2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`; 3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle; 4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data; 5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False); 6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in; 7. Move buffer's documentation to `docs/tutorials/concepts.rst`. Co-authored-by: n+e <trinkle23897@gmail.com>
This commit is contained in:
parent
1eb6137645
commit
f0129f4ca7
@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
|
||||
Buffer
|
||||
------
|
||||
|
||||
.. automodule:: tianshou.data.ReplayBuffer
|
||||
:members:
|
||||
:noindex:
|
||||
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
|
||||
|
||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
||||
The current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs`` the observation of step :math:`t` ;
|
||||
* ``act`` the action of step :math:`t` ;
|
||||
* ``rew`` the reward of step :math:`t` ;
|
||||
* ``done`` the done flag of step :math:`t` ;
|
||||
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
|
||||
* ``policy`` the data computed by policy in step :math:`t`;
|
||||
|
||||
The following code snippet illustrates its usage, including:
|
||||
|
||||
- the basic data storage: ``add()``;
|
||||
- get attribute, get slicing data, ...;
|
||||
- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``;
|
||||
- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``;
|
||||
- save/load data from buffer: pickle and HDF5;
|
||||
|
||||
::
|
||||
|
||||
>>> import pickle, numpy as np
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
|
||||
|
||||
>>> buf.obs
|
||||
# since we set size = 20, len(buf.obs) == 20.
|
||||
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
>>> # but there are only three valid items, so len(buf) == 3.
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # save to file "buf.pkl"
|
||||
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
|
||||
>>> # save to HDF5 file
|
||||
>>> buf.save_hdf5('buf.hdf5')
|
||||
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... done = i % 4 == 0
|
||||
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
|
||||
>>> len(buf2)
|
||||
10
|
||||
>>> buf2.obs
|
||||
# since its size = 10, it only stores the last 10 steps' result.
|
||||
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
|
||||
|
||||
>>> # move buf2's result into buf (meanwhile keep it chronologically)
|
||||
>>> buf.update(buf2)
|
||||
>>> buf.obs
|
||||
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
|
||||
0, 0, 0, 0])
|
||||
|
||||
>>> # get all available index by using batch_size = 0
|
||||
>>> indice = buf.sample_index(0)
|
||||
>>> indice
|
||||
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
|
||||
>>> # get one step previous/next transition
|
||||
>>> buf.prev(indice)
|
||||
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
|
||||
>>> buf.next(indice)
|
||||
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
|
||||
|
||||
>>> # get a random sample from buffer
|
||||
>>> # the batch_data is equal to buf[indice].
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
>>> len(buf)
|
||||
13
|
||||
|
||||
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # load complete buffer from HDF5 file
|
||||
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
|
||||
>>> len(buf)
|
||||
3
|
||||
|
||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38):
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<details>
|
||||
<summary>Advance usage of ReplayBuffer</summary>
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
|
||||
... done=done, obs_next={'id': i + 1})
|
||||
... print(i, ep_len, ep_rew)
|
||||
0 1 0.0
|
||||
1 0 0.0
|
||||
2 0 0.0
|
||||
3 0 0.0
|
||||
4 0 0.0
|
||||
5 5 15.0
|
||||
6 0 0.0
|
||||
7 0 0.0
|
||||
8 0 0.0
|
||||
9 0 0.0
|
||||
10 5 40.0
|
||||
11 0 0.0
|
||||
12 0 0.0
|
||||
13 0 0.0
|
||||
14 0 0.0
|
||||
15 5 65.0
|
||||
>>> print(buf) # you can see obs_next is not saved in buf
|
||||
ReplayBuffer(
|
||||
obs: Batch(
|
||||
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
|
||||
),
|
||||
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
|
||||
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
done: array([False, True, False, False, False, False, True, False,
|
||||
False]),
|
||||
info: Batch(),
|
||||
policy: Batch(),
|
||||
)
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get(index, 'obs').id)
|
||||
[[ 7 7 8 9]
|
||||
[ 7 8 9 10]
|
||||
[11 11 11 11]
|
||||
[11 11 11 12]
|
||||
[11 11 12 13]
|
||||
[11 12 13 14]
|
||||
[12 13 14 15]
|
||||
[ 7 7 7 7]
|
||||
[ 7 7 7 8]]
|
||||
>>> # here is another way to get the stacked data
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||
0
|
||||
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||
>>> print(buf[:].obs_next.id)
|
||||
[[ 7 8 9 10]
|
||||
[ 7 8 9 10]
|
||||
[11 11 11 12]
|
||||
[11 11 12 13]
|
||||
[11 12 13 14]
|
||||
[12 13 14 15]
|
||||
[12 13 14 15]
|
||||
[ 7 7 7 8]
|
||||
[ 7 7 8 9]]
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</details><br>
|
||||
|
||||
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
|
||||
|
||||
|
||||
Policy
|
||||
|
||||
@ -7,9 +7,10 @@ import h5py
|
||||
import numpy as np
|
||||
from timeit import timeit
|
||||
|
||||
from tianshou.data import Batch, SegmentTree, \
|
||||
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.utils.converter import to_hdf5
|
||||
from tianshou.data import Batch, SegmentTree, ReplayBuffer
|
||||
from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data import ReplayBufferManager, CachedReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -38,11 +39,14 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert (data.obs < size).all()
|
||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||
b = ReplayBuffer(size=10)
|
||||
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
||||
# neg bsz should return empty index
|
||||
assert b.sample_index(-1).tolist() == []
|
||||
b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}})
|
||||
assert b.obs[0] == 1
|
||||
assert b.done[0] == 'str'
|
||||
assert b.done[0]
|
||||
assert b.obs_next[0] == 'str'
|
||||
assert np.all(b.obs[1:] == 0)
|
||||
assert np.all(b.done[1:] == np.array(None))
|
||||
assert np.all(b.obs_next[1:] == np.array(None))
|
||||
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
||||
assert np.all(b.info.a[1:] == 0)
|
||||
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
||||
@ -91,7 +95,7 @@ def test_ignore_obs_next(size=10):
|
||||
assert data.obs_next
|
||||
|
||||
|
||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||
@ -115,7 +119,9 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
_, indice = buf2.sample(0)
|
||||
assert indice.tolist() == [2, 6]
|
||||
_, indice = buf2.sample(1)
|
||||
assert indice in [2, 6]
|
||||
assert indice[0] in [2, 6]
|
||||
batch, indice = buf2.sample(-1) # neg bsz -> no data
|
||||
assert indice.tolist() == [] and len(batch) == 0
|
||||
with pytest.raises(IndexError):
|
||||
buf[bufsize * 2]
|
||||
|
||||
@ -152,6 +158,12 @@ def test_update():
|
||||
assert len(buf1) == len(buf2)
|
||||
assert (buf2[0].obs == buf1[1].obs).all()
|
||||
assert (buf2[-1].obs == buf1[0].obs).all()
|
||||
b = ListReplayBuffer()
|
||||
with pytest.raises(NotImplementedError):
|
||||
b.update(b)
|
||||
b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
|
||||
with pytest.raises(NotImplementedError):
|
||||
b.update(b)
|
||||
|
||||
|
||||
def test_segtree():
|
||||
@ -260,8 +272,7 @@ def test_pickle():
|
||||
vbuf = ReplayBuffer(size, stack_num=2)
|
||||
lbuf = ListReplayBuffer()
|
||||
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
rew = torch.tensor([1.]).to(device)
|
||||
rew = np.array([1, 1])
|
||||
for i in range(4):
|
||||
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
||||
for i in range(3):
|
||||
@ -287,18 +298,18 @@ def test_hdf5():
|
||||
buffers = {
|
||||
"array": ReplayBuffer(size, stack_num=2),
|
||||
"list": ListReplayBuffer(),
|
||||
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
|
||||
}
|
||||
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
rew = torch.tensor([1.]).to(device)
|
||||
info_t = torch.tensor([1.]).to(device)
|
||||
for i in range(4):
|
||||
kwargs = {
|
||||
'obs': Batch(index=np.array([i])),
|
||||
'act': i,
|
||||
'rew': rew,
|
||||
'done': 0,
|
||||
'info': {"number": {"n": i}, 'extra': None},
|
||||
'rew': np.array([1, 2]),
|
||||
'done': i % 3 == 2,
|
||||
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
|
||||
}
|
||||
buffers["array"].add(**kwargs)
|
||||
buffers["list"].add(**kwargs)
|
||||
@ -320,10 +331,10 @@ def test_hdf5():
|
||||
assert len(_buffers[k]) == len(buffers[k])
|
||||
assert np.allclose(_buffers[k].act, buffers[k].act)
|
||||
assert _buffers[k].stack_num == buffers[k].stack_num
|
||||
assert _buffers[k]._maxsize == buffers[k]._maxsize
|
||||
assert _buffers[k]._index == buffers[k]._index
|
||||
assert _buffers[k].maxsize == buffers[k].maxsize
|
||||
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
||||
for k in ["array", "prioritized"]:
|
||||
assert _buffers[k]._index == buffers[k]._index
|
||||
assert isinstance(buffers[k].get(0, "info"), Batch)
|
||||
assert isinstance(_buffers[k].get(0, "info"), Batch)
|
||||
for k in ["array"]:
|
||||
@ -332,9 +343,6 @@ def test_hdf5():
|
||||
assert np.all(
|
||||
buffers[k][:].info.extra == _buffers[k][:].info.extra)
|
||||
|
||||
for path in paths.values():
|
||||
os.remove(path)
|
||||
|
||||
# raise exception when value cannot be pickled
|
||||
data = {"not_supported": lambda x: x * x}
|
||||
grp = h5py.Group
|
||||
@ -347,13 +355,338 @@ def test_hdf5():
|
||||
to_hdf5(data, grp)
|
||||
|
||||
|
||||
def test_replaybuffermanager():
|
||||
buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)])
|
||||
ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3],
|
||||
done=[0, 0, 1], buffer_ids=[0, 1, 2])
|
||||
assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3])
|
||||
with pytest.raises(NotImplementedError):
|
||||
# ReplayBufferManager cannot be updated
|
||||
buf.update(buf)
|
||||
# sample index / prev / next / unfinished_index
|
||||
indice = buf.sample_index(11000)
|
||||
assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample
|
||||
batch, indice = buf.sample(0)
|
||||
assert np.allclose(indice, [0, 5, 10])
|
||||
indice_prev = buf.prev(indice)
|
||||
assert np.allclose(indice_prev, indice), indice_prev
|
||||
indice_next = buf.next(indice)
|
||||
assert np.allclose(indice_next, indice), indice_next
|
||||
assert np.allclose(buf.unfinished_index(), [0, 5])
|
||||
buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3])
|
||||
assert np.allclose(buf.unfinished_index(), [0, 5])
|
||||
batch, indice = buf.sample(10)
|
||||
batch, indice = buf.sample(0)
|
||||
assert np.allclose(indice, [0, 5, 10, 15])
|
||||
indice_prev = buf.prev(indice)
|
||||
assert np.allclose(indice_prev, indice), indice_prev
|
||||
indice_next = buf.next(indice)
|
||||
assert np.allclose(indice_next, indice), indice_next
|
||||
data = np.array([0, 0, 0, 0])
|
||||
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
|
||||
buf.add(obs=data, act=data, rew=data, done=1 - data,
|
||||
buffer_ids=[0, 1, 2, 3])
|
||||
assert len(buf) == 12
|
||||
buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
|
||||
buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1],
|
||||
buffer_ids=[0, 1, 2, 3])
|
||||
assert len(buf) == 20
|
||||
indice = buf.sample_index(120000)
|
||||
assert np.bincount(indice).min() >= 5000
|
||||
batch, indice = buf.sample(10)
|
||||
indice = buf.sample_index(0)
|
||||
assert np.allclose(indice, np.arange(len(buf)))
|
||||
# check the actual data stored in buf._meta
|
||||
assert np.allclose(buf.done, [
|
||||
0, 0, 1, 0, 0,
|
||||
0, 0, 1, 0, 1,
|
||||
1, 0, 1, 0, 0,
|
||||
1, 0, 1, 0, 1,
|
||||
])
|
||||
assert np.allclose(buf.prev(indice), [
|
||||
0, 0, 1, 3, 3,
|
||||
5, 5, 6, 8, 8,
|
||||
10, 11, 11, 13, 13,
|
||||
15, 16, 16, 18, 18,
|
||||
])
|
||||
assert np.allclose(buf.next(indice), [
|
||||
1, 2, 2, 4, 4,
|
||||
6, 7, 7, 9, 9,
|
||||
10, 12, 12, 14, 14,
|
||||
15, 17, 17, 19, 19,
|
||||
])
|
||||
assert np.allclose(buf.unfinished_index(), [4, 14])
|
||||
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1],
|
||||
buffer_ids=[2])
|
||||
assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1])
|
||||
assert np.allclose(buf.unfinished_index(), [4])
|
||||
indice = list(sorted(buf.sample_index(0)))
|
||||
assert np.allclose(indice, np.arange(len(buf)))
|
||||
assert np.allclose(buf.prev(indice), [
|
||||
0, 0, 1, 3, 3,
|
||||
5, 5, 6, 8, 8,
|
||||
14, 11, 11, 13, 13,
|
||||
15, 16, 16, 18, 18,
|
||||
])
|
||||
assert np.allclose(buf.next(indice), [
|
||||
1, 2, 2, 4, 4,
|
||||
6, 7, 7, 9, 9,
|
||||
10, 12, 12, 14, 10,
|
||||
15, 17, 17, 19, 19,
|
||||
])
|
||||
# corner case: list, int and -1
|
||||
assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
|
||||
assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
|
||||
batch = buf._meta
|
||||
batch.info.n = np.ones(buf.maxsize)
|
||||
buf.set_batch(batch)
|
||||
assert np.allclose(buf.buffers[-1].info.n, [1] * 5)
|
||||
assert buf.sample_index(-1).tolist() == []
|
||||
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
|
||||
|
||||
|
||||
def test_cachedbuffer():
|
||||
buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
|
||||
assert buf.sample_index(0).tolist() == []
|
||||
# check the normal function/usage/storage in CachedReplayBuffer
|
||||
ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0],
|
||||
cached_buffer_ids=[1])
|
||||
obs = np.zeros(buf.maxsize)
|
||||
obs[15] = 1
|
||||
indice = buf.sample_index(0)
|
||||
assert np.allclose(indice, [15])
|
||||
assert np.allclose(buf.prev(indice), [15])
|
||||
assert np.allclose(buf.next(indice), [15])
|
||||
assert np.allclose(buf.obs, obs)
|
||||
assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
|
||||
ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1],
|
||||
cached_buffer_ids=[3])
|
||||
obs[[0, 25]] = 2
|
||||
indice = buf.sample_index(0)
|
||||
assert np.allclose(indice, [0, 15])
|
||||
assert np.allclose(buf.prev(indice), [0, 15])
|
||||
assert np.allclose(buf.next(indice), [0, 15])
|
||||
assert np.allclose(buf.obs, obs)
|
||||
assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
|
||||
assert np.allclose(buf.unfinished_index(), [15])
|
||||
assert np.allclose(buf.sample_index(0), [0, 15])
|
||||
ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4],
|
||||
done=[0, 1], cached_buffer_ids=[3, 1])
|
||||
assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
|
||||
obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
|
||||
assert np.allclose(buf.obs, obs)
|
||||
assert np.allclose(buf.unfinished_index(), [25])
|
||||
indice = buf.sample_index(0)
|
||||
assert np.allclose(indice, [0, 1, 2, 25])
|
||||
assert np.allclose(buf.done[indice], [1, 0, 1, 0])
|
||||
assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
|
||||
assert np.allclose(buf.next(indice), [0, 2, 2, 25])
|
||||
indice = buf.sample_index(10000)
|
||||
assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample
|
||||
# cached buffer with main_buffer size == 0 (no update)
|
||||
# used in test_collector
|
||||
buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
|
||||
data = np.zeros(4)
|
||||
rew = np.ones([4, 4])
|
||||
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
|
||||
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
|
||||
buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
|
||||
buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
|
||||
buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
|
||||
assert np.allclose(buf.done, [
|
||||
0, 0, 1, 0, 0,
|
||||
0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0,
|
||||
])
|
||||
indice = buf.sample_index(0)
|
||||
assert np.allclose(indice, [0, 1, 10, 11])
|
||||
assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
|
||||
assert np.allclose(buf.next(indice), [1, 1, 11, 11])
|
||||
|
||||
|
||||
def test_multibuf_stack():
|
||||
size = 5
|
||||
bufsize = 9
|
||||
stack_num = 4
|
||||
cached_num = 3
|
||||
env = MyTestEnv(size)
|
||||
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
|
||||
buf4 = CachedReplayBuffer(
|
||||
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
|
||||
cached_num, size)
|
||||
# test if CachedReplayBuffer can handle super corner case:
|
||||
# prio-buffer + stack_num + ignore_obs_next + sample_avail
|
||||
buf5 = CachedReplayBuffer(
|
||||
PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num,
|
||||
ignore_obs_next=True, sample_avail=True),
|
||||
cached_num, size)
|
||||
obs = env.reset(1)
|
||||
for i in range(18):
|
||||
obs_next, rew, done, info = env.step(1)
|
||||
obs_list = np.array([obs + size * i for i in range(cached_num)])
|
||||
act_list = [1] * cached_num
|
||||
rew_list = [rew] * cached_num
|
||||
done_list = [done] * cached_num
|
||||
obs_next_list = -obs_list
|
||||
info_list = [info] * cached_num
|
||||
buf4.add(obs_list, act_list, rew_list, done_list,
|
||||
obs_next_list, info_list)
|
||||
buf5.add(obs_list, act_list, rew_list, done_list,
|
||||
obs_next_list, info_list)
|
||||
obs = obs_next
|
||||
if done:
|
||||
obs = env.reset(1)
|
||||
# check the `add` order is correct
|
||||
assert np.allclose(buf4.obs.reshape(-1), [
|
||||
12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer
|
||||
1, 2, 3, 4, 0, # cached_buffer[0]
|
||||
6, 7, 8, 9, 0, # cached_buffer[1]
|
||||
11, 12, 13, 14, 0, # cached_buffer[2]
|
||||
]), buf4.obs
|
||||
assert np.allclose(buf4.done, [
|
||||
0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer
|
||||
0, 0, 0, 1, 0, # cached_buffer[0]
|
||||
0, 0, 0, 1, 0, # cached_buffer[1]
|
||||
0, 0, 0, 1, 0, # cached_buffer[2]
|
||||
]), buf4.done
|
||||
assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
|
||||
indice = sorted(buf4.sample_index(0))
|
||||
assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
|
||||
assert np.allclose(buf4[indice].obs[..., 0], [
|
||||
[11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14],
|
||||
[4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7],
|
||||
[6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11],
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6],
|
||||
[6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12],
|
||||
])
|
||||
assert np.allclose(buf4[indice].obs_next[..., 0], [
|
||||
[11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14],
|
||||
[4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8],
|
||||
[6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12],
|
||||
[1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7],
|
||||
[6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12],
|
||||
])
|
||||
assert np.all(buf4.done == buf5.done)
|
||||
indice = buf5.sample_index(0)
|
||||
assert np.allclose(sorted(indice), [2, 7])
|
||||
assert np.all(np.isin(buf5.sample_index(100), indice))
|
||||
# manually change the stack num
|
||||
buf5.stack_num = 2
|
||||
for buf in buf5.buffers:
|
||||
buf.stack_num = 2
|
||||
indice = buf5.sample_index(0)
|
||||
assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
|
||||
batch, _ = buf5.sample(0)
|
||||
assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
|
||||
buf5.update_weight(indice, batch.weight * 0)
|
||||
weight = buf5[np.arange(buf5.maxsize)].weight
|
||||
modified_weight = weight[[0, 1, 2, 5, 6, 7]]
|
||||
assert modified_weight.min() == modified_weight.max()
|
||||
assert modified_weight.max() < 1
|
||||
unmodified_weight = weight[[3, 4, 8]]
|
||||
assert unmodified_weight.min() == unmodified_weight.max()
|
||||
assert unmodified_weight.max() < 1
|
||||
cached_weight = weight[9:]
|
||||
assert cached_weight.min() == cached_weight.max() == 1
|
||||
# test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
|
||||
buf6 = CachedReplayBuffer(
|
||||
ReplayBuffer(bufsize, stack_num=stack_num,
|
||||
save_only_last_obs=True, ignore_obs_next=True),
|
||||
cached_num, size)
|
||||
obs = np.random.rand(size, 4, 84, 84)
|
||||
buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1],
|
||||
obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2])
|
||||
assert buf6.obs.shape == (buf6.maxsize, 84, 84)
|
||||
assert np.allclose(buf6.obs[0], obs[0, -1])
|
||||
assert np.allclose(buf6.obs[14], obs[2, -1])
|
||||
assert np.allclose(buf6.obs[19], obs[0, -1])
|
||||
assert buf6[0].obs.shape == (4, 84, 84)
|
||||
|
||||
|
||||
def test_multibuf_hdf5():
|
||||
size = 100
|
||||
buffers = {
|
||||
"vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
|
||||
"cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
|
||||
}
|
||||
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
info_t = torch.tensor([1.]).to(device)
|
||||
for i in range(4):
|
||||
kwargs = {
|
||||
'obs': Batch(index=np.array([i])),
|
||||
'act': i,
|
||||
'rew': np.array([1, 2]),
|
||||
'done': i % 3 == 2,
|
||||
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
|
||||
}
|
||||
buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
|
||||
buffer_ids=[0, 1, 2])
|
||||
buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
|
||||
cached_buffer_ids=[0, 1, 2])
|
||||
|
||||
# save
|
||||
paths = {}
|
||||
for k, buf in buffers.items():
|
||||
f, path = tempfile.mkstemp(suffix='.hdf5')
|
||||
os.close(f)
|
||||
buf.save_hdf5(path)
|
||||
paths[k] = path
|
||||
|
||||
# load replay buffer
|
||||
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
|
||||
|
||||
# compare
|
||||
for k in buffers.keys():
|
||||
assert len(_buffers[k]) == len(buffers[k])
|
||||
assert np.allclose(_buffers[k].act, buffers[k].act)
|
||||
assert _buffers[k].stack_num == buffers[k].stack_num
|
||||
assert _buffers[k].maxsize == buffers[k].maxsize
|
||||
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
||||
# check shallow copy in ReplayBufferManager
|
||||
for k in ["vector", "cached"]:
|
||||
buffers[k].info.number.n[0] = -100
|
||||
assert buffers[k].buffers[0].info.number.n[0] == -100
|
||||
# check if still behave normally
|
||||
for k in ["vector", "cached"]:
|
||||
kwargs = {
|
||||
'obs': Batch(index=np.array([5])),
|
||||
'act': 5,
|
||||
'rew': np.array([2, 1]),
|
||||
'done': False,
|
||||
'info': {"number": {"n": i}, 'Timelimit.truncate': True},
|
||||
}
|
||||
buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
|
||||
act = np.zeros(buffers[k].maxsize)
|
||||
if k == "vector":
|
||||
act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
|
||||
act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
|
||||
act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
|
||||
act[size * 3] = 5
|
||||
elif k == "cached":
|
||||
act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
|
||||
act[np.arange(3) + size] = np.array([3, 5, 2])
|
||||
act[np.arange(3) + size * 2] = np.array([3, 5, 2])
|
||||
act[np.arange(3) + size * 3] = np.array([3, 5, 2])
|
||||
act[size * 4] = 5
|
||||
assert np.allclose(buffers[k].act, act)
|
||||
|
||||
for path in paths.values():
|
||||
os.remove(path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hdf5()
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
test_pickle()
|
||||
test_segtree()
|
||||
test_priortized_replaybuffer()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
test_update()
|
||||
test_pickle()
|
||||
test_hdf5()
|
||||
test_replaybuffermanager()
|
||||
test_cachedbuffer()
|
||||
test_multibuf_stack()
|
||||
test_multibuf_hdf5()
|
||||
|
||||
@ -18,7 +18,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
|
||||
@ -16,7 +16,7 @@ from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from tianshou.data.batch import Batch
|
||||
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
|
||||
from tianshou.data.utils.segtree import SegmentTree
|
||||
from tianshou.data.buffer import ReplayBuffer, \
|
||||
ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \
|
||||
PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer
|
||||
from tianshou.data.collector import Collector
|
||||
|
||||
__all__ = [
|
||||
@ -14,5 +14,7 @@ __all__ = [
|
||||
"ReplayBuffer",
|
||||
"ListReplayBuffer",
|
||||
"PrioritizedReplayBuffer",
|
||||
"ReplayBufferManager",
|
||||
"CachedReplayBuffer",
|
||||
"Collector",
|
||||
]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import h5py
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
@ -13,121 +14,13 @@ class ReplayBuffer:
|
||||
""":class:`~tianshou.data.ReplayBuffer` stores data generated from \
|
||||
interaction between the policy and environment.
|
||||
|
||||
The current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
ReplayBuffer can be considered as a specialized form (or management) of
|
||||
Batch. It stores all the data in a batch with circular-queue style.
|
||||
|
||||
* ``obs`` the observation of step :math:`t` ;
|
||||
* ``act`` the action of step :math:`t` ;
|
||||
* ``rew`` the reward of step :math:`t` ;
|
||||
* ``done`` the done flag of step :math:`t` ;
|
||||
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
|
||||
function returns 4 arguments, and the last one is ``info``);
|
||||
* ``policy`` the data computed by policy in step :math:`t`;
|
||||
For the example usage of ReplayBuffer, please check out Section Buffer in
|
||||
:doc:`/tutorials/concepts`.
|
||||
|
||||
The following code snippet illustrates its usage:
|
||||
::
|
||||
|
||||
>>> import pickle, numpy as np
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> buf.obs
|
||||
# since we set size = 20, len(buf.obs) == 20.
|
||||
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0.])
|
||||
>>> # but there are only three valid items, so len(buf) == 3.
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # save to file "buf.pkl"
|
||||
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
|
||||
>>> # save to HDF5 file
|
||||
>>> buf.save_hdf5('buf.hdf5')
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
>>> len(buf2)
|
||||
10
|
||||
>>> buf2.obs
|
||||
# since its size = 10, it only stores the last 10 steps' result.
|
||||
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
||||
|
||||
>>> # move buf2's result into buf (meanwhile keep it chronologically)
|
||||
>>> buf.update(buf2)
|
||||
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
||||
0., 0., 0., 0., 0., 0., 0.])
|
||||
|
||||
>>> # get a random sample from buffer
|
||||
>>> # the batch_data is equal to buf[indice].
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
>>> len(buf)
|
||||
13
|
||||
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # load complete buffer from HDF5 file
|
||||
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # load contents of HDF5 file into existing buffer
|
||||
>>> # (only possible if size of buffer and data in file match)
|
||||
>>> buf.load_contents_hdf5('buf.hdf5')
|
||||
>>> len(buf)
|
||||
3
|
||||
|
||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||
observation (save memory in atari tasks), and multi-modal observation (see
|
||||
issue#38):
|
||||
::
|
||||
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
|
||||
... obs_next={'id': i + 1})
|
||||
>>> print(buf) # you can see obs_next is not saved in buf
|
||||
ReplayBuffer(
|
||||
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
|
||||
info: Batch(),
|
||||
obs: Batch(
|
||||
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
),
|
||||
policy: Batch(),
|
||||
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
)
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get(index, 'obs').id)
|
||||
[[ 7. 7. 8. 9.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 11.]
|
||||
[11. 11. 11. 12.]
|
||||
[11. 11. 12. 13.]
|
||||
[11. 12. 13. 14.]
|
||||
[12. 13. 14. 15.]
|
||||
[ 7. 7. 7. 7.]
|
||||
[ 7. 7. 7. 8.]]
|
||||
>>> # here is another way to get the stacked data
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||
0.0
|
||||
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||
>>> print(buf[:].obs_next.id)
|
||||
[[ 7. 8. 9. 10.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 12.]
|
||||
[11. 11. 12. 13.]
|
||||
[11. 12. 13. 14.]
|
||||
[12. 13. 14. 15.]
|
||||
[12. 13. 14. 15.]
|
||||
[ 7. 7. 7. 8.]
|
||||
[ 7. 7. 8. 9.]]
|
||||
|
||||
:param int size: the size of replay buffer.
|
||||
:param int size: the maximum size of replay buffer.
|
||||
:param int stack_num: the frame-stack sampling argument, should be greater
|
||||
than or equal to 1, defaults to 1 (no stacking).
|
||||
:param bool ignore_obs_next: whether to store obs_next, defaults to False.
|
||||
@ -136,9 +29,11 @@ class ReplayBuffer:
|
||||
False.
|
||||
:param bool sample_avail: the parameter indicating sampling only available
|
||||
index when using frame-stack sampling method, defaults to False.
|
||||
This feature is not supported in Prioritized Replay Buffer currently.
|
||||
"""
|
||||
|
||||
_reserved_keys = ("obs", "act", "rew", "done",
|
||||
"obs_next", "info", "policy")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@ -147,16 +42,20 @@ class ReplayBuffer:
|
||||
save_only_last_obs: bool = False,
|
||||
sample_avail: bool = False,
|
||||
) -> None:
|
||||
self.options: Dict[str, Any] = {
|
||||
"stack_num": stack_num,
|
||||
"ignore_obs_next": ignore_obs_next,
|
||||
"save_only_last_obs": save_only_last_obs,
|
||||
"sample_avail": sample_avail,
|
||||
}
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._indices = np.arange(size)
|
||||
self.maxsize = size
|
||||
assert stack_num > 0, "stack_num should greater than 0"
|
||||
self.stack_num = stack_num
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index: List[int] = []
|
||||
self._save_s_ = not ignore_obs_next
|
||||
self._last_obs = save_only_last_obs
|
||||
self._index = 0
|
||||
self._size = 0
|
||||
self._indices = np.arange(size)
|
||||
self._save_obs_next = not ignore_obs_next
|
||||
self._save_only_last_obs = save_only_last_obs
|
||||
self._sample_avail = sample_avail
|
||||
self._meta: Batch = Batch()
|
||||
self.reset()
|
||||
|
||||
@ -181,20 +80,92 @@ class ReplayBuffer:
|
||||
We need it because pickling buffer does not work out-of-the-box
|
||||
("buffer.__getattr__" is customized).
|
||||
"""
|
||||
self._indices = np.arange(state["_maxsize"])
|
||||
self.__dict__.update(state)
|
||||
# compatible with version == 0.3.1's HDF5 data format
|
||||
self._indices = np.arange(self.maxsize)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
exclude = {"_indices"}
|
||||
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
|
||||
return state
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
"""Set self.key = value."""
|
||||
assert key not in self._reserved_keys, (
|
||||
"key '{}' is reserved and cannot be assigned".format(key))
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def save_hdf5(self, path: str) -> None:
|
||||
"""Save replay buffer to HDF5 file."""
|
||||
with h5py.File(path, "w") as f:
|
||||
to_hdf5(self.__dict__, f)
|
||||
|
||||
@classmethod
|
||||
def load_hdf5(
|
||||
cls, path: str, device: Optional[str] = None
|
||||
) -> "ReplayBuffer":
|
||||
"""Load replay buffer from HDF5 file."""
|
||||
with h5py.File(path, "r") as f:
|
||||
buf = cls.__new__(cls)
|
||||
buf.__setstate__(from_hdf5(f, device=device))
|
||||
return buf
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all the data in replay buffer and episode statistics."""
|
||||
self._index = self._size = 0
|
||||
self._episode_length, self._episode_reward = 0, 0.0
|
||||
|
||||
def set_batch(self, batch: Batch) -> None:
|
||||
"""Manually choose the batch you want the ReplayBuffer to manage."""
|
||||
assert len(batch) == self.maxsize and \
|
||||
set(batch.keys()).issubset(self._reserved_keys), \
|
||||
"Input batch doesn't meet ReplayBuffer's data form requirement."
|
||||
self._meta = batch
|
||||
|
||||
def unfinished_index(self) -> np.ndarray:
|
||||
"""Return the index of unfinished episode."""
|
||||
last = (self._index - 1) % self._size if self._size else 0
|
||||
return np.array(
|
||||
[last] if not self.done[last] and self._size else [], np.int)
|
||||
|
||||
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
"""Return the index of previous transition.
|
||||
|
||||
The index won't be modified if it is the beginning of an episode.
|
||||
"""
|
||||
index = (index - 1) % self._size
|
||||
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
||||
return (index + end_flag) % self._size
|
||||
|
||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
"""Return the index of next transition.
|
||||
|
||||
The index won't be modified if it is the end of an episode.
|
||||
"""
|
||||
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
|
||||
return (index + (1 - end_flag)) % self._size
|
||||
|
||||
def update(self, buffer: "ReplayBuffer") -> None:
|
||||
"""Move the data from the given buffer to current buffer."""
|
||||
if len(buffer) == 0 or self.maxsize == 0:
|
||||
return
|
||||
stack_num, buffer.stack_num = buffer.stack_num, 1
|
||||
save_only_last_obs = self._save_only_last_obs
|
||||
self._save_only_last_obs = False
|
||||
indices = buffer.sample_index(0) # get all available indices
|
||||
for i in indices:
|
||||
self.add(**buffer[i]) # type: ignore
|
||||
buffer.stack_num = stack_num
|
||||
self._save_only_last_obs = save_only_last_obs
|
||||
|
||||
def _buffer_allocator(self, key: List[str], value: Any) -> None:
|
||||
"""Allocate memory on buffer._meta for new (key, value) pair."""
|
||||
data = self._meta
|
||||
for k in key[:-1]:
|
||||
data = data[k]
|
||||
data[key[-1]] = _create_value(value, self.maxsize)
|
||||
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
try:
|
||||
value = self._meta.__dict__[name]
|
||||
except KeyError:
|
||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
||||
value = self._meta.__dict__[name]
|
||||
self._buffer_allocator([name], inst)
|
||||
value = self._meta[name]
|
||||
if isinstance(inst, (torch.Tensor, np.ndarray)):
|
||||
if inst.shape != value.shape[1:]:
|
||||
raise ValueError(
|
||||
@ -203,33 +174,10 @@ class ReplayBuffer:
|
||||
)
|
||||
try:
|
||||
value[self._index] = inst
|
||||
except KeyError:
|
||||
for key in set(inst.keys()).difference(value.__dict__.keys()):
|
||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
||||
value[self._index] = inst
|
||||
|
||||
@property
|
||||
def stack_num(self) -> int:
|
||||
return self._stack
|
||||
|
||||
@stack_num.setter
|
||||
def stack_num(self, num: int) -> None:
|
||||
assert num > 0, "stack_num should greater than 0"
|
||||
self._stack = num
|
||||
|
||||
def update(self, buffer: "ReplayBuffer") -> None:
|
||||
"""Move the data from the given buffer to self."""
|
||||
if len(buffer) == 0:
|
||||
return
|
||||
i = begin = buffer._index % len(buffer)
|
||||
stack_num_orig = buffer.stack_num
|
||||
buffer.stack_num = 1
|
||||
while True:
|
||||
self.add(**buffer[i]) # type: ignore
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
buffer.stack_num = stack_num_orig
|
||||
except KeyError: # inst is a dict/Batch
|
||||
for key in set(inst.keys()).difference(value.keys()):
|
||||
self._buffer_allocator([name, key], inst[key])
|
||||
self._meta[name][self._index] = inst
|
||||
|
||||
def add(
|
||||
self,
|
||||
@ -241,117 +189,110 @@ class ReplayBuffer:
|
||||
info: Optional[Union[dict, Batch]] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
) -> Tuple[int, Union[float, np.ndarray]]:
|
||||
"""Add a batch of data into replay buffer.
|
||||
|
||||
Return (episode_length, episode_reward) if one episode is terminated,
|
||||
otherwise return (0, 0.0).
|
||||
"""
|
||||
assert isinstance(
|
||||
info, (dict, Batch)
|
||||
), "You should return a dict in the last argument of env.step()."
|
||||
if self._last_obs:
|
||||
if self._save_only_last_obs:
|
||||
obs = obs[-1]
|
||||
self._add_to_buffer("obs", obs)
|
||||
self._add_to_buffer("act", act)
|
||||
# make sure the reward is a float instead of an int
|
||||
self._add_to_buffer("rew", rew * 1.0) # type: ignore
|
||||
self._add_to_buffer("done", done)
|
||||
if self._save_s_:
|
||||
# make sure the data type of reward is float instead of int
|
||||
# but rew may be np.ndarray, so that we cannot use float(rew)
|
||||
rew = rew * 1.0 # type: ignore
|
||||
self._add_to_buffer("rew", rew)
|
||||
self._add_to_buffer("done", bool(done)) # done should be a bool scalar
|
||||
if self._save_obs_next:
|
||||
if obs_next is None:
|
||||
obs_next = Batch()
|
||||
elif self._last_obs:
|
||||
elif self._save_only_last_obs:
|
||||
obs_next = obs_next[-1]
|
||||
self._add_to_buffer("obs_next", obs_next)
|
||||
self._add_to_buffer("info", info)
|
||||
self._add_to_buffer("policy", policy)
|
||||
|
||||
# maintain available index for frame-stack sampling
|
||||
if self._avail:
|
||||
# update current frame
|
||||
avail = sum(self.done[i] for i in range(
|
||||
self._index - self.stack_num + 1, self._index)) == 0
|
||||
if self._size < self.stack_num - 1:
|
||||
avail = False
|
||||
if avail and self._index not in self._avail_index:
|
||||
self._avail_index.append(self._index)
|
||||
elif not avail and self._index in self._avail_index:
|
||||
self._avail_index.remove(self._index)
|
||||
# remove the later available frame because of broken storage
|
||||
t = (self._index + self.stack_num - 1) % self._maxsize
|
||||
if t in self._avail_index:
|
||||
self._avail_index.remove(t)
|
||||
if self.maxsize > 0:
|
||||
self._size = min(self._size + 1, self.maxsize)
|
||||
self._index = (self._index + 1) % self.maxsize
|
||||
else: # TODO: remove this after deleting ListReplayBuffer
|
||||
self._size = self._index = self._size + 1
|
||||
|
||||
if self._maxsize > 0:
|
||||
self._size = min(self._size + 1, self._maxsize)
|
||||
self._index = (self._index + 1) % self._maxsize
|
||||
self._episode_reward += rew
|
||||
self._episode_length += 1
|
||||
|
||||
if done:
|
||||
result = self._episode_length, self._episode_reward
|
||||
self._episode_length, self._episode_reward = 0, 0.0
|
||||
return result
|
||||
else:
|
||||
self._size = self._index = self._index + 1
|
||||
return 0, self._episode_reward * 0.0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all the data in replay buffer."""
|
||||
self._index = 0
|
||||
self._size = 0
|
||||
self._avail_index = []
|
||||
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||
"""Get a random sample of index with size = batch_size.
|
||||
|
||||
Return all available indices in the buffer if batch_size is 0; return
|
||||
an empty numpy array if batch_size < 0 or no available index can be
|
||||
sampled.
|
||||
"""
|
||||
if self.stack_num == 1 or not self._sample_avail: # most often case
|
||||
if batch_size > 0:
|
||||
return np.random.choice(self._size, batch_size)
|
||||
elif batch_size == 0: # construct current available indices
|
||||
return np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(self._index)])
|
||||
else:
|
||||
return np.array([], np.int)
|
||||
else:
|
||||
if batch_size < 0:
|
||||
return np.array([], np.int)
|
||||
all_indices = prev_indices = np.concatenate([
|
||||
np.arange(self._index, self._size), np.arange(self._index)])
|
||||
for _ in range(self.stack_num - 2):
|
||||
prev_indices = self.prev(prev_indices)
|
||||
all_indices = all_indices[prev_indices != self.prev(prev_indices)]
|
||||
if batch_size > 0:
|
||||
return np.random.choice(all_indices, batch_size)
|
||||
else:
|
||||
return all_indices
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with size equal to batch_size.
|
||||
"""Get a random sample from buffer with size = batch_size.
|
||||
|
||||
Return all the data in the buffer if batch_size is 0.
|
||||
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
"""
|
||||
if batch_size > 0:
|
||||
_all = self._avail_index if self._avail else self._size
|
||||
indice = np.random.choice(_all, batch_size)
|
||||
else:
|
||||
if self._avail:
|
||||
indice = np.array(self._avail_index)
|
||||
else:
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
assert len(indice) > 0, "No available indice can be sampled."
|
||||
return self[indice], indice
|
||||
indices = self.sample_index(batch_size)
|
||||
return self[indices], indices
|
||||
|
||||
def get(
|
||||
self,
|
||||
indice: Union[slice, int, np.integer, np.ndarray],
|
||||
index: Union[int, np.integer, np.ndarray],
|
||||
key: str,
|
||||
stack_num: Optional[int] = None,
|
||||
) -> Union[Batch, np.ndarray]:
|
||||
"""Return the stacked result.
|
||||
|
||||
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
|
||||
indice. The stack_num (here equals to 4) is given from buffer
|
||||
initialization procedure.
|
||||
index.
|
||||
"""
|
||||
if stack_num is None:
|
||||
stack_num = self.stack_num
|
||||
val = self._meta[key]
|
||||
try:
|
||||
if stack_num == 1: # the most often case
|
||||
if key != "obs_next" or self._save_s_:
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
return val[indice]
|
||||
except IndexError as e:
|
||||
if not (isinstance(val, Batch) and val.is_empty()):
|
||||
raise e # val != Batch()
|
||||
return Batch()
|
||||
indice = self._indices[:self._size][indice]
|
||||
done = self._meta.__dict__["done"]
|
||||
if key == "obs_next" and not self._save_s_:
|
||||
indice += 1 - done[indice].astype(np.int)
|
||||
indice[indice == self._size] = 0
|
||||
key = "obs"
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
if stack_num == 1:
|
||||
return val[indice]
|
||||
return val[index]
|
||||
stack: List[Any] = []
|
||||
indice = np.asarray(index)
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
pre_indice = np.asarray(indice - 1)
|
||||
pre_indice[pre_indice == -1] = self._size - 1
|
||||
indice = np.asarray(
|
||||
pre_indice + done[pre_indice].astype(np.int))
|
||||
indice[indice == self._size] = 0
|
||||
indice = self.prev(indice)
|
||||
if isinstance(val, Batch):
|
||||
return Batch.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
@ -369,31 +310,24 @@ class ReplayBuffer:
|
||||
If stack_num is larger than 1, return the stacked obs and obs_next with
|
||||
shape (batch, len, ...).
|
||||
"""
|
||||
if isinstance(index, slice): # change slice to np array
|
||||
index = self._indices[:len(self)][index]
|
||||
# raise KeyError first instead of AttributeError, to support np.array
|
||||
obs = self.get(index, "obs")
|
||||
if self._save_obs_next:
|
||||
obs_next = self.get(index, "obs_next")
|
||||
else:
|
||||
obs_next = self.get(self.next(index), "obs")
|
||||
return Batch(
|
||||
obs=self.get(index, "obs"),
|
||||
obs=obs,
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, "obs_next"),
|
||||
obs_next=obs_next,
|
||||
info=self.get(index, "info"),
|
||||
policy=self.get(index, "policy"),
|
||||
)
|
||||
|
||||
def save_hdf5(self, path: str) -> None:
|
||||
"""Save replay buffer to HDF5 file."""
|
||||
with h5py.File(path, "w") as f:
|
||||
to_hdf5(self.__getstate__(), f)
|
||||
|
||||
@classmethod
|
||||
def load_hdf5(
|
||||
cls, path: str, device: Optional[str] = None
|
||||
) -> "ReplayBuffer":
|
||||
"""Load replay buffer from HDF5 file."""
|
||||
with h5py.File(path, "r") as f:
|
||||
buf = cls.__new__(cls)
|
||||
buf.__setstate__(from_hdf5(f, device=device))
|
||||
return buf
|
||||
|
||||
|
||||
class ListReplayBuffer(ReplayBuffer):
|
||||
"""List-based replay buffer.
|
||||
@ -411,24 +345,27 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.")
|
||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
||||
|
||||
def _add_to_buffer(
|
||||
self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]
|
||||
) -> None:
|
||||
if self._meta.__dict__.get(name) is None:
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
if self._meta.get(name) is None:
|
||||
self._meta.__dict__[name] = []
|
||||
self._meta.__dict__[name].append(inst)
|
||||
self._meta[name].append(inst)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._index = self._size = 0
|
||||
for k in list(self._meta.__dict__.keys()):
|
||||
if isinstance(self._meta.__dict__[k], list):
|
||||
super().reset()
|
||||
for k in self._meta.keys():
|
||||
if isinstance(self._meta[k], list):
|
||||
self._meta.__dict__[k] = []
|
||||
|
||||
def update(self, buffer: ReplayBuffer) -> None:
|
||||
"""The ListReplayBuffer cannot be updated by any buffer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
|
||||
@ -464,8 +401,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
weight: Optional[Union[Number, np.number]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
) -> Tuple[int, Union[float, np.ndarray]]:
|
||||
if weight is None:
|
||||
weight = self._max_prio
|
||||
else:
|
||||
@ -473,60 +409,289 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._max_prio = max(self._max_prio, weight)
|
||||
self._min_prio = min(self._min_prio, weight)
|
||||
self.weight[self._index] = weight ** self._alpha
|
||||
super().add(obs, act, rew, done, obs_next, info, policy, **kwargs)
|
||||
return super().add(obs, act, rew, done, obs_next,
|
||||
info, policy, **kwargs)
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with priority probability.
|
||||
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||
if batch_size > 0 and self._size > 0:
|
||||
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
||||
return self.weight.get_prefix_sum_idx(scalar)
|
||||
else:
|
||||
return super().sample_index(batch_size)
|
||||
|
||||
Return all the data in the buffer if batch_size is 0.
|
||||
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
def get_weight(
|
||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||
) -> np.ndarray:
|
||||
"""Get the importance sampling weight.
|
||||
|
||||
The "weight" in the returned Batch is the weight on loss function
|
||||
to de-bias the sampling process (some transition tuples are sampled
|
||||
more often so their losses are weighted less).
|
||||
"""
|
||||
assert self._size > 0, "Cannot sample a buffer with 0 size!"
|
||||
if batch_size == 0:
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
else:
|
||||
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
||||
indice = self.weight.get_prefix_sum_idx(scalar)
|
||||
batch = self[indice]
|
||||
# important sampling weight calculation
|
||||
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
|
||||
# simplified formula: (p_j/p_min)**(-beta)
|
||||
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
|
||||
return batch, indice
|
||||
return (self.weight[index] / self._min_prio) ** (-self._beta)
|
||||
|
||||
def update_weight(
|
||||
self,
|
||||
indice: Union[np.ndarray],
|
||||
new_weight: Union[np.ndarray, torch.Tensor]
|
||||
index: np.ndarray,
|
||||
new_weight: Union[np.ndarray, torch.Tensor],
|
||||
) -> None:
|
||||
"""Update priority weight by indice in this buffer.
|
||||
"""Update priority weight by index in this buffer.
|
||||
|
||||
:param np.ndarray indice: indice you want to update weight.
|
||||
:param np.ndarray index: index you want to update weight.
|
||||
:param np.ndarray new_weight: new priority weight you want to update.
|
||||
"""
|
||||
weight = np.abs(to_numpy(new_weight)) + self.__eps
|
||||
self.weight[indice] = weight ** self._alpha
|
||||
self.weight[index] = weight ** self._alpha
|
||||
self._max_prio = max(self._max_prio, weight.max())
|
||||
self._min_prio = min(self._min_prio, weight.min())
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||
) -> Batch:
|
||||
return Batch(
|
||||
obs=self.get(index, "obs"),
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, "obs_next"),
|
||||
info=self.get(index, "info"),
|
||||
policy=self.get(index, "policy"),
|
||||
weight=self.weight[index],
|
||||
)
|
||||
batch = super().__getitem__(index)
|
||||
batch.weight = self.get_weight(index)
|
||||
return batch
|
||||
|
||||
|
||||
class ReplayBufferManager(ReplayBuffer):
|
||||
"""ReplayBufferManager contains a list of ReplayBuffer.
|
||||
|
||||
These replay buffers have contiguous memory layout, and the storage space
|
||||
each buffer has is a shallow copy of the topmost memory.
|
||||
|
||||
:param int buffer_list: a list of ReplayBuffers needed to be handled.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None:
|
||||
self.buffer_num = len(buffer_list)
|
||||
self.buffers = buffer_list
|
||||
self._offset = []
|
||||
offset = 0
|
||||
for buf in self.buffers:
|
||||
# overwrite sub-buffers' _buffer_allocator so that
|
||||
# the top buffer can allocate new memory for all sub-buffers
|
||||
buf._buffer_allocator = self._buffer_allocator # type: ignore
|
||||
assert buf._meta.is_empty()
|
||||
self._offset.append(offset)
|
||||
offset += buf.maxsize
|
||||
super().__init__(size=offset, **kwargs)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum([len(buf) for buf in self.buffers])
|
||||
|
||||
def reset(self) -> None:
|
||||
for buf in self.buffers:
|
||||
buf.reset()
|
||||
|
||||
def _set_batch_for_children(self) -> None:
|
||||
for offset, buf in zip(self._offset, self.buffers):
|
||||
buf.set_batch(self._meta[offset:offset + buf.maxsize])
|
||||
|
||||
def set_batch(self, batch: Batch) -> None:
|
||||
super().set_batch(batch)
|
||||
self._set_batch_for_children()
|
||||
|
||||
def unfinished_index(self) -> np.ndarray:
|
||||
return np.concatenate([
|
||||
buf.unfinished_index() + offset
|
||||
for offset, buf in zip(self._offset, self.buffers)])
|
||||
|
||||
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
index = np.asarray(index) % self.maxsize
|
||||
prev_indices = np.zeros_like(index)
|
||||
for offset, buf in zip(self._offset, self.buffers):
|
||||
mask = (offset <= index) & (index < offset + buf.maxsize)
|
||||
if np.any(mask):
|
||||
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
|
||||
return prev_indices
|
||||
|
||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
index = np.asarray(index) % self.maxsize
|
||||
next_indices = np.zeros_like(index)
|
||||
for offset, buf in zip(self._offset, self.buffers):
|
||||
mask = (offset <= index) & (index < offset + buf.maxsize)
|
||||
if np.any(mask):
|
||||
next_indices[mask] = buf.next(index[mask] - offset) + offset
|
||||
return next_indices
|
||||
|
||||
def update(self, buffer: ReplayBuffer) -> None:
|
||||
"""The ReplayBufferManager cannot be updated by any buffer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _buffer_allocator(self, key: List[str], value: Any) -> None:
|
||||
super()._buffer_allocator(key, value)
|
||||
self._set_batch_for_children()
|
||||
|
||||
def add( # type: ignore
|
||||
self,
|
||||
obs: Any,
|
||||
act: Any,
|
||||
rew: np.ndarray,
|
||||
done: np.ndarray,
|
||||
obs_next: Any = Batch(),
|
||||
info: Optional[Batch] = Batch(),
|
||||
policy: Optional[Batch] = Batch(),
|
||||
buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
|
||||
**kwargs: Any
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Add a batch of data into ReplayBufferManager.
|
||||
|
||||
Each of the data's length (first dimension) must equal to the length of
|
||||
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
|
||||
|
||||
Return the array of episode_length and episode_reward with shape
|
||||
(len(buffer_ids), ...), where (episode_length[i], episode_reward[i])
|
||||
refers to the buffer_ids[i]'s corresponding episode result.
|
||||
"""
|
||||
if buffer_ids is None:
|
||||
buffer_ids = np.arange(self.buffer_num)
|
||||
# assume each element in buffer_ids is unique
|
||||
assert np.bincount(buffer_ids).max() == 1
|
||||
batch = Batch(obs=obs, act=act, rew=rew, done=done,
|
||||
obs_next=obs_next, info=info, policy=policy)
|
||||
assert len(buffer_ids) == len(batch)
|
||||
episode_lengths = [] # (len(buffer_ids),)
|
||||
episode_rewards = [] # (len(buffer_ids), ...)
|
||||
for batch_idx, buffer_id in enumerate(buffer_ids):
|
||||
length, reward = self.buffers[buffer_id].add(**batch[batch_idx])
|
||||
episode_lengths.append(length)
|
||||
episode_rewards.append(reward)
|
||||
return np.stack(episode_lengths), np.stack(episode_rewards)
|
||||
|
||||
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||
if batch_size < 0:
|
||||
return np.array([], np.int)
|
||||
if self._sample_avail and self.stack_num > 1:
|
||||
all_indices = np.concatenate([
|
||||
buf.sample_index(0) + offset
|
||||
for offset, buf in zip(self._offset, self.buffers)])
|
||||
if batch_size == 0:
|
||||
return all_indices
|
||||
else:
|
||||
return np.random.choice(all_indices, batch_size)
|
||||
if batch_size == 0: # get all available indices
|
||||
sample_num = np.zeros(self.buffer_num, np.int)
|
||||
else:
|
||||
buffer_lens = np.array([len(buf) for buf in self.buffers])
|
||||
buffer_idx = np.random.choice(self.buffer_num, batch_size,
|
||||
p=buffer_lens / buffer_lens.sum())
|
||||
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
|
||||
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
|
||||
sample_num[sample_num == 0] = -1
|
||||
|
||||
return np.concatenate([
|
||||
buf.sample_index(bsz) + offset
|
||||
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
|
||||
])
|
||||
|
||||
|
||||
class CachedReplayBuffer(ReplayBufferManager):
|
||||
"""CachedReplayBuffer contains a given main buffer and n cached buffers, \
|
||||
cached_buffer_num * ReplayBuffer(size=max_episode_length).
|
||||
|
||||
The memory layout is: ``| main_buffer | cached_buffers[0] |
|
||||
cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``.
|
||||
|
||||
The data is first stored in cached buffers. When the episode is
|
||||
terminated, the data will move to the main buffer and the corresponding
|
||||
cached buffer will be reset.
|
||||
|
||||
:param ReplayBuffer main_buffer: the main buffer whose ``.update()``
|
||||
function behaves normally.
|
||||
:param int cached_buffer_num: number of ReplayBuffer needs to be created
|
||||
for cached buffer.
|
||||
:param int max_episode_length: the maximum length of one episode, used in
|
||||
each cached buffer's maxsize.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` or
|
||||
:class:`~tianshou.data.ReplayBufferManager` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_buffer: ReplayBuffer,
|
||||
cached_buffer_num: int,
|
||||
max_episode_length: int,
|
||||
) -> None:
|
||||
assert cached_buffer_num > 0 and max_episode_length > 0
|
||||
self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer)
|
||||
kwargs = main_buffer.options
|
||||
buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs)
|
||||
for _ in range(cached_buffer_num)]
|
||||
super().__init__(buffer_list=buffers, **kwargs)
|
||||
self.main_buffer = self.buffers[0]
|
||||
self.cached_buffers = self.buffers[1:]
|
||||
self.cached_buffer_num = cached_buffer_num
|
||||
|
||||
def add( # type: ignore
|
||||
self,
|
||||
obs: Any,
|
||||
act: Any,
|
||||
rew: np.ndarray,
|
||||
done: np.ndarray,
|
||||
obs_next: Any = Batch(),
|
||||
info: Optional[Batch] = Batch(),
|
||||
policy: Optional[Batch] = Batch(),
|
||||
cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Add a batch of data into CachedReplayBuffer.
|
||||
|
||||
Each of the data's length (first dimension) must equal to the length of
|
||||
cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ...,
|
||||
cached_buffer_num - 1].
|
||||
|
||||
Return the array of episode_length and episode_reward with shape
|
||||
(len(cached_buffer_ids), ...), where (episode_length[i],
|
||||
episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's
|
||||
corresponding episode result.
|
||||
"""
|
||||
if cached_buffer_ids is None:
|
||||
cached_buffer_ids = np.arange(self.cached_buffer_num)
|
||||
else: # make sure it is np.ndarray
|
||||
cached_buffer_ids = np.asarray(cached_buffer_ids)
|
||||
# in self.buffers, the first buffer is main_buffer
|
||||
buffer_ids = cached_buffer_ids + 1 # type: ignore
|
||||
result = super().add(obs, act, rew, done, obs_next, info,
|
||||
policy, buffer_ids=buffer_ids, **kwargs)
|
||||
# find the terminated episode, move data from cached buf to main buf
|
||||
for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]:
|
||||
self.main_buffer.update(self.cached_buffers[buffer_idx])
|
||||
self.cached_buffers[buffer_idx].reset()
|
||||
return result
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||
) -> Batch:
|
||||
batch = super().__getitem__(index)
|
||||
if self._is_prioritized:
|
||||
indice = self._indices[index]
|
||||
mask = indice < self.main_buffer.maxsize
|
||||
batch.weight = np.ones(len(indice))
|
||||
batch.weight[mask] = self.main_buffer.get_weight(indice[mask])
|
||||
return batch
|
||||
|
||||
def update_weight(
|
||||
self,
|
||||
index: np.ndarray,
|
||||
new_weight: Union[np.ndarray, torch.Tensor],
|
||||
) -> None:
|
||||
"""Update priority weight by index in main buffer.
|
||||
|
||||
:param np.ndarray index: index you want to update weight.
|
||||
:param np.ndarray new_weight: new priority weight you want to update.
|
||||
"""
|
||||
if self._is_prioritized:
|
||||
mask = index < self.main_buffer.maxsize
|
||||
self.main_buffer.update_weight(index[mask], new_weight[mask])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user