Fix master branch test issues (#908)
This commit is contained in:
parent
f8808d236f
commit
cb8551f315
2
setup.py
2
setup.py
@ -30,7 +30,7 @@ def get_install_requires() -> str:
|
||||
def get_extras_require() -> str:
|
||||
req = {
|
||||
"dev": [
|
||||
"sphinx",
|
||||
"sphinx<7",
|
||||
"sphinx_rtd_theme",
|
||||
"jinja2",
|
||||
"sphinxcontrib-bibtex",
|
||||
|
@ -38,6 +38,7 @@ def test_batch():
|
||||
assert 'a' not in b
|
||||
with pytest.raises(AssertionError):
|
||||
Batch({1: 2})
|
||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
||||
@ -47,7 +48,6 @@ def test_batch():
|
||||
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
|
||||
with pytest.raises(TypeError):
|
||||
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
|
||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||
batch.cat_(batch)
|
||||
assert torch.allclose(batch.a, torch.ones(4, 3))
|
||||
@ -183,7 +183,7 @@ def test_batch():
|
||||
g1.add_nodes_from(list(range(10)))
|
||||
g2 = nx.Graph()
|
||||
g2.add_nodes_from(list(range(20)))
|
||||
assert Batch(a=np.array([g1, g2])).a.dtype == object
|
||||
assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
|
@ -32,7 +32,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
buf = ReplayBuffer(bufsize)
|
||||
buf.update(buf)
|
||||
assert str(buf) == buf.__class__.__name__ + '()'
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, act in enumerate(action_list):
|
||||
obs_next, rew, terminated, truncated, info = env.step(act)
|
||||
|
@ -58,7 +58,10 @@ def _to_array_with_correct_type(obj: Any) -> np.ndarray:
|
||||
# convert the value to np.ndarray
|
||||
# convert to object obj type if neither bool nor number
|
||||
# raises an exception if array's elements are tensors themselves
|
||||
obj_array = np.asanyarray(obj)
|
||||
try:
|
||||
obj_array = np.asanyarray(obj)
|
||||
except ValueError:
|
||||
obj_array = np.asanyarray(obj, dtype=object)
|
||||
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
|
||||
obj_array = obj_array.astype(object)
|
||||
if obj_array.dtype == object:
|
||||
|
@ -34,7 +34,7 @@ class CachedReplayBuffer(ReplayBufferManager):
|
||||
max_episode_length: int,
|
||||
) -> None:
|
||||
assert cached_buffer_num > 0 and max_episode_length > 0
|
||||
assert type(main_buffer) == ReplayBuffer
|
||||
assert isinstance(main_buffer, ReplayBuffer)
|
||||
kwargs = main_buffer.options
|
||||
buffers = [main_buffer] + [
|
||||
ReplayBuffer(max_episode_length, **kwargs)
|
||||
|
@ -9,6 +9,7 @@ import torch
|
||||
from tianshou.data import (
|
||||
Batch,
|
||||
CachedReplayBuffer,
|
||||
PrioritizedReplayBuffer,
|
||||
ReplayBuffer,
|
||||
ReplayBufferManager,
|
||||
VectorReplayBuffer,
|
||||
@ -89,10 +90,10 @@ class Collector(object):
|
||||
else: # ReplayBuffer or PrioritizedReplayBuffer
|
||||
assert buffer.maxsize > 0
|
||||
if self.env_num > 1:
|
||||
if type(buffer) == ReplayBuffer:
|
||||
if isinstance(buffer, ReplayBuffer):
|
||||
buffer_type = "ReplayBuffer"
|
||||
vector_type = "VectorReplayBuffer"
|
||||
else:
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
buffer_type = "PrioritizedReplayBuffer"
|
||||
vector_type = "PrioritizedVectorReplayBuffer"
|
||||
raise TypeError(
|
||||
|
Loading…
x
Reference in New Issue
Block a user