add test_buffer

This commit is contained in:
Trinkle23897 2020-03-11 17:28:51 +08:00
parent 04557fdb82
commit 6632e47b9d
7 changed files with 28 additions and 4 deletions

View File

@ -30,7 +30,10 @@ jobs:
- name: Lint with flake8
run: |
pip install flake8
./flake8.sh
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pip install pytest pytest-cov

View File

@ -32,6 +32,7 @@ setup(
# that you indicate whether you support Python 2, Python 3 or both.
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
keywords='reinforcement learning platform',
# You can just specify the packages manually here if your project is

0
test/__init__.py Normal file
View File

19
test/test_buffer.py Normal file
View File

@ -0,0 +1,19 @@
from tianshou.data import ReplayBuffer
from test.test_env import MyTestEnv
def test_replaybuffer(bufsize=20):
env = MyTestEnv(10)
buf = ReplayBuffer(bufsize)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 9
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info)
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
indice = buf.sample_indice(4)
data = buf.sample(4)
if __name__ == '__main__':
test_replaybuffer()

View File

@ -80,6 +80,8 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
e.step([a] * num)
t[i] = time.time() - t[i]
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
for v in venv:
v.close()
if __name__ == '__main__':

View File

@ -1,6 +1,5 @@
class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
def __init__(self, **kwargs):
super().__init__()
self.__dict__.update(kwargs)

View File

@ -7,7 +7,7 @@ class ReplayBuffer(object):
def __init__(self, size):
super().__init__()
self._maxsize = size
self._index = self._size = 0
self.reset()
def __len__(self):
return self._size
@ -45,7 +45,7 @@ class ReplayBuffer(object):
return np.random.choice(self._size, batch_size)
def sample(self, batch_size):
indice = self.sample_index(batch_size)
indice = self.sample_indice(batch_size)
return Batch(
obs=self.obs[indice],
act=self.act[indice],