diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4c6bc71..393534c 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,9 +1,11 @@ +import torch +import pickle import pytest import numpy as np from timeit import timeit -from tianshou.data import Batch, PrioritizedReplayBuffer, \ - ReplayBuffer, SegmentTree +from tianshou.data import Batch, SegmentTree, \ + ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -39,10 +41,11 @@ def test_replaybuffer(size=10, bufsize=20): def test_ignore_obs_next(size=10): # Issue 82 - buf = ReplayBuffer(size, ignore_obs_net=True) + buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]), - 'mask2': np.array([i + 4, 0, 1, 0, 0])}, + 'mask2': np.array([i + 4, 0, 1, 0, 0]), + 'mask': i}, act={'act_id': i, 'position_id': i + 3}, rew=i, @@ -55,6 +58,22 @@ def test_ignore_obs_next(size=10): assert isinstance(data, Batch) assert isinstance(data2, Batch) assert np.allclose(indice, orig) + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) + buf.stack_num = 4 + data = buf[indice] + data2 = buf[indice] + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose(data.obs_next.mask, np.array([ + [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], + [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6], + [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]])) + assert np.allclose(data.info['if'], data2.info['if']) + assert np.allclose(data.info['if'], np.array([ + [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], + [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6], + [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]])) + assert data.obs_next def test_stack(size=5, bufsize=9, stack_num=4): @@ -62,7 +81,7 @@ def test_stack(size=5, bufsize=9, stack_num=4): buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) - for i in range(15): + for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) @@ -73,12 +92,11 @@ def test_stack(size=5, bufsize=9, stack_num=4): assert np.allclose(buf.get(indice, 'obs'), np.expand_dims( [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1)) - print(buf) + [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1)) _, indice = buf2.sample(0) - assert indice == [2] + assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) - assert indice.sum() == 2 + assert indice in [2, 6] def test_priortized_replaybuffer(size=32, bufsize=15): @@ -107,7 +125,7 @@ def test_update(): buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): buf1.add(obs=np.array([i]), act=float(i), rew=i * i, - done=False, info={'incident': 'found'}) + done=i % 2 == 0, info={'incident': 'found'}) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) @@ -214,10 +232,38 @@ def test_segtree(): print('tree', timeit(sample_tree, setup=sample_tree, number=1000)) +def test_pickle(): + size = 100 + vbuf = ReplayBuffer(size, stack_num=2) + lbuf = ListReplayBuffer() + pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + rew = torch.tensor([1.]).to(device) + for i in range(4): + vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) + for i in range(3): + lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0) + for i in range(5): + pbuf.add(obs=Batch(index=np.array([i])), + act=2, rew=rew, done=0, weight=np.random.rand()) + # save & load + _vbuf = pickle.loads(pickle.dumps(vbuf)) + _lbuf = pickle.loads(pickle.dumps(lbuf)) + _pbuf = pickle.loads(pickle.dumps(pbuf)) + assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) + assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) + assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) + # make sure the meta var is identical + assert _vbuf.stack_num == vbuf.stack_num + assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], + pbuf.weight[np.arange(len(pbuf))]) + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() + test_pickle() test_segtree() test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 1d7a80f..24aeb0b 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -23,7 +23,7 @@ class ReplayBuffer: The following code snippet illustrates its usage: :: - >>> import numpy as np + >>> import pickle, numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): @@ -35,6 +35,7 @@ class ReplayBuffer: >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 + >>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl" >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) @@ -54,6 +55,11 @@ class ReplayBuffer: >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) + >>> len(buf) + 13 + >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" + >>> len(buf) + 3 :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next @@ -119,6 +125,7 @@ class ReplayBuffer: sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size + self._indices = np.arange(size) self._stack = None self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 @@ -137,9 +144,18 @@ class ReplayBuffer: """Return str(self).""" return self.__class__.__name__ + self._meta.__repr__()[5:] - def __getattr__(self, key: str) -> Union['Batch', Any]: + def __getattr__(self, key: str) -> Any: """Return self.key""" - return self._meta.__dict__[key] + try: + return self._meta[key] + except KeyError as e: + raise AttributeError from e + + def __setstate__(self, state): + """Unpickling interface. We need it because pickling buffer does not + work out-of-the-box (``buffer.__getattr__`` is customized). + """ + self.__dict__.update(state) def _add_to_buffer(self, name: str, inst: Any) -> None: try: @@ -149,9 +165,8 @@ class ReplayBuffer: value = self._meta.__dict__[name] if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( - "Cannot add data to a buffer with different shape, key: " - f"{name}, expect shape: {value.shape[1:]}, " - f"given shape: {inst.shape}.") + "Cannot add data to a buffer with different shape, with key " + f"{name}, expect {value.shape[1:]}, given {inst.shape}.") try: value[self._index] = inst except KeyError: @@ -261,47 +276,42 @@ class ReplayBuffer: """ if stack_num is None: stack_num = self.stack_num - if isinstance(indice, slice): - indice = np.arange( - 0 if indice.start is None - else self._size - indice.start if indice.start < 0 - else indice.start, - self._size if indice.stop is None - else self._size - indice.stop if indice.stop < 0 - else indice.stop, - 1 if indice.step is None else indice.step) - else: - indice = np.array(indice, copy=True) - # set last frame done to True - last_index = (self._index - 1 + self._size) % self._size - last_done, self.done[last_index] = self.done[last_index], True - if key == 'obs_next' and (not self._save_s_ or self.obs_next is None): - indice += 1 - self.done[indice].astype(np.int) + if stack_num == 1: # the most often case + if key != 'obs_next' or self._save_s_: + val = self._meta.__dict__[key] + try: + return val[indice] + except IndexError as e: + if not (isinstance(val, Batch) and val.is_empty()): + raise e # val != Batch() + return Batch() + indice = self._indices[:self._size][indice] + done = self._meta.__dict__['done'] + if key == 'obs_next' and not self._save_s_: + indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' val = self._meta.__dict__[key] try: - if stack_num > 1: - stack = [] - for _ in range(stack_num): - stack = [val[indice]] + stack - pre_indice = np.asarray(indice - 1) - pre_indice[pre_indice == -1] = self._size - 1 - indice = np.asarray( - pre_indice + self.done[pre_indice].astype(np.int)) - indice[indice == self._size] = 0 - if isinstance(val, Batch): - stack = Batch.stack(stack, axis=indice.ndim) - else: - stack = np.stack(stack, axis=indice.ndim) + if stack_num == 1: + return val[indice] + stack = [] + for _ in range(stack_num): + stack = [val[indice]] + stack + pre_indice = np.asarray(indice - 1) + pre_indice[pre_indice == -1] = self._size - 1 + indice = np.asarray( + pre_indice + done[pre_indice].astype(np.int)) + indice[indice == self._size] = 0 + if isinstance(val, Batch): + stack = Batch.stack(stack, axis=indice.ndim) else: - stack = val[indice] + stack = np.stack(stack, axis=indice.ndim) + return stack except IndexError as e: - stack = Batch() - if not isinstance(val, Batch) or len(val.__dict__) > 0: - raise e - self.done[last_index] = last_done - return stack + if not (isinstance(val, Batch) and val.is_empty()): + raise e # val != Batch() + return Batch() def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: @@ -380,7 +390,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): """Return self.key""" if key == 'weight': return self._weight - return self._meta.__dict__[key] + return super().__getattr__(key) def add(self, obs: Union[dict, np.ndarray], diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a2f545e..01398ca 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -231,7 +231,8 @@ class BasePolicy(ABC, nn.Module): usage is to update the sampling weight in prioritized experience replay. Check out :ref:`policy_concept` for more information. """ - if isinstance(buffer, PrioritizedReplayBuffer): + if isinstance(buffer, PrioritizedReplayBuffer) \ + and hasattr(batch, 'weight'): buffer.update_weight(indice, batch.weight) def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):