Improve Batch (#126)

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack
This commit is contained in:
youkaichao 2020-07-11 09:44:47 +08:00 committed by n+e
parent 47e8e2686c
commit 2564e989fb
2 changed files with 21 additions and 3 deletions

View File

@ -9,6 +9,10 @@ from tianshou.data import Batch, to_torch
def test_batch():
assert list(Batch()) == []
assert Batch().is_empty()
assert not Batch(a=[1, 2, 3]).is_empty()
with pytest.raises(AssertionError):
Batch({1: 2})
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3))
batch = Batch(obs=[0], np=np.zeros([3, 4]))

View File

@ -50,6 +50,12 @@ def _create_value(inst: Any, size: int) -> Union[
return np.array([None for _ in range(size)])
def _assert_type_keys(keys):
keys = list(keys)
assert all(isinstance(e, str) for e in keys), \
f"keys should all be string, but got {keys}"
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
@ -247,6 +253,7 @@ class Batch:
batch_dict = deepcopy(batch_dict)
if batch_dict is not None:
if isinstance(batch_dict, (dict, Batch)):
_assert_type_keys(batch_dict.keys())
for k, v in batch_dict.items():
if isinstance(v, (list, tuple, np.ndarray)):
v_ = None
@ -511,12 +518,14 @@ class Batch:
raise TypeError(s)
@staticmethod
def cat(batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
new batch.
"""
batch = Batch()
for batch_ in batches:
if isinstance(batch_, dict):
batch_ = Batch(batch_)
batch.cat_(batch_)
return batch
@ -531,6 +540,7 @@ class Batch:
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
self.__dict__[k] = Batch.stack(v, axis)
@ -542,6 +552,7 @@ class Batch:
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = reduce(set.symmetric_difference, keys_map)
_assert_type_keys(keys_partial)
for k in keys_partial:
for i, e in enumerate(batches):
val = e.get(k, None)
@ -554,7 +565,7 @@ class Batch:
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a single new
batch.
"""
@ -615,6 +626,9 @@ class Batch:
raise TypeError("Object of type 'Batch' has no len()")
return min(r)
def is_empty(self):
return len(self.__dict__.keys()) == 0
@property
def shape(self) -> List[int]:
"""Return self.shape."""