* add sum_tree.py * add prioritized replay buffer * del sum_tree.py * fix some format issues * fix weight_update bug * simply replace replaybuffer in test_dqn without weight update * weight default set to 1 * fix sampling bug when buffer is not full * rename parameter * fix formula error, add accuracy check * add PrioritizedDQN test * add test_pdqn.py * add update_weight() doc * add ref of prio dqn in readme.md and index.rst * restore test_dqn.py, fix args of test_pdqn.py
38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
import numpy as np
|
|
from tianshou.data import PrioritizedReplayBuffer
|
|
|
|
if __name__ == '__main__':
|
|
from env import MyTestEnv
|
|
else: # pytest
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
def test_replaybuffer(size=32, bufsize=15):
|
|
env = MyTestEnv(size)
|
|
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
|
obs = env.reset()
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
|
for i, a in enumerate(action_list):
|
|
obs_next, rew, done, info = env.step(a)
|
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
|
|
obs = obs_next
|
|
assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
|
|
rtol=1e-12)
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
if len(buf)//2 == 0:
|
|
assert len(data) == len(buf)
|
|
else:
|
|
assert len(data) == len(buf)//2
|
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
|
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
|
data, indice = buf.sample(len(buf) // 2)
|
|
buf.update_weight(indice, -data.weight/2)
|
|
assert np.isclose(buf.weight[indice], np.power(
|
|
np.abs(-data.weight/2), buf._alpha)).all()
|
|
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_replaybuffer(233333, 200000)
|
|
print("pass")
|