Fix master branch test issues (#908)
This commit is contained in:
parent
f8808d236f
commit
cb8551f315
2
setup.py
2
setup.py
@ -30,7 +30,7 @@ def get_install_requires() -> str:
|
|||||||
def get_extras_require() -> str:
|
def get_extras_require() -> str:
|
||||||
req = {
|
req = {
|
||||||
"dev": [
|
"dev": [
|
||||||
"sphinx",
|
"sphinx<7",
|
||||||
"sphinx_rtd_theme",
|
"sphinx_rtd_theme",
|
||||||
"jinja2",
|
"jinja2",
|
||||||
"sphinxcontrib-bibtex",
|
"sphinxcontrib-bibtex",
|
||||||
|
@ -38,6 +38,7 @@ def test_batch():
|
|||||||
assert 'a' not in b
|
assert 'a' not in b
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
Batch({1: 2})
|
Batch({1: 2})
|
||||||
|
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||||
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
|
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
||||||
@ -47,7 +48,6 @@ def test_batch():
|
|||||||
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
|
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
|
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
|
||||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
|
||||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||||
batch.cat_(batch)
|
batch.cat_(batch)
|
||||||
assert torch.allclose(batch.a, torch.ones(4, 3))
|
assert torch.allclose(batch.a, torch.ones(4, 3))
|
||||||
@ -183,7 +183,7 @@ def test_batch():
|
|||||||
g1.add_nodes_from(list(range(10)))
|
g1.add_nodes_from(list(range(10)))
|
||||||
g2 = nx.Graph()
|
g2 = nx.Graph()
|
||||||
g2.add_nodes_from(list(range(20)))
|
g2.add_nodes_from(list(range(20)))
|
||||||
assert Batch(a=np.array([g1, g2])).a.dtype == object
|
assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch():
|
def test_batch_over_batch():
|
||||||
|
@ -32,7 +32,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
buf = ReplayBuffer(bufsize)
|
buf = ReplayBuffer(bufsize)
|
||||||
buf.update(buf)
|
buf.update(buf)
|
||||||
assert str(buf) == buf.__class__.__name__ + '()'
|
assert str(buf) == buf.__class__.__name__ + '()'
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||||
for i, act in enumerate(action_list):
|
for i, act in enumerate(action_list):
|
||||||
obs_next, rew, terminated, truncated, info = env.step(act)
|
obs_next, rew, terminated, truncated, info = env.step(act)
|
||||||
|
@ -58,7 +58,10 @@ def _to_array_with_correct_type(obj: Any) -> np.ndarray:
|
|||||||
# convert the value to np.ndarray
|
# convert the value to np.ndarray
|
||||||
# convert to object obj type if neither bool nor number
|
# convert to object obj type if neither bool nor number
|
||||||
# raises an exception if array's elements are tensors themselves
|
# raises an exception if array's elements are tensors themselves
|
||||||
obj_array = np.asanyarray(obj)
|
try:
|
||||||
|
obj_array = np.asanyarray(obj)
|
||||||
|
except ValueError:
|
||||||
|
obj_array = np.asanyarray(obj, dtype=object)
|
||||||
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
|
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
|
||||||
obj_array = obj_array.astype(object)
|
obj_array = obj_array.astype(object)
|
||||||
if obj_array.dtype == object:
|
if obj_array.dtype == object:
|
||||||
|
@ -34,7 +34,7 @@ class CachedReplayBuffer(ReplayBufferManager):
|
|||||||
max_episode_length: int,
|
max_episode_length: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert cached_buffer_num > 0 and max_episode_length > 0
|
assert cached_buffer_num > 0 and max_episode_length > 0
|
||||||
assert type(main_buffer) == ReplayBuffer
|
assert isinstance(main_buffer, ReplayBuffer)
|
||||||
kwargs = main_buffer.options
|
kwargs = main_buffer.options
|
||||||
buffers = [main_buffer] + [
|
buffers = [main_buffer] + [
|
||||||
ReplayBuffer(max_episode_length, **kwargs)
|
ReplayBuffer(max_episode_length, **kwargs)
|
||||||
|
@ -9,6 +9,7 @@ import torch
|
|||||||
from tianshou.data import (
|
from tianshou.data import (
|
||||||
Batch,
|
Batch,
|
||||||
CachedReplayBuffer,
|
CachedReplayBuffer,
|
||||||
|
PrioritizedReplayBuffer,
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
ReplayBufferManager,
|
ReplayBufferManager,
|
||||||
VectorReplayBuffer,
|
VectorReplayBuffer,
|
||||||
@ -89,10 +90,10 @@ class Collector(object):
|
|||||||
else: # ReplayBuffer or PrioritizedReplayBuffer
|
else: # ReplayBuffer or PrioritizedReplayBuffer
|
||||||
assert buffer.maxsize > 0
|
assert buffer.maxsize > 0
|
||||||
if self.env_num > 1:
|
if self.env_num > 1:
|
||||||
if type(buffer) == ReplayBuffer:
|
if isinstance(buffer, ReplayBuffer):
|
||||||
buffer_type = "ReplayBuffer"
|
buffer_type = "ReplayBuffer"
|
||||||
vector_type = "VectorReplayBuffer"
|
vector_type = "VectorReplayBuffer"
|
||||||
else:
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||||
buffer_type = "PrioritizedReplayBuffer"
|
buffer_type = "PrioritizedReplayBuffer"
|
||||||
vector_type = "PrioritizedVectorReplayBuffer"
|
vector_type = "PrioritizedVectorReplayBuffer"
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user