Pickle compatible for replay buffer and improve buffer.get (#182)
fix #84 and make buffer more efficient
This commit is contained in:
parent
7f3b817b24
commit
311a2beafb
@ -1,9 +1,11 @@
|
|||||||
|
import torch
|
||||||
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from timeit import timeit
|
from timeit import timeit
|
||||||
|
|
||||||
from tianshou.data import Batch, PrioritizedReplayBuffer, \
|
from tianshou.data import Batch, SegmentTree, \
|
||||||
ReplayBuffer, SegmentTree
|
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -39,10 +41,11 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
|
|
||||||
def test_ignore_obs_next(size=10):
|
def test_ignore_obs_next(size=10):
|
||||||
# Issue 82
|
# Issue 82
|
||||||
buf = ReplayBuffer(size, ignore_obs_net=True)
|
buf = ReplayBuffer(size, ignore_obs_next=True)
|
||||||
for i in range(size):
|
for i in range(size):
|
||||||
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
|
||||||
'mask2': np.array([i + 4, 0, 1, 0, 0])},
|
'mask2': np.array([i + 4, 0, 1, 0, 0]),
|
||||||
|
'mask': i},
|
||||||
act={'act_id': i,
|
act={'act_id': i,
|
||||||
'position_id': i + 3},
|
'position_id': i + 3},
|
||||||
rew=i,
|
rew=i,
|
||||||
@ -55,6 +58,22 @@ def test_ignore_obs_next(size=10):
|
|||||||
assert isinstance(data, Batch)
|
assert isinstance(data, Batch)
|
||||||
assert isinstance(data2, Batch)
|
assert isinstance(data2, Batch)
|
||||||
assert np.allclose(indice, orig)
|
assert np.allclose(indice, orig)
|
||||||
|
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
||||||
|
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
|
||||||
|
buf.stack_num = 4
|
||||||
|
data = buf[indice]
|
||||||
|
data2 = buf[indice]
|
||||||
|
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
|
||||||
|
assert np.allclose(data.obs_next.mask, np.array([
|
||||||
|
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
|
||||||
|
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
|
||||||
|
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
|
||||||
|
assert np.allclose(data.info['if'], data2.info['if'])
|
||||||
|
assert np.allclose(data.info['if'], np.array([
|
||||||
|
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||||
|
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
|
||||||
|
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
|
||||||
|
assert data.obs_next
|
||||||
|
|
||||||
|
|
||||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||||
@ -62,7 +81,7 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
|||||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||||
obs = env.reset(1)
|
obs = env.reset(1)
|
||||||
for i in range(15):
|
for i in range(16):
|
||||||
obs_next, rew, done, info = env.step(1)
|
obs_next, rew, done, info = env.step(1)
|
||||||
buf.add(obs, 1, rew, done, None, info)
|
buf.add(obs, 1, rew, done, None, info)
|
||||||
buf2.add(obs, 1, rew, done, None, info)
|
buf2.add(obs, 1, rew, done, None, info)
|
||||||
@ -73,12 +92,11 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
|||||||
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
|
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
|
||||||
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1))
|
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1))
|
||||||
print(buf)
|
|
||||||
_, indice = buf2.sample(0)
|
_, indice = buf2.sample(0)
|
||||||
assert indice == [2]
|
assert indice.tolist() == [2, 6]
|
||||||
_, indice = buf2.sample(1)
|
_, indice = buf2.sample(1)
|
||||||
assert indice.sum() == 2
|
assert indice in [2, 6]
|
||||||
|
|
||||||
|
|
||||||
def test_priortized_replaybuffer(size=32, bufsize=15):
|
def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||||
@ -107,7 +125,7 @@ def test_update():
|
|||||||
buf2 = ReplayBuffer(4, stack_num=2)
|
buf2 = ReplayBuffer(4, stack_num=2)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
||||||
done=False, info={'incident': 'found'})
|
done=i % 2 == 0, info={'incident': 'found'})
|
||||||
assert len(buf1) > len(buf2)
|
assert len(buf1) > len(buf2)
|
||||||
buf2.update(buf1)
|
buf2.update(buf1)
|
||||||
assert len(buf1) == len(buf2)
|
assert len(buf1) == len(buf2)
|
||||||
@ -214,10 +232,38 @@ def test_segtree():
|
|||||||
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
|
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickle():
|
||||||
|
size = 100
|
||||||
|
vbuf = ReplayBuffer(size, stack_num=2)
|
||||||
|
lbuf = ListReplayBuffer()
|
||||||
|
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
rew = torch.tensor([1.]).to(device)
|
||||||
|
for i in range(4):
|
||||||
|
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
|
||||||
|
for i in range(3):
|
||||||
|
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
|
||||||
|
for i in range(5):
|
||||||
|
pbuf.add(obs=Batch(index=np.array([i])),
|
||||||
|
act=2, rew=rew, done=0, weight=np.random.rand())
|
||||||
|
# save & load
|
||||||
|
_vbuf = pickle.loads(pickle.dumps(vbuf))
|
||||||
|
_lbuf = pickle.loads(pickle.dumps(lbuf))
|
||||||
|
_pbuf = pickle.loads(pickle.dumps(pbuf))
|
||||||
|
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
|
||||||
|
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
|
||||||
|
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
|
||||||
|
# make sure the meta var is identical
|
||||||
|
assert _vbuf.stack_num == vbuf.stack_num
|
||||||
|
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
|
||||||
|
pbuf.weight[np.arange(len(pbuf))])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
test_ignore_obs_next()
|
test_ignore_obs_next()
|
||||||
test_stack()
|
test_stack()
|
||||||
|
test_pickle()
|
||||||
test_segtree()
|
test_segtree()
|
||||||
test_priortized_replaybuffer()
|
test_priortized_replaybuffer()
|
||||||
test_priortized_replaybuffer(233333, 200000)
|
test_priortized_replaybuffer(233333, 200000)
|
||||||
|
@ -23,7 +23,7 @@ class ReplayBuffer:
|
|||||||
The following code snippet illustrates its usage:
|
The following code snippet illustrates its usage:
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> import numpy as np
|
>>> import pickle, numpy as np
|
||||||
>>> from tianshou.data import ReplayBuffer
|
>>> from tianshou.data import ReplayBuffer
|
||||||
>>> buf = ReplayBuffer(size=20)
|
>>> buf = ReplayBuffer(size=20)
|
||||||
>>> for i in range(3):
|
>>> for i in range(3):
|
||||||
@ -35,6 +35,7 @@ class ReplayBuffer:
|
|||||||
>>> # but there are only three valid items, so len(buf) == 3.
|
>>> # but there are only three valid items, so len(buf) == 3.
|
||||||
>>> len(buf)
|
>>> len(buf)
|
||||||
3
|
3
|
||||||
|
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
|
||||||
>>> buf2 = ReplayBuffer(size=10)
|
>>> buf2 = ReplayBuffer(size=10)
|
||||||
>>> for i in range(15):
|
>>> for i in range(15):
|
||||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||||
@ -54,6 +55,11 @@ class ReplayBuffer:
|
|||||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||||
>>> batch_data.obs == buf[indice].obs
|
>>> batch_data.obs == buf[indice].obs
|
||||||
array([ True, True, True, True])
|
array([ True, True, True, True])
|
||||||
|
>>> len(buf)
|
||||||
|
13
|
||||||
|
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||||
|
>>> len(buf)
|
||||||
|
3
|
||||||
|
|
||||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||||
@ -119,6 +125,7 @@ class ReplayBuffer:
|
|||||||
sample_avail: bool = False, **kwargs) -> None:
|
sample_avail: bool = False, **kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
|
self._indices = np.arange(size)
|
||||||
self._stack = None
|
self._stack = None
|
||||||
self.stack_num = stack_num
|
self.stack_num = stack_num
|
||||||
self._avail = sample_avail and stack_num > 1
|
self._avail = sample_avail and stack_num > 1
|
||||||
@ -137,9 +144,18 @@ class ReplayBuffer:
|
|||||||
"""Return str(self)."""
|
"""Return str(self)."""
|
||||||
return self.__class__.__name__ + self._meta.__repr__()[5:]
|
return self.__class__.__name__ + self._meta.__repr__()[5:]
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
def __getattr__(self, key: str) -> Any:
|
||||||
"""Return self.key"""
|
"""Return self.key"""
|
||||||
return self._meta.__dict__[key]
|
try:
|
||||||
|
return self._meta[key]
|
||||||
|
except KeyError as e:
|
||||||
|
raise AttributeError from e
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
"""Unpickling interface. We need it because pickling buffer does not
|
||||||
|
work out-of-the-box (``buffer.__getattr__`` is customized).
|
||||||
|
"""
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||||
try:
|
try:
|
||||||
@ -149,9 +165,8 @@ class ReplayBuffer:
|
|||||||
value = self._meta.__dict__[name]
|
value = self._meta.__dict__[name]
|
||||||
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
|
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot add data to a buffer with different shape, key: "
|
"Cannot add data to a buffer with different shape, with key "
|
||||||
f"{name}, expect shape: {value.shape[1:]}, "
|
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
|
||||||
f"given shape: {inst.shape}.")
|
|
||||||
try:
|
try:
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -261,47 +276,42 @@ class ReplayBuffer:
|
|||||||
"""
|
"""
|
||||||
if stack_num is None:
|
if stack_num is None:
|
||||||
stack_num = self.stack_num
|
stack_num = self.stack_num
|
||||||
if isinstance(indice, slice):
|
if stack_num == 1: # the most often case
|
||||||
indice = np.arange(
|
if key != 'obs_next' or self._save_s_:
|
||||||
0 if indice.start is None
|
val = self._meta.__dict__[key]
|
||||||
else self._size - indice.start if indice.start < 0
|
try:
|
||||||
else indice.start,
|
return val[indice]
|
||||||
self._size if indice.stop is None
|
except IndexError as e:
|
||||||
else self._size - indice.stop if indice.stop < 0
|
if not (isinstance(val, Batch) and val.is_empty()):
|
||||||
else indice.stop,
|
raise e # val != Batch()
|
||||||
1 if indice.step is None else indice.step)
|
return Batch()
|
||||||
else:
|
indice = self._indices[:self._size][indice]
|
||||||
indice = np.array(indice, copy=True)
|
done = self._meta.__dict__['done']
|
||||||
# set last frame done to True
|
if key == 'obs_next' and not self._save_s_:
|
||||||
last_index = (self._index - 1 + self._size) % self._size
|
indice += 1 - done[indice].astype(np.int)
|
||||||
last_done, self.done[last_index] = self.done[last_index], True
|
|
||||||
if key == 'obs_next' and (not self._save_s_ or self.obs_next is None):
|
|
||||||
indice += 1 - self.done[indice].astype(np.int)
|
|
||||||
indice[indice == self._size] = 0
|
indice[indice == self._size] = 0
|
||||||
key = 'obs'
|
key = 'obs'
|
||||||
val = self._meta.__dict__[key]
|
val = self._meta.__dict__[key]
|
||||||
try:
|
try:
|
||||||
if stack_num > 1:
|
if stack_num == 1:
|
||||||
stack = []
|
return val[indice]
|
||||||
for _ in range(stack_num):
|
stack = []
|
||||||
stack = [val[indice]] + stack
|
for _ in range(stack_num):
|
||||||
pre_indice = np.asarray(indice - 1)
|
stack = [val[indice]] + stack
|
||||||
pre_indice[pre_indice == -1] = self._size - 1
|
pre_indice = np.asarray(indice - 1)
|
||||||
indice = np.asarray(
|
pre_indice[pre_indice == -1] = self._size - 1
|
||||||
pre_indice + self.done[pre_indice].astype(np.int))
|
indice = np.asarray(
|
||||||
indice[indice == self._size] = 0
|
pre_indice + done[pre_indice].astype(np.int))
|
||||||
if isinstance(val, Batch):
|
indice[indice == self._size] = 0
|
||||||
stack = Batch.stack(stack, axis=indice.ndim)
|
if isinstance(val, Batch):
|
||||||
else:
|
stack = Batch.stack(stack, axis=indice.ndim)
|
||||||
stack = np.stack(stack, axis=indice.ndim)
|
|
||||||
else:
|
else:
|
||||||
stack = val[indice]
|
stack = np.stack(stack, axis=indice.ndim)
|
||||||
|
return stack
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
stack = Batch()
|
if not (isinstance(val, Batch) and val.is_empty()):
|
||||||
if not isinstance(val, Batch) or len(val.__dict__) > 0:
|
raise e # val != Batch()
|
||||||
raise e
|
return Batch()
|
||||||
self.done[last_index] = last_done
|
|
||||||
return stack
|
|
||||||
|
|
||||||
def __getitem__(self, index: Union[
|
def __getitem__(self, index: Union[
|
||||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||||
@ -380,7 +390,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
"""Return self.key"""
|
"""Return self.key"""
|
||||||
if key == 'weight':
|
if key == 'weight':
|
||||||
return self._weight
|
return self._weight
|
||||||
return self._meta.__dict__[key]
|
return super().__getattr__(key)
|
||||||
|
|
||||||
def add(self,
|
def add(self,
|
||||||
obs: Union[dict, np.ndarray],
|
obs: Union[dict, np.ndarray],
|
||||||
|
@ -231,7 +231,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
usage is to update the sampling weight in prioritized experience
|
usage is to update the sampling weight in prioritized experience
|
||||||
replay. Check out :ref:`policy_concept` for more information.
|
replay. Check out :ref:`policy_concept` for more information.
|
||||||
"""
|
"""
|
||||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
if isinstance(buffer, PrioritizedReplayBuffer) \
|
||||||
|
and hasattr(batch, 'weight'):
|
||||||
buffer.update_weight(indice, batch.weight)
|
buffer.update_weight(indice, batch.weight)
|
||||||
|
|
||||||
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
|
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user