diff --git a/test/base/test_batch.py b/test/base/test_batch.py index e2390de..a9f2cdd 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,6 +9,10 @@ from tianshou.data import Batch, to_torch def test_batch(): assert list(Batch()) == [] + assert Batch().is_empty() + assert not Batch(a=[1, 2, 3]).is_empty() + with pytest.raises(AssertionError): + Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e14b98f..6b10517 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -50,6 +50,12 @@ def _create_value(inst: Any, size: int) -> Union[ return np.array([None for _ in range(size)]) +def _assert_type_keys(keys): + keys = list(keys) + assert all(isinstance(e, str) for e in keys), \ + f"keys should all be string, but got {keys}" + + class Batch: """Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a @@ -247,6 +253,7 @@ class Batch: batch_dict = deepcopy(batch_dict) if batch_dict is not None: if isinstance(batch_dict, (dict, Batch)): + _assert_type_keys(batch_dict.keys()) for k, v in batch_dict.items(): if isinstance(v, (list, tuple, np.ndarray)): v_ = None @@ -511,12 +518,14 @@ class Batch: raise TypeError(s) @staticmethod - def cat(batches: List['Batch']) -> 'Batch': - """Concatenate a :class:`~tianshou.data.Batch` object into a single + def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': + """Concatenate a list of :class:`~tianshou.data.Batch` object into a single new batch. """ batch = Batch() for batch_ in batches: + if isinstance(batch_, dict): + batch_ = Batch(batch_) batch.cat_(batch_) return batch @@ -531,6 +540,7 @@ class Batch: keys_shared = set.intersection(*keys_map) values_shared = [ [e[k] for e in batches] for k in keys_shared] + _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): self.__dict__[k] = Batch.stack(v, axis) @@ -542,6 +552,7 @@ class Batch: v = v.astype(np.object) self.__dict__[k] = v keys_partial = reduce(set.symmetric_difference, keys_map) + _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches): val = e.get(k, None) @@ -554,7 +565,7 @@ class Batch: self.__dict__[k][i] = val @staticmethod - def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': + def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': """Stack a :class:`~tianshou.data.Batch` object into a single new batch. """ @@ -615,6 +626,9 @@ class Batch: raise TypeError("Object of type 'Batch' has no len()") return min(r) + def is_empty(self): + return len(self.__dict__.keys()) == 0 + @property def shape(self) -> List[int]: """Return self.shape."""