buffer update bug fix (#154)
* buffer update bug fix * some fix in buffer update * polish Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
parent
fe5555d2a1
commit
d09b69e594
@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -10,7 +11,6 @@ else: # pytest
|
||||
def test_replaybuffer(size=10, bufsize=20):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize)
|
||||
buf2 = ReplayBuffer(bufsize)
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, a in enumerate(action_list):
|
||||
@ -22,11 +22,6 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert (indice < len(buf)).all()
|
||||
assert (data.obs < size).all()
|
||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||
assert len(buf) > len(buf2)
|
||||
buf2.update(buf)
|
||||
assert len(buf) == len(buf2)
|
||||
assert buf2[0].obs == buf[5].obs
|
||||
assert buf2[-1].obs == buf[4].obs
|
||||
b = ReplayBuffer(size=10)
|
||||
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
||||
assert b.obs[0] == 1
|
||||
@ -104,8 +99,22 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
||||
|
||||
|
||||
def test_update():
|
||||
buf1 = ReplayBuffer(4, stack_num=2)
|
||||
buf2 = ReplayBuffer(4, stack_num=2)
|
||||
for i in range(5):
|
||||
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
||||
done=False, info={'incident': 'found'})
|
||||
assert len(buf1) > len(buf2)
|
||||
buf2.update(buf1)
|
||||
assert len(buf1) == len(buf2)
|
||||
assert (buf2[0].obs == buf1[1].obs).all()
|
||||
assert (buf2[-1].obs == buf1[0].obs).all()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
test_update()
|
||||
|
@ -157,16 +157,25 @@ class ReplayBuffer:
|
||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
||||
value[self._index] = inst
|
||||
|
||||
def _get_stack_num(self):
|
||||
return self._stack
|
||||
|
||||
def _set_stack_num(self, num):
|
||||
self._stack = num
|
||||
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
"""Move the data from the given buffer to self."""
|
||||
if len(buffer) == 0:
|
||||
return
|
||||
i = begin = buffer._index % len(buffer)
|
||||
origin = buffer._get_stack_num()
|
||||
buffer._set_stack_num(0)
|
||||
while True:
|
||||
self.add(**buffer[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
buffer._set_stack_num(origin)
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, Batch, np.ndarray],
|
||||
@ -408,7 +417,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
replace=self._replace)
|
||||
p = p[indice] # weight of each sample
|
||||
elif batch_size == 0:
|
||||
p = np.full(shape=self._size, fill_value=1.0/self._size)
|
||||
p = np.full(shape=self._size, fill_value=1.0 / self._size)
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
|
Loading…
x
Reference in New Issue
Block a user