diff --git a/setup.py b/setup.py index 4b22472..f8736fa 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( install_requires=[ "gym>=0.15.4", "tqdm", - "numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793 + "numpy!=1.16.0,<1.20.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard", "torch>=1.4.0", "numba>=0.51.0", diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 650d560..4553edf 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -19,6 +19,8 @@ def test_batch(): assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() + b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) + assert b.c.dtype == np.object b = Batch() b.update() assert b.is_empty() @@ -143,8 +145,10 @@ def test_batch(): assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 - with pytest.raises(KeyError): + with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) + with pytest.raises(ValueError): + batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == np.object # auto convert to np.object @@ -333,6 +337,12 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # test with illegal input format + with pytest.raises(ValueError): + Batch.cat([[Batch(a=1)], [Batch(a=1)]]) + with pytest.raises(ValueError): + Batch.stack([[Batch(a=1)], [Batch(a=1)]]) + # exceptions assert Batch.cat([]).is_empty() assert Batch.stack([]).is_empty() diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 04b1928..fba2007 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -621,9 +621,9 @@ def test_multibuf_hdf5(): 'done': i % 3 == 2, 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } - buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + buffers["vector"].add(**Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) - buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + buffers["cached"].add(**Batch.stack([kwargs, kwargs, kwargs]), cached_buffer_ids=[0, 1, 2]) # save @@ -657,7 +657,7 @@ def test_multibuf_hdf5(): 'done': False, 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, } - buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) + buffers[k].add(**Batch.stack([kwargs, kwargs, kwargs, kwargs])) act = np.zeros(buffers[k].maxsize) if k == "vector": act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index fe35160..4f15622 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -8,12 +8,6 @@ from collections.abc import Collection from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ Sequence -# Disable pickle warning related to torch, since it has been removed -# on torch master branch. See Pull Request #39003 for details: -# https://github.com/pytorch/pytorch/pull/39003 -warnings.filterwarnings( - "ignore", message="pickle support for Storage will be removed in 1.5.") - def _is_batch_set(data: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, @@ -91,6 +85,9 @@ def _create_value( has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: + # _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here + if isinstance(inst, Batch) and inst.is_empty(recurse=True): + return inst # should never hit since it has already checked in Batch.cat_ # here we do not consider scalar types, following the behavior of numpy # which does not support concatenation of zero-dimensional arrays @@ -257,7 +254,7 @@ class Batch: raise ValueError("Batch does not supported tensor assignment. " "Use a compatible Batch or dict instead.") if not set(value.keys()).issubset(self.__dict__.keys()): - raise KeyError( + raise ValueError( "Creating keys is not supported by item assignment.") for key, val in self.items(): try: @@ -449,12 +446,21 @@ class Batch: """Concatenate a list of (or one) Batch objects into current batch.""" if isinstance(batches, Batch): batches = [batches] - if len(batches) == 0: + # check input format + batch_list = [] + for b in batches: + if isinstance(b, dict): + if len(b) > 0: + batch_list.append(Batch(b)) + elif isinstance(b, Batch): + # x.is_empty() means that x is Batch() and should be ignored + if not b.is_empty(): + batch_list.append(b) + else: + raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_") + if len(batch_list) == 0: return - batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] - - # x.is_empty() means that x is Batch() and should be ignored - batches = [x for x in batches if not x.is_empty()] + batches = batch_list try: # x.is_empty(recurse=True) here means x is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and @@ -496,9 +502,22 @@ class Batch: self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0 ) -> None: """Stack a list of Batch object into current batch.""" - if len(batches) == 0: + # check input format + batch_list = [] + for b in batches: + if isinstance(b, dict): + if len(b) > 0: + batch_list.append(Batch(b)) + elif isinstance(b, Batch): + # x.is_empty() means that x is Batch() and should be ignored + if not b.is_empty(): + batch_list.append(b) + else: + raise ValueError( + f"Cannot concatenate {type(b)} in Batch.stack_") + if len(batch_list) == 0: return - batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] + batches = batch_list if not self.is_empty(): batches = [self] + batches # collect non-empty keys diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 4c1ea14..7bbf099 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -174,7 +174,7 @@ class ReplayBuffer: ) try: value[self._index] = inst - except KeyError: # inst is a dict/Batch + except ValueError: # inst is a dict/Batch for key in set(inst.keys()).difference(value.keys()): self._buffer_allocator([name, key], inst[key]) self._meta[name][self._index] = inst