Improve Batch (#126)
* make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack
This commit is contained in:
parent
47e8e2686c
commit
2564e989fb
@ -9,6 +9,10 @@ from tianshou.data import Batch, to_torch
|
|||||||
|
|
||||||
def test_batch():
|
def test_batch():
|
||||||
assert list(Batch()) == []
|
assert list(Batch()) == []
|
||||||
|
assert Batch().is_empty()
|
||||||
|
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
Batch({1: 2})
|
||||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||||
|
@ -50,6 +50,12 @@ def _create_value(inst: Any, size: int) -> Union[
|
|||||||
return np.array([None for _ in range(size)])
|
return np.array([None for _ in range(size)])
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_type_keys(keys):
|
||||||
|
keys = list(keys)
|
||||||
|
assert all(isinstance(e, str) for e in keys), \
|
||||||
|
f"keys should all be string, but got {keys}"
|
||||||
|
|
||||||
|
|
||||||
class Batch:
|
class Batch:
|
||||||
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||||
structure to pass any kind of data to other methods, for example, a
|
structure to pass any kind of data to other methods, for example, a
|
||||||
@ -247,6 +253,7 @@ class Batch:
|
|||||||
batch_dict = deepcopy(batch_dict)
|
batch_dict = deepcopy(batch_dict)
|
||||||
if batch_dict is not None:
|
if batch_dict is not None:
|
||||||
if isinstance(batch_dict, (dict, Batch)):
|
if isinstance(batch_dict, (dict, Batch)):
|
||||||
|
_assert_type_keys(batch_dict.keys())
|
||||||
for k, v in batch_dict.items():
|
for k, v in batch_dict.items():
|
||||||
if isinstance(v, (list, tuple, np.ndarray)):
|
if isinstance(v, (list, tuple, np.ndarray)):
|
||||||
v_ = None
|
v_ = None
|
||||||
@ -511,12 +518,14 @@ class Batch:
|
|||||||
raise TypeError(s)
|
raise TypeError(s)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cat(batches: List['Batch']) -> 'Batch':
|
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
|
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
|
||||||
new batch.
|
new batch.
|
||||||
"""
|
"""
|
||||||
batch = Batch()
|
batch = Batch()
|
||||||
for batch_ in batches:
|
for batch_ in batches:
|
||||||
|
if isinstance(batch_, dict):
|
||||||
|
batch_ = Batch(batch_)
|
||||||
batch.cat_(batch_)
|
batch.cat_(batch_)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -531,6 +540,7 @@ class Batch:
|
|||||||
keys_shared = set.intersection(*keys_map)
|
keys_shared = set.intersection(*keys_map)
|
||||||
values_shared = [
|
values_shared = [
|
||||||
[e[k] for e in batches] for k in keys_shared]
|
[e[k] for e in batches] for k in keys_shared]
|
||||||
|
_assert_type_keys(keys_shared)
|
||||||
for k, v in zip(keys_shared, values_shared):
|
for k, v in zip(keys_shared, values_shared):
|
||||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||||
self.__dict__[k] = Batch.stack(v, axis)
|
self.__dict__[k] = Batch.stack(v, axis)
|
||||||
@ -542,6 +552,7 @@ class Batch:
|
|||||||
v = v.astype(np.object)
|
v = v.astype(np.object)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
keys_partial = reduce(set.symmetric_difference, keys_map)
|
||||||
|
_assert_type_keys(keys_partial)
|
||||||
for k in keys_partial:
|
for k in keys_partial:
|
||||||
for i, e in enumerate(batches):
|
for i, e in enumerate(batches):
|
||||||
val = e.get(k, None)
|
val = e.get(k, None)
|
||||||
@ -554,7 +565,7 @@ class Batch:
|
|||||||
self.__dict__[k][i] = val
|
self.__dict__[k][i] = val
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
|
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||||
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
||||||
batch.
|
batch.
|
||||||
"""
|
"""
|
||||||
@ -615,6 +626,9 @@ class Batch:
|
|||||||
raise TypeError("Object of type 'Batch' has no len()")
|
raise TypeError("Object of type 'Batch' has no len()")
|
||||||
return min(r)
|
return min(r)
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return len(self.__dict__.keys()) == 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> List[int]:
|
def shape(self) -> List[int]:
|
||||||
"""Return self.shape."""
|
"""Return self.shape."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user