Fix master branch test issues (#908)

This commit is contained in:
Anas BELFADIL 2023-08-09 19:27:18 +02:00 committed by GitHub
parent f8808d236f
commit cb8551f315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 8 deletions

View File

@ -30,7 +30,7 @@ def get_install_requires() -> str:
def get_extras_require() -> str: def get_extras_require() -> str:
req = { req = {
"dev": [ "dev": [
"sphinx", "sphinx<7",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"jinja2", "jinja2",
"sphinxcontrib-bibtex", "sphinxcontrib-bibtex",

View File

@ -38,6 +38,7 @@ def test_batch():
assert 'a' not in b assert 'a' not in b
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
Batch({1: 2}) Batch({1: 2})
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
with pytest.raises(TypeError): with pytest.raises(TypeError):
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
@ -47,7 +48,6 @@ def test_batch():
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError): with pytest.raises(TypeError):
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3)) assert torch.allclose(batch.a, torch.ones(2, 3))
batch.cat_(batch) batch.cat_(batch)
assert torch.allclose(batch.a, torch.ones(4, 3)) assert torch.allclose(batch.a, torch.ones(4, 3))
@ -183,7 +183,7 @@ def test_batch():
g1.add_nodes_from(list(range(10))) g1.add_nodes_from(list(range(10)))
g2 = nx.Graph() g2 = nx.Graph()
g2.add_nodes_from(list(range(20))) g2.add_nodes_from(list(range(20)))
assert Batch(a=np.array([g1, g2])).a.dtype == object assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object
def test_batch_over_batch(): def test_batch_over_batch():

View File

@ -32,7 +32,7 @@ def test_replaybuffer(size=10, bufsize=20):
buf = ReplayBuffer(bufsize) buf = ReplayBuffer(bufsize)
buf.update(buf) buf.update(buf)
assert str(buf) == buf.__class__.__name__ + '()' assert str(buf) == buf.__class__.__name__ + '()'
obs = env.reset() obs, _ = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10 action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, act in enumerate(action_list): for i, act in enumerate(action_list):
obs_next, rew, terminated, truncated, info = env.step(act) obs_next, rew, terminated, truncated, info = env.step(act)

View File

@ -58,7 +58,10 @@ def _to_array_with_correct_type(obj: Any) -> np.ndarray:
# convert the value to np.ndarray # convert the value to np.ndarray
# convert to object obj type if neither bool nor number # convert to object obj type if neither bool nor number
# raises an exception if array's elements are tensors themselves # raises an exception if array's elements are tensors themselves
obj_array = np.asanyarray(obj) try:
obj_array = np.asanyarray(obj)
except ValueError:
obj_array = np.asanyarray(obj, dtype=object)
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)): if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
obj_array = obj_array.astype(object) obj_array = obj_array.astype(object)
if obj_array.dtype == object: if obj_array.dtype == object:

View File

@ -34,7 +34,7 @@ class CachedReplayBuffer(ReplayBufferManager):
max_episode_length: int, max_episode_length: int,
) -> None: ) -> None:
assert cached_buffer_num > 0 and max_episode_length > 0 assert cached_buffer_num > 0 and max_episode_length > 0
assert type(main_buffer) == ReplayBuffer assert isinstance(main_buffer, ReplayBuffer)
kwargs = main_buffer.options kwargs = main_buffer.options
buffers = [main_buffer] + [ buffers = [main_buffer] + [
ReplayBuffer(max_episode_length, **kwargs) ReplayBuffer(max_episode_length, **kwargs)

View File

@ -9,6 +9,7 @@ import torch
from tianshou.data import ( from tianshou.data import (
Batch, Batch,
CachedReplayBuffer, CachedReplayBuffer,
PrioritizedReplayBuffer,
ReplayBuffer, ReplayBuffer,
ReplayBufferManager, ReplayBufferManager,
VectorReplayBuffer, VectorReplayBuffer,
@ -89,10 +90,10 @@ class Collector(object):
else: # ReplayBuffer or PrioritizedReplayBuffer else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0 assert buffer.maxsize > 0
if self.env_num > 1: if self.env_num > 1:
if type(buffer) == ReplayBuffer: if isinstance(buffer, ReplayBuffer):
buffer_type = "ReplayBuffer" buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer" vector_type = "VectorReplayBuffer"
else: if isinstance(buffer, PrioritizedReplayBuffer):
buffer_type = "PrioritizedReplayBuffer" buffer_type = "PrioritizedReplayBuffer"
vector_type = "PrioritizedVectorReplayBuffer" vector_type = "PrioritizedVectorReplayBuffer"
raise TypeError( raise TypeError(