buffer update bug fix (#154)

* buffer update bug fix

* some fix in buffer update

* polish

Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
ChenDRAG 2020-07-20 22:12:57 +08:00 committed by GitHub
parent fe5555d2a1
commit d09b69e594
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import numpy as np
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
@ -10,7 +11,6 @@ else: # pytest
def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize)
buf2 = ReplayBuffer(bufsize)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
@ -22,11 +22,6 @@ def test_replaybuffer(size=10, bufsize=20):
assert (indice < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
assert len(buf) > len(buf2)
buf2.update(buf)
assert len(buf) == len(buf2)
assert buf2[0].obs == buf[5].obs
assert buf2[-1].obs == buf[4].obs
b = ReplayBuffer(size=10)
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1
@ -104,8 +99,22 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
done=False, info={'incident': 'found'})
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all()
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_priortized_replaybuffer(233333, 200000)
test_update()

View File

@ -157,16 +157,25 @@ class ReplayBuffer:
value.__dict__[key] = _create_value(inst[key], self._maxsize)
value[self._index] = inst
def _get_stack_num(self):
return self._stack
def _set_stack_num(self, num):
self._stack = num
def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
origin = buffer._get_stack_num()
buffer._set_stack_num(0)
while True:
self.add(**buffer[i])
i = (i + 1) % len(buffer)
if i == begin:
break
buffer._set_stack_num(origin)
def add(self,
obs: Union[dict, Batch, np.ndarray],
@ -408,7 +417,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
replace=self._replace)
p = p[indice] # weight of each sample
elif batch_size == 0:
p = np.full(shape=self._size, fill_value=1.0/self._size)
p = np.full(shape=self._size, fill_value=1.0 / self._size)
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),