add test_buffer
This commit is contained in:
parent
04557fdb82
commit
6632e47b9d
5
.github/workflows/pytest.yml
vendored
5
.github/workflows/pytest.yml
vendored
@ -30,7 +30,10 @@ jobs:
|
|||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
pip install flake8
|
pip install flake8
|
||||||
./flake8.sh
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
|
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||||
|
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||||
|
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
pip install pytest pytest-cov
|
pip install pytest pytest-cov
|
||||||
|
1
setup.py
1
setup.py
@ -32,6 +32,7 @@ setup(
|
|||||||
# that you indicate whether you support Python 2, Python 3 or both.
|
# that you indicate whether you support Python 2, Python 3 or both.
|
||||||
'Programming Language :: Python :: 3.6',
|
'Programming Language :: Python :: 3.6',
|
||||||
'Programming Language :: Python :: 3.7',
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8',
|
||||||
],
|
],
|
||||||
keywords='reinforcement learning platform',
|
keywords='reinforcement learning platform',
|
||||||
# You can just specify the packages manually here if your project is
|
# You can just specify the packages manually here if your project is
|
||||||
|
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
19
test/test_buffer.py
Normal file
19
test/test_buffer.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from tianshou.data import ReplayBuffer
|
||||||
|
from test.test_env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_replaybuffer(bufsize=20):
|
||||||
|
env = MyTestEnv(10)
|
||||||
|
buf = ReplayBuffer(bufsize)
|
||||||
|
obs = env.reset()
|
||||||
|
action_list = [1] * 5 + [0] * 10 + [1] * 9
|
||||||
|
for i, a in enumerate(action_list):
|
||||||
|
obs_next, rew, done, info = env.step(a)
|
||||||
|
buf.add(obs, a, rew, done, obs_next, info)
|
||||||
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||||
|
indice = buf.sample_indice(4)
|
||||||
|
data = buf.sample(4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_replaybuffer()
|
@ -80,6 +80,8 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
|||||||
e.step([a] * num)
|
e.step([a] * num)
|
||||||
t[i] = time.time() - t[i]
|
t[i] = time.time() - t[i]
|
||||||
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
|
print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s')
|
||||||
|
for v in venv:
|
||||||
|
v.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
class Batch(object):
|
class Batch(object):
|
||||||
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
|
@ -7,7 +7,7 @@ class ReplayBuffer(object):
|
|||||||
def __init__(self, size):
|
def __init__(self, size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
self._index = self._size = 0
|
self.reset()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._size
|
return self._size
|
||||||
@ -45,7 +45,7 @@ class ReplayBuffer(object):
|
|||||||
return np.random.choice(self._size, batch_size)
|
return np.random.choice(self._size, batch_size)
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
indice = self.sample_index(batch_size)
|
indice = self.sample_indice(batch_size)
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self.obs[indice],
|
obs=self.obs[indice],
|
||||||
act=self.act[indice],
|
act=self.act[indice],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user