From cb8551f315e5878895ba51c8e2345922b6e18e51 Mon Sep 17 00:00:00 2001 From: Anas BELFADIL <56280198+BFAnas@users.noreply.github.com> Date: Wed, 9 Aug 2023 19:27:18 +0200 Subject: [PATCH] Fix master branch test issues (#908) --- setup.py | 2 +- test/base/test_batch.py | 4 ++-- test/base/test_buffer.py | 2 +- tianshou/data/batch.py | 5 ++++- tianshou/data/buffer/cached.py | 2 +- tianshou/data/collector.py | 5 +++-- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 8ae0050..175a0a6 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def get_install_requires() -> str: def get_extras_require() -> str: req = { "dev": [ - "sphinx", + "sphinx<7", "sphinx_rtd_theme", "jinja2", "sphinxcontrib-bibtex", diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 53ee8ff..f4e7770 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -38,6 +38,7 @@ def test_batch(): assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) + batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) @@ -47,7 +48,6 @@ def test_batch(): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) - batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) @@ -183,7 +183,7 @@ def test_batch(): g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) - assert Batch(a=np.array([g1, g2])).a.dtype == object + assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object def test_batch_over_batch(): diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index c4d3c6b..203d8e4 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -32,7 +32,7 @@ def test_replaybuffer(size=10, bufsize=20): buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + '()' - obs = env.reset() + obs, _ = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e204151..2e676ef 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -58,7 +58,10 @@ def _to_array_with_correct_type(obj: Any) -> np.ndarray: # convert the value to np.ndarray # convert to object obj type if neither bool nor number # raises an exception if array's elements are tensors themselves - obj_array = np.asanyarray(obj) + try: + obj_array = np.asanyarray(obj) + except ValueError: + obj_array = np.asanyarray(obj, dtype=object) if not issubclass(obj_array.dtype.type, (np.bool_, np.number)): obj_array = obj_array.astype(object) if obj_array.dtype == object: diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py index 5063750..299ba43 100644 --- a/tianshou/data/buffer/cached.py +++ b/tianshou/data/buffer/cached.py @@ -34,7 +34,7 @@ class CachedReplayBuffer(ReplayBufferManager): max_episode_length: int, ) -> None: assert cached_buffer_num > 0 and max_episode_length > 0 - assert type(main_buffer) == ReplayBuffer + assert isinstance(main_buffer, ReplayBuffer) kwargs = main_buffer.options buffers = [main_buffer] + [ ReplayBuffer(max_episode_length, **kwargs) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3ff1ab2..0532a25 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -9,6 +9,7 @@ import torch from tianshou.data import ( Batch, CachedReplayBuffer, + PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, VectorReplayBuffer, @@ -89,10 +90,10 @@ class Collector(object): else: # ReplayBuffer or PrioritizedReplayBuffer assert buffer.maxsize > 0 if self.env_num > 1: - if type(buffer) == ReplayBuffer: + if isinstance(buffer, ReplayBuffer): buffer_type = "ReplayBuffer" vector_type = "VectorReplayBuffer" - else: + if isinstance(buffer, PrioritizedReplayBuffer): buffer_type = "PrioritizedReplayBuffer" vector_type = "PrioritizedVectorReplayBuffer" raise TypeError(