buffer update bug fix (#154)
* buffer update bug fix * some fix in buffer update * polish Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
parent
fe5555d2a1
commit
d09b69e594
@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
|
||||||
|
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -10,7 +11,6 @@ else: # pytest
|
|||||||
def test_replaybuffer(size=10, bufsize=20):
|
def test_replaybuffer(size=10, bufsize=20):
|
||||||
env = MyTestEnv(size)
|
env = MyTestEnv(size)
|
||||||
buf = ReplayBuffer(bufsize)
|
buf = ReplayBuffer(bufsize)
|
||||||
buf2 = ReplayBuffer(bufsize)
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||||
for i, a in enumerate(action_list):
|
for i, a in enumerate(action_list):
|
||||||
@ -22,11 +22,6 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
assert (indice < len(buf)).all()
|
assert (indice < len(buf)).all()
|
||||||
assert (data.obs < size).all()
|
assert (data.obs < size).all()
|
||||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||||
assert len(buf) > len(buf2)
|
|
||||||
buf2.update(buf)
|
|
||||||
assert len(buf) == len(buf2)
|
|
||||||
assert buf2[0].obs == buf[5].obs
|
|
||||||
assert buf2[-1].obs == buf[4].obs
|
|
||||||
b = ReplayBuffer(size=10)
|
b = ReplayBuffer(size=10)
|
||||||
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
|
||||||
assert b.obs[0] == 1
|
assert b.obs[0] == 1
|
||||||
@ -104,8 +99,22 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
|||||||
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update():
|
||||||
|
buf1 = ReplayBuffer(4, stack_num=2)
|
||||||
|
buf2 = ReplayBuffer(4, stack_num=2)
|
||||||
|
for i in range(5):
|
||||||
|
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
|
||||||
|
done=False, info={'incident': 'found'})
|
||||||
|
assert len(buf1) > len(buf2)
|
||||||
|
buf2.update(buf1)
|
||||||
|
assert len(buf1) == len(buf2)
|
||||||
|
assert (buf2[0].obs == buf1[1].obs).all()
|
||||||
|
assert (buf2[-1].obs == buf1[0].obs).all()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
test_ignore_obs_next()
|
test_ignore_obs_next()
|
||||||
test_stack()
|
test_stack()
|
||||||
test_priortized_replaybuffer(233333, 200000)
|
test_priortized_replaybuffer(233333, 200000)
|
||||||
|
test_update()
|
||||||
|
@ -157,16 +157,25 @@ class ReplayBuffer:
|
|||||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
|
|
||||||
|
def _get_stack_num(self):
|
||||||
|
return self._stack
|
||||||
|
|
||||||
|
def _set_stack_num(self, num):
|
||||||
|
self._stack = num
|
||||||
|
|
||||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||||
"""Move the data from the given buffer to self."""
|
"""Move the data from the given buffer to self."""
|
||||||
if len(buffer) == 0:
|
if len(buffer) == 0:
|
||||||
return
|
return
|
||||||
i = begin = buffer._index % len(buffer)
|
i = begin = buffer._index % len(buffer)
|
||||||
|
origin = buffer._get_stack_num()
|
||||||
|
buffer._set_stack_num(0)
|
||||||
while True:
|
while True:
|
||||||
self.add(**buffer[i])
|
self.add(**buffer[i])
|
||||||
i = (i + 1) % len(buffer)
|
i = (i + 1) % len(buffer)
|
||||||
if i == begin:
|
if i == begin:
|
||||||
break
|
break
|
||||||
|
buffer._set_stack_num(origin)
|
||||||
|
|
||||||
def add(self,
|
def add(self,
|
||||||
obs: Union[dict, Batch, np.ndarray],
|
obs: Union[dict, Batch, np.ndarray],
|
||||||
@ -408,7 +417,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
replace=self._replace)
|
replace=self._replace)
|
||||||
p = p[indice] # weight of each sample
|
p = p[indice] # weight of each sample
|
||||||
elif batch_size == 0:
|
elif batch_size == 0:
|
||||||
p = np.full(shape=self._size, fill_value=1.0/self._size)
|
p = np.full(shape=self._size, fill_value=1.0 / self._size)
|
||||||
indice = np.concatenate([
|
indice = np.concatenate([
|
||||||
np.arange(self._index, self._size),
|
np.arange(self._index, self._size),
|
||||||
np.arange(0, self._index),
|
np.arange(0, self._index),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user