Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix docs

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-12 23:45:42 +08:00 committed by n+e
parent affeec13de
commit 5599a6d1a6
2 changed files with 82 additions and 45 deletions

View File

@ -166,7 +166,7 @@ def test_batch_cat_and_stack():
assert isinstance(b12_stack.a.d.e, np.ndarray) assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2 assert b12_stack.a.d.e.ndim == 2
# test batch with incompatible keys # test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2]) test = Batch.cat([b1, b2])
@ -177,6 +177,7 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b) assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c) assert np.allclose(test.common.c, ans.common.c)
# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)), b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)), b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]])) c=Batch(d=[[1], [2]]))
@ -194,6 +195,26 @@ def test_batch_cat_and_stack():
assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0 assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():
batch = Batch( batch = Batch(

View File

@ -3,7 +3,6 @@ import pprint
import warnings import warnings
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from functools import reduce
from numbers import Number from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional from typing import Any, List, Tuple, Union, Iterator, Optional
@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool:
return False return False
def _create_value(inst: Any, size: int) -> Union[ def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]: 'Batch', np.ndarray, torch.Tensor]:
"""
:param bool stack: whether to stack or to concatenate. E.g. if inst has
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
of (10, 3, 5), otherwise (10, 5)
"""
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = \
isinstance(inst, Number) or \
issubclass(inst.__class__, np.generic) or \
(has_shape and not inst.shape)
if not stack and is_scalar:
# here we do not consider scalar types, following the
# behavior of numpy which does not support concatenation
# of zero-dimensional arrays (scalars)
raise TypeError(f"cannot cat {inst} with which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)): if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type target_type = inst.dtype.type
else: else:
target_type = np.object target_type = np.object
return np.full((size, *inst.shape), return np.full(shape,
fill_value=None if target_type == np.object else 0, fill_value=None if target_type == np.object else 0,
dtype=target_type) dtype=target_type)
elif isinstance(inst, torch.Tensor): elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape), return torch.full(shape,
fill_value=0, fill_value=0,
device=inst.device, device=inst.device,
dtype=inst.dtype) dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)): elif isinstance(inst, (dict, Batch)):
zero_batch = Batch() zero_batch = Batch()
for key, val in inst.items(): for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size) zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
return zero_batch return zero_batch
elif isinstance(inst, (np.generic, Number)): elif is_scalar:
return _create_value(np.asarray(inst), size) return _create_value(np.asarray(inst), size, stack=stack)
else: # fall back to np.object else: # fall back to np.object
return np.array([None for _ in range(size)]) return np.array([None for _ in range(size)])
@ -495,10 +511,12 @@ class Batch:
# partial keys will be padded by zeros # partial keys will be padded by zeros
# with the shape of [len, rest_shape] # with the shape of [len, rest_shape]
lens = [len(x) for x in batches] lens = [len(x) for x in batches]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
keys_map = list(map(lambda e: set(e.keys()), batches)) keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map) keys_shared = set.intersection(*keys_map)
values_shared = [ values_shared = [[e[k] for e in batches] for k in keys_shared]
[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared) _assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared): for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v): if all(isinstance(e, (dict, Batch)) for e in v):
@ -513,40 +531,15 @@ class Batch:
keys_partial = set.union(*keys_map) - keys_shared keys_partial = set.union(*keys_map) - keys_shared
_assert_type_keys(keys_partial) _assert_type_keys(keys_partial)
for k in keys_partial: for k in keys_partial:
is_dict = False
value = None
for i, e in enumerate(batches): for i, e in enumerate(batches):
val = e.get(k, None) val = e.get(k, None)
if val is not None: if val is not None:
if isinstance(val, (dict, Batch)): try:
is_dict = True self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
else: # np.ndarray or torch.Tensor except KeyError:
value = val self.__dict__[k] = \
break _create_value(val, sum_lens[-1], stack=False)
if is_dict: self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
self.__dict__[k] = Batch.cat(
[e.get(k, Batch()) for e in batches])
else:
if isinstance(value, np.ndarray):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = np.zeros(shape, dtype=value.dtype)
arrs.append(e.get(k, pad))
self.__dict__[k] = np.concatenate(arrs)
elif isinstance(value, torch.Tensor):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = torch.zeros(shape,
dtype=value.dtype,
device=value.device)
arrs.append(e.get(k, pad))
self.__dict__[k] = torch.cat(arrs)
else:
raise TypeError(
f"cannot cat value with type {type(value)}, we only "
"support dict, Batch, np.ndarray, and torch.Tensor")
@staticmethod @staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
@ -576,12 +569,14 @@ class Batch:
"""Stack a list of :class:`~tianshou.data.Batch` object into current """Stack a list of :class:`~tianshou.data.Batch` object into current
batch. batch.
""" """
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0: if len(self.__dict__) > 0:
batches = [self] + list(batches) batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches)) keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map) keys_shared = set.intersection(*keys_map)
values_shared = [ values_shared = [[e[k] for e in batches] for k in keys_shared]
[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared) _assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared): for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v): if all(isinstance(e, (dict, Batch)) for e in v):
@ -593,7 +588,11 @@ class Batch:
if not issubclass(v.dtype.type, (np.bool_, np.number)): if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object) v = v.astype(np.object)
self.__dict__[k] = v self.__dict__[k] = v
keys_partial = reduce(set.symmetric_difference, keys_map) keys_partial = set.difference(set.union(*keys_map), keys_shared)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} "
f"is only supported with axis=0, but got axis={axis}!")
_assert_type_keys(keys_partial) _assert_type_keys(keys_partial)
for k in keys_partial: for k in keys_partial:
for i, e in enumerate(batches): for i, e in enumerate(batches):
@ -609,7 +608,24 @@ class Batch:
@staticmethod @staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a list of :class:`~tianshou.data.Batch` object into a single """Stack a list of :class:`~tianshou.data.Batch` object into a single
new batch. new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros. E.g.
::
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)
.. note::
If there are keys that are not shared across all batches, ``stack``
with ``axis != 0`` is undefined, and will cause an exception.
""" """
batch = Batch() batch = Batch()
batch.stack_(batches, axis) batch.stack_(batches, axis)