Pickle compatible for replay buffer and improve buffer.get (#182)
fix #84 and make buffer more efficient
This commit is contained in:
parent
7f3b817b24
commit
311a2beafb
@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import pickle
|
||||
import pytest
|
||||
import numpy as np
|
||||
from timeit import timeit
|
||||
|
||||
from tianshou.data import Batch, PrioritizedReplayBuffer, \
|
||||
ReplayBuffer, SegmentTree
|
||||
from tianshou.data import Batch, SegmentTree, \
|
||||
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -39,10 +41,11 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
|
||||
def test_ignore_obs_next(size=10):
|
||||
# Issue 82
|
||||
buf = ReplayBuffer(size, ignore_obs_net=True)
|
||||
buf = ReplayBuffer(size, ignore_obs_next=True)
|
||||
for i in range(size):
|
||||
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
||||
'mask2': np.array([i + 4, 0, 1, 0, 0])},
|
||||
'mask2': np.array([i + 4, 0, 1, 0, 0]),
|
||||
'mask': i},
|
||||
act={'act_id': i,
|
||||
'position_id': i + 3},
|
||||
rew=i,
|
||||
@ -55,6 +58,22 @@ def test_ignore_obs_next(size=10):
|
||||
assert isinstance(data, Batch)
|
||||
assert isinstance(data2, Batch)
|
||||
assert np.allclose(indice, orig)
|
||||
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
||||
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
|
||||
buf.stack_num = 4
|
||||
data = buf[indice]
|
||||
data2 = buf[indice]
|
||||
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
||||
assert np.allclose(data.obs_next.mask, np.array([
|
||||
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
|
||||
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
|
||||
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
|
||||
assert np.allclose(data.info['if'], data2.info['if'])
|
||||
assert np.allclose(data.info['if'], np.array([
|
||||
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
|
||||
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
|
||||
assert data.obs_next
|
||||
|
||||
|
||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
@ -62,7 +81,7 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||
obs = env.reset(1)
|
||||
for i in range(15):
|
||||
for i in range(16):
|
||||
obs_next, rew, done, info = env.step(1)
|
||||
buf.add(obs, 1, rew, done, None, info)
|
||||
buf2.add(obs, 1, rew, done, None, info)
|
||||
@ -73,12 +92,11 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
|
||||
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1))
|
||||
print(buf)
|
||||
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1))
|
||||
_, indice = buf2.sample(0)
|
||||
assert indice == [2]
|
||||
assert indice.tolist() == [2, 6]
|
||||
_, indice = buf2.sample(1)
|
||||
assert indice.sum() == 2
|
||||
assert indice in [2, 6]
|
||||
|
||||
|
||||
def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
@ -107,7 +125,7 @@ def test_update():
|
||||
buf2 = ReplayBuffer(4, stack_num=2)
|
||||
for i in range(5):
|
||||
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
||||
done=False, info={'incident': 'found'})
|
||||
done=i % 2 == 0, info={'incident': 'found'})
|
||||
assert len(buf1) > len(buf2)
|
||||
buf2.update(buf1)
|
||||
assert len(buf1) == len(buf2)
|
||||
@ -214,10 +232,38 @@ def test_segtree():
|
||||
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
|
||||
|
||||
|
||||
def test_pickle():
|
||||
size = 100
|
||||
vbuf = ReplayBuffer(size, stack_num=2)
|
||||
lbuf = ListReplayBuffer()
|
||||
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
rew = torch.tensor([1.]).to(device)
|
||||
for i in range(4):
|
||||
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
||||
for i in range(3):
|
||||
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
|
||||
for i in range(5):
|
||||
pbuf.add(obs=Batch(index=np.array([i])),
|
||||
act=2, rew=rew, done=0, weight=np.random.rand())
|
||||
# save & load
|
||||
_vbuf = pickle.loads(pickle.dumps(vbuf))
|
||||
_lbuf = pickle.loads(pickle.dumps(lbuf))
|
||||
_pbuf = pickle.loads(pickle.dumps(pbuf))
|
||||
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
|
||||
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
|
||||
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
|
||||
# make sure the meta var is identical
|
||||
assert _vbuf.stack_num == vbuf.stack_num
|
||||
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
|
||||
pbuf.weight[np.arange(len(pbuf))])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
test_pickle()
|
||||
test_segtree()
|
||||
test_priortized_replaybuffer()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
|
@ -23,7 +23,7 @@ class ReplayBuffer:
|
||||
The following code snippet illustrates its usage:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> import pickle, numpy as np
|
||||
>>> from tianshou.data import ReplayBuffer
|
||||
>>> buf = ReplayBuffer(size=20)
|
||||
>>> for i in range(3):
|
||||
@ -35,6 +35,7 @@ class ReplayBuffer:
|
||||
>>> # but there are only three valid items, so len(buf) == 3.
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
@ -54,6 +55,11 @@ class ReplayBuffer:
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
>>> len(buf)
|
||||
13
|
||||
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||
>>> len(buf)
|
||||
3
|
||||
|
||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||
@ -119,6 +125,7 @@ class ReplayBuffer:
|
||||
sample_avail: bool = False, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._indices = np.arange(size)
|
||||
self._stack = None
|
||||
self.stack_num = stack_num
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
@ -137,9 +144,18 @@ class ReplayBuffer:
|
||||
"""Return str(self)."""
|
||||
return self.__class__.__name__ + self._meta.__repr__()[5:]
|
||||
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
"""Return self.key"""
|
||||
return self._meta.__dict__[key]
|
||||
try:
|
||||
return self._meta[key]
|
||||
except KeyError as e:
|
||||
raise AttributeError from e
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Unpickling interface. We need it because pickling buffer does not
|
||||
work out-of-the-box (``buffer.__getattr__`` is customized).
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
try:
|
||||
@ -149,9 +165,8 @@ class ReplayBuffer:
|
||||
value = self._meta.__dict__[name]
|
||||
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
|
||||
raise ValueError(
|
||||
"Cannot add data to a buffer with different shape, key: "
|
||||
f"{name}, expect shape: {value.shape[1:]}, "
|
||||
f"given shape: {inst.shape}.")
|
||||
"Cannot add data to a buffer with different shape, with key "
|
||||
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
|
||||
try:
|
||||
value[self._index] = inst
|
||||
except KeyError:
|
||||
@ -261,47 +276,42 @@ class ReplayBuffer:
|
||||
"""
|
||||
if stack_num is None:
|
||||
stack_num = self.stack_num
|
||||
if isinstance(indice, slice):
|
||||
indice = np.arange(
|
||||
0 if indice.start is None
|
||||
else self._size - indice.start if indice.start < 0
|
||||
else indice.start,
|
||||
self._size if indice.stop is None
|
||||
else self._size - indice.stop if indice.stop < 0
|
||||
else indice.stop,
|
||||
1 if indice.step is None else indice.step)
|
||||
else:
|
||||
indice = np.array(indice, copy=True)
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
last_done, self.done[last_index] = self.done[last_index], True
|
||||
if key == 'obs_next' and (not self._save_s_ or self.obs_next is None):
|
||||
indice += 1 - self.done[indice].astype(np.int)
|
||||
if stack_num == 1: # the most often case
|
||||
if key != 'obs_next' or self._save_s_:
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
return val[indice]
|
||||
except IndexError as e:
|
||||
if not (isinstance(val, Batch) and val.is_empty()):
|
||||
raise e # val != Batch()
|
||||
return Batch()
|
||||
indice = self._indices[:self._size][indice]
|
||||
done = self._meta.__dict__['done']
|
||||
if key == 'obs_next' and not self._save_s_:
|
||||
indice += 1 - done[indice].astype(np.int)
|
||||
indice[indice == self._size] = 0
|
||||
key = 'obs'
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
if stack_num > 1:
|
||||
stack = []
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
pre_indice = np.asarray(indice - 1)
|
||||
pre_indice[pre_indice == -1] = self._size - 1
|
||||
indice = np.asarray(
|
||||
pre_indice + self.done[pre_indice].astype(np.int))
|
||||
indice[indice == self._size] = 0
|
||||
if isinstance(val, Batch):
|
||||
stack = Batch.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
stack = np.stack(stack, axis=indice.ndim)
|
||||
if stack_num == 1:
|
||||
return val[indice]
|
||||
stack = []
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
pre_indice = np.asarray(indice - 1)
|
||||
pre_indice[pre_indice == -1] = self._size - 1
|
||||
indice = np.asarray(
|
||||
pre_indice + done[pre_indice].astype(np.int))
|
||||
indice[indice == self._size] = 0
|
||||
if isinstance(val, Batch):
|
||||
stack = Batch.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
stack = val[indice]
|
||||
stack = np.stack(stack, axis=indice.ndim)
|
||||
return stack
|
||||
except IndexError as e:
|
||||
stack = Batch()
|
||||
if not isinstance(val, Batch) or len(val.__dict__) > 0:
|
||||
raise e
|
||||
self.done[last_index] = last_done
|
||||
return stack
|
||||
if not (isinstance(val, Batch) and val.is_empty()):
|
||||
raise e # val != Batch()
|
||||
return Batch()
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
@ -380,7 +390,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""Return self.key"""
|
||||
if key == 'weight':
|
||||
return self._weight
|
||||
return self._meta.__dict__[key]
|
||||
return super().__getattr__(key)
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
|
@ -231,7 +231,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
usage is to update the sampling weight in prioritized experience
|
||||
replay. Check out :ref:`policy_concept` for more information.
|
||||
"""
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
if isinstance(buffer, PrioritizedReplayBuffer) \
|
||||
and hasattr(batch, 'weight'):
|
||||
buffer.update_weight(indice, batch.weight)
|
||||
|
||||
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user