Improve Batch (#126)
* make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack
This commit is contained in:
parent
47e8e2686c
commit
2564e989fb
@ -9,6 +9,10 @@ from tianshou.data import Batch, to_torch
|
||||
|
||||
def test_batch():
|
||||
assert list(Batch()) == []
|
||||
assert Batch().is_empty()
|
||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||
with pytest.raises(AssertionError):
|
||||
Batch({1: 2})
|
||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||
|
@ -50,6 +50,12 @@ def _create_value(inst: Any, size: int) -> Union[
|
||||
return np.array([None for _ in range(size)])
|
||||
|
||||
|
||||
def _assert_type_keys(keys):
|
||||
keys = list(keys)
|
||||
assert all(isinstance(e, str) for e in keys), \
|
||||
f"keys should all be string, but got {keys}"
|
||||
|
||||
|
||||
class Batch:
|
||||
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||
structure to pass any kind of data to other methods, for example, a
|
||||
@ -247,6 +253,7 @@ class Batch:
|
||||
batch_dict = deepcopy(batch_dict)
|
||||
if batch_dict is not None:
|
||||
if isinstance(batch_dict, (dict, Batch)):
|
||||
_assert_type_keys(batch_dict.keys())
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, (list, tuple, np.ndarray)):
|
||||
v_ = None
|
||||
@ -511,12 +518,14 @@ class Batch:
|
||||
raise TypeError(s)
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List['Batch']) -> 'Batch':
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
|
||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
|
||||
new batch.
|
||||
"""
|
||||
batch = Batch()
|
||||
for batch_ in batches:
|
||||
if isinstance(batch_, dict):
|
||||
batch_ = Batch(batch_)
|
||||
batch.cat_(batch_)
|
||||
return batch
|
||||
|
||||
@ -531,6 +540,7 @@ class Batch:
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [
|
||||
[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
self.__dict__[k] = Batch.stack(v, axis)
|
||||
@ -542,6 +552,7 @@ class Batch:
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
||||
_assert_type_keys(keys_partial)
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
val = e.get(k, None)
|
||||
@ -554,7 +565,7 @@ class Batch:
|
||||
self.__dict__[k][i] = val
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
|
||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
||||
batch.
|
||||
"""
|
||||
@ -615,6 +626,9 @@ class Batch:
|
||||
raise TypeError("Object of type 'Batch' has no len()")
|
||||
return min(r)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.__dict__.keys()) == 0
|
||||
|
||||
@property
|
||||
def shape(self) -> List[int]:
|
||||
"""Return self.shape."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user