Pickle compatible for replay buffer and improve buffer.get (#182)

fix #84 and make buffer more efficient
This commit is contained in:
n+e 2020-08-16 16:26:23 +08:00 committed by GitHub
parent 7f3b817b24
commit 311a2beafb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 53 deletions

View File

@ -1,9 +1,11 @@
import torch
import pickle
import pytest
import numpy as np
from timeit import timeit
from tianshou.data import Batch, PrioritizedReplayBuffer, \
ReplayBuffer, SegmentTree
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
@ -39,10 +41,11 @@ def test_replaybuffer(size=10, bufsize=20):
def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_net=True)
buf = ReplayBuffer(size, ignore_obs_next=True)
for i in range(size):
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
'mask2': np.array([i + 4, 0, 1, 0, 0])},
'mask2': np.array([i + 4, 0, 1, 0, 0]),
'mask': i},
act={'act_id': i,
'position_id': i + 3},
rew=i,
@ -55,6 +58,22 @@ def test_ignore_obs_next(size=10):
assert isinstance(data, Batch)
assert isinstance(data2, Batch)
assert np.allclose(indice, orig)
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
buf.stack_num = 4
data = buf[indice]
data2 = buf[indice]
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, np.array([
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
assert np.allclose(data.info['if'], data2.info['if'])
assert np.allclose(data.info['if'], np.array([
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
assert data.obs_next
def test_stack(size=5, bufsize=9, stack_num=4):
@ -62,7 +81,7 @@ def test_stack(size=5, bufsize=9, stack_num=4):
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
obs = env.reset(1)
for i in range(15):
for i in range(16):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
buf2.add(obs, 1, rew, done, None, info)
@ -73,12 +92,11 @@ def test_stack(size=5, bufsize=9, stack_num=4):
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1))
print(buf)
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1))
_, indice = buf2.sample(0)
assert indice == [2]
assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1)
assert indice.sum() == 2
assert indice in [2, 6]
def test_priortized_replaybuffer(size=32, bufsize=15):
@ -107,7 +125,7 @@ def test_update():
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
done=False, info={'incident': 'found'})
done=i % 2 == 0, info={'incident': 'found'})
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
@ -214,10 +232,38 @@ def test_segtree():
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
def test_pickle():
size = 100
vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3):
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
for i in range(5):
pbuf.add(obs=Batch(index=np.array([i])),
act=2, rew=rew, done=0, weight=np.random.rand())
# save & load
_vbuf = pickle.loads(pickle.dumps(vbuf))
_lbuf = pickle.loads(pickle.dumps(lbuf))
_pbuf = pickle.loads(pickle.dumps(pbuf))
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
# make sure the meta var is identical
assert _vbuf.stack_num == vbuf.stack_num
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
pbuf.weight[np.arange(len(pbuf))])
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_pickle()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)

View File

@ -23,7 +23,7 @@ class ReplayBuffer:
The following code snippet illustrates its usage:
::
>>> import numpy as np
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
@ -35,6 +35,7 @@ class ReplayBuffer:
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
@ -54,6 +55,11 @@ class ReplayBuffer:
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
@ -119,6 +125,7 @@ class ReplayBuffer:
sample_avail: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._indices = np.arange(size)
self._stack = None
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
@ -137,9 +144,18 @@ class ReplayBuffer:
"""Return str(self)."""
return self.__class__.__name__ + self._meta.__repr__()[5:]
def __getattr__(self, key: str) -> Union['Batch', Any]:
def __getattr__(self, key: str) -> Any:
"""Return self.key"""
return self._meta.__dict__[key]
try:
return self._meta[key]
except KeyError as e:
raise AttributeError from e
def __setstate__(self, state):
"""Unpickling interface. We need it because pickling buffer does not
work out-of-the-box (``buffer.__getattr__`` is customized).
"""
self.__dict__.update(state)
def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
@ -149,9 +165,8 @@ class ReplayBuffer:
value = self._meta.__dict__[name]
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
raise ValueError(
"Cannot add data to a buffer with different shape, key: "
f"{name}, expect shape: {value.shape[1:]}, "
f"given shape: {inst.shape}.")
"Cannot add data to a buffer with different shape, with key "
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
try:
value[self._index] = inst
except KeyError:
@ -261,47 +276,42 @@ class ReplayBuffer:
"""
if stack_num is None:
stack_num = self.stack_num
if isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None
else self._size - indice.start if indice.start < 0
else indice.start,
self._size if indice.stop is None
else self._size - indice.stop if indice.stop < 0
else indice.stop,
1 if indice.step is None else indice.step)
else:
indice = np.array(indice, copy=True)
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True
if key == 'obs_next' and (not self._save_s_ or self.obs_next is None):
indice += 1 - self.done[indice].astype(np.int)
if stack_num == 1: # the most often case
if key != 'obs_next' or self._save_s_:
val = self._meta.__dict__[key]
try:
return val[indice]
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
indice = self._indices[:self._size][indice]
done = self._meta.__dict__['done']
if key == 'obs_next' and not self._save_s_:
indice += 1 - done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
val = self._meta.__dict__[key]
try:
if stack_num > 1:
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + self.done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim)
else:
stack = np.stack(stack, axis=indice.ndim)
if stack_num == 1:
return val[indice]
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim)
else:
stack = val[indice]
stack = np.stack(stack, axis=indice.ndim)
return stack
except IndexError as e:
stack = Batch()
if not isinstance(val, Batch) or len(val.__dict__) > 0:
raise e
self.done[last_index] = last_done
return stack
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
@ -380,7 +390,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
"""Return self.key"""
if key == 'weight':
return self._weight
return self._meta.__dict__[key]
return super().__getattr__(key)
def add(self,
obs: Union[dict, np.ndarray],

View File

@ -231,7 +231,8 @@ class BasePolicy(ABC, nn.Module):
usage is to update the sampling weight in prioritized experience
replay. Check out :ref:`policy_concept` for more information.
"""
if isinstance(buffer, PrioritizedReplayBuffer):
if isinstance(buffer, PrioritizedReplayBuffer) \
and hasattr(batch, 'weight'):
buffer.update_weight(indice, batch.weight)
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):