diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 91789db..1715738 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,7 +30,10 @@ jobs: - name: Lint with flake8 run: | pip install flake8 - ./flake8.sh + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pip install pytest pytest-cov diff --git a/setup.py b/setup.py index 96f325f..87f41ce 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ setup( # that you indicate whether you support Python 2, Python 3 or both. 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', ], keywords='reinforcement learning platform', # You can just specify the packages manually here if your project is diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_buffer.py b/test/test_buffer.py new file mode 100644 index 0000000..ea6c9b2 --- /dev/null +++ b/test/test_buffer.py @@ -0,0 +1,19 @@ +from tianshou.data import ReplayBuffer +from test.test_env import MyTestEnv + + +def test_replaybuffer(bufsize=20): + env = MyTestEnv(10) + buf = ReplayBuffer(bufsize) + obs = env.reset() + action_list = [1] * 5 + [0] * 10 + [1] * 9 + for i, a in enumerate(action_list): + obs_next, rew, done, info = env.step(a) + buf.add(obs, a, rew, done, obs_next, info) + assert len(buf) == min(bufsize, i + 1), print(len(buf), i) + indice = buf.sample_indice(4) + data = buf.sample(4) + + +if __name__ == '__main__': + test_replaybuffer() \ No newline at end of file diff --git a/test/test_env.py b/test/test_env.py index 1c9cea1..67912e0 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -80,6 +80,8 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): e.step([a] * num) t[i] = time.time() - t[i] print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s') + for v in venv: + v.close() if __name__ == '__main__': diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 937afa6..ee16eb0 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,5 @@ class Batch(object): """Suggested keys: [obs, act, rew, done, obs_next, info]""" - def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index d294ae7..20e66d3 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -7,7 +7,7 @@ class ReplayBuffer(object): def __init__(self, size): super().__init__() self._maxsize = size - self._index = self._size = 0 + self.reset() def __len__(self): return self._size @@ -45,7 +45,7 @@ class ReplayBuffer(object): return np.random.choice(self._size, batch_size) def sample(self, batch_size): - indice = self.sample_index(batch_size) + indice = self.sample_indice(batch_size) return Batch( obs=self.obs[indice], act=self.act[indice],