- Refacor code to remove duplicate code - Enable async simulation for all vector envs - Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv` The abstraction of vector env changed. Prior to this pr, each vector env is almost independent. After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility. Co-authored-by: n+e <463003665@qq.com> Co-authored-by: magicly <magicly007@gmail.com>
92 lines
2.3 KiB
Python
92 lines
2.3 KiB
Python
import pytest
|
|
import numpy as np
|
|
|
|
from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer,
|
|
ReplayBuffer, SegmentTree)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def data():
|
|
np.random.seed(0)
|
|
obs = {'observable': np.random.rand(100, 100),
|
|
'hidden': np.random.randint(1000, size=200)}
|
|
info = {'policy': "dqn", 'base': np.arange(10)}
|
|
add_data = {'obs': obs, 'rew': 1., 'act': np.random.rand(30),
|
|
'done': False, 'obs_next': obs, 'info': info}
|
|
buffer = ReplayBuffer(int(1e3), stack_num=100)
|
|
buffer2 = ReplayBuffer(int(1e4), stack_num=100)
|
|
indexes = np.random.choice(int(1e3), size=3, replace=False)
|
|
return {
|
|
'add_data': add_data,
|
|
'buffer': buffer,
|
|
'buffer2': buffer2,
|
|
'slice': slice(-3000, -1000, 2),
|
|
'indexes': indexes,
|
|
}
|
|
|
|
|
|
def test_init():
|
|
for _ in np.arange(1e5):
|
|
_ = ReplayBuffer(1e5)
|
|
_ = PrioritizedReplayBuffer(
|
|
size=int(1e5), alpha=0.5,
|
|
beta=0.5, repeat_sample=True)
|
|
_ = ListReplayBuffer()
|
|
|
|
|
|
def test_add(data):
|
|
buffer = data['buffer']
|
|
for _ in np.arange(1e5):
|
|
buffer.add(**data['add_data'])
|
|
|
|
|
|
def test_update(data):
|
|
buffer = data['buffer']
|
|
buffer2 = data['buffer2']
|
|
for _ in np.arange(1e2):
|
|
buffer2.update(buffer)
|
|
|
|
|
|
def test_getitem_slice(data):
|
|
Slice = data['slice']
|
|
buffer = data['buffer']
|
|
for _ in np.arange(1e3):
|
|
_ = buffer[Slice]
|
|
|
|
|
|
def test_getitem_indexes(data):
|
|
indexes = data['indexes']
|
|
buffer = data['buffer']
|
|
for _ in np.arange(1e2):
|
|
_ = buffer[indexes]
|
|
|
|
|
|
def test_get(data):
|
|
indexes = data['indexes']
|
|
buffer = data['buffer']
|
|
for _ in np.arange(3e2):
|
|
buffer.get(indexes, 'obs')
|
|
buffer.get(indexes, 'rew')
|
|
buffer.get(indexes, 'done')
|
|
buffer.get(indexes, 'info')
|
|
|
|
|
|
def test_sample(data):
|
|
buffer = data['buffer']
|
|
for _ in np.arange(1e1):
|
|
buffer.sample(int(1e2))
|
|
|
|
|
|
def test_segtree(data):
|
|
size = 100000
|
|
tree = SegmentTree(size)
|
|
tree[np.arange(size)] = np.random.rand(size)
|
|
|
|
for i in np.arange(1e5):
|
|
scalar = np.random.rand(64) * tree.reduce()
|
|
tree.get_prefix_sum_idx(scalar)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"])
|