add test_buffer
This commit is contained in:
parent
04557fdb82
commit
6632e47b9d
5
.github/workflows/pytest.yml
vendored
5
.github/workflows/pytest.yml
vendored
@ -30,7 +30,10 @@ jobs:
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
./flake8.sh
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pip install pytest pytest-cov
|
||||
|
1
setup.py
1
setup.py
@ -32,6 +32,7 @@ setup(
|
||||
# that you indicate whether you support Python 2, Python 3 or both.
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
],
|
||||
keywords='reinforcement learning platform',
|
||||
# You can just specify the packages manually here if your project is
|
||||
|
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
19
test/test_buffer.py
Normal file
19
test/test_buffer.py
Normal file
@ -0,0 +1,19 @@
|
||||
from tianshou.data import ReplayBuffer
|
||||
from test.test_env import MyTestEnv
|
||||
|
||||
|
||||
def test_replaybuffer(bufsize=20):
|
||||
env = MyTestEnv(10)
|
||||
buf = ReplayBuffer(bufsize)
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 9
|
||||
for i, a in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info)
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
indice = buf.sample_indice(4)
|
||||
data = buf.sample(4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
@ -80,6 +80,8 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
||||
e.step([a] * num)
|
||||
t[i] = time.time() - t[i]
|
||||
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
|
||||
for v in venv:
|
||||
v.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,6 +1,5 @@
|
||||
class Batch(object):
|
||||
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.__dict__.update(kwargs)
|
||||
|
@ -7,7 +7,7 @@ class ReplayBuffer(object):
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._index = self._size = 0
|
||||
self.reset()
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
@ -45,7 +45,7 @@ class ReplayBuffer(object):
|
||||
return np.random.choice(self._size, batch_size)
|
||||
|
||||
def sample(self, batch_size):
|
||||
indice = self.sample_index(batch_size)
|
||||
indice = self.sample_indice(batch_size)
|
||||
return Batch(
|
||||
obs=self.obs[indice],
|
||||
act=self.act[indice],
|
||||
|
Loading…
x
Reference in New Issue
Block a user