38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
|
import numpy as np
|
||
|
from tianshou.data import PrioritizedReplayBuffer
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
from env import MyTestEnv
|
||
|
else: # pytest
|
||
|
from test.base.env import MyTestEnv
|
||
|
|
||
|
|
||
|
def test_replaybuffer(size=32, bufsize=15):
|
||
|
env = MyTestEnv(size)
|
||
|
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
||
|
obs = env.reset()
|
||
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||
|
for i, a in enumerate(action_list):
|
||
|
obs_next, rew, done, info = env.step(a)
|
||
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
|
||
|
obs = obs_next
|
||
|
assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
|
||
|
rtol=1e-12)
|
||
|
data, indice = buf.sample(len(buf) // 2)
|
||
|
if len(buf)//2 == 0:
|
||
|
assert len(data) == len(buf)
|
||
|
else:
|
||
|
assert len(data) == len(buf)//2
|
||
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||
|
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||
|
data, indice = buf.sample(len(buf) // 2)
|
||
|
buf.update_weight(indice, -data.weight/2)
|
||
|
assert np.isclose(buf.weight[indice], np.power(
|
||
|
np.abs(-data.weight/2), buf._alpha)).all()
|
||
|
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_replaybuffer(233333, 200000)
|
||
|
print("pass")
|