Fix master branch test issues (#908)

This commit is contained in:
Anas BELFADIL 2023-08-09 19:27:18 +02:00 committed by GitHub
parent f8808d236f
commit cb8551f315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 8 deletions

View File

@ -30,7 +30,7 @@ def get_install_requires() -> str:
def get_extras_require() -> str:
req = {
"dev": [
"sphinx",
"sphinx<7",
"sphinx_rtd_theme",
"jinja2",
"sphinxcontrib-bibtex",

View File

@ -38,6 +38,7 @@ def test_batch():
assert 'a' not in b
with pytest.raises(AssertionError):
Batch({1: 2})
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
with pytest.raises(TypeError):
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
@ -47,7 +48,6 @@ def test_batch():
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3))
batch.cat_(batch)
assert torch.allclose(batch.a, torch.ones(4, 3))
@ -183,7 +183,7 @@ def test_batch():
g1.add_nodes_from(list(range(10)))
g2 = nx.Graph()
g2.add_nodes_from(list(range(20)))
assert Batch(a=np.array([g1, g2])).a.dtype == object
assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object
def test_batch_over_batch():

View File

@ -32,7 +32,7 @@ def test_replaybuffer(size=10, bufsize=20):
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + '()'
obs = env.reset()
obs, _ = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, act in enumerate(action_list):
obs_next, rew, terminated, truncated, info = env.step(act)

View File

@ -58,7 +58,10 @@ def _to_array_with_correct_type(obj: Any) -> np.ndarray:
# convert the value to np.ndarray
# convert to object obj type if neither bool nor number
# raises an exception if array's elements are tensors themselves
obj_array = np.asanyarray(obj)
try:
obj_array = np.asanyarray(obj)
except ValueError:
obj_array = np.asanyarray(obj, dtype=object)
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
obj_array = obj_array.astype(object)
if obj_array.dtype == object:

View File

@ -34,7 +34,7 @@ class CachedReplayBuffer(ReplayBufferManager):
max_episode_length: int,
) -> None:
assert cached_buffer_num > 0 and max_episode_length > 0
assert type(main_buffer) == ReplayBuffer
assert isinstance(main_buffer, ReplayBuffer)
kwargs = main_buffer.options
buffers = [main_buffer] + [
ReplayBuffer(max_episode_length, **kwargs)

View File

@ -9,6 +9,7 @@ import torch
from tianshou.data import (
Batch,
CachedReplayBuffer,
PrioritizedReplayBuffer,
ReplayBuffer,
ReplayBufferManager,
VectorReplayBuffer,
@ -89,10 +90,10 @@ class Collector(object):
else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0
if self.env_num > 1:
if type(buffer) == ReplayBuffer:
if isinstance(buffer, ReplayBuffer):
buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer"
else:
if isinstance(buffer, PrioritizedReplayBuffer):
buffer_type = "PrioritizedReplayBuffer"
vector_type = "PrioritizedVectorReplayBuffer"
raise TypeError(