Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix docs

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-12 23:45:42 +08:00 committed by n+e
parent affeec13de
commit 5599a6d1a6
2 changed files with 82 additions and 45 deletions

View File

@ -166,7 +166,7 @@ def test_batch_cat_and_stack():
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test batch with incompatible keys
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
@ -177,6 +177,7 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))
@ -194,6 +195,26 @@ def test_batch_cat_and_stack():
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
def test_batch_over_batch_to_torch():
batch = Batch(

View File

@ -3,7 +3,6 @@ import pprint
import warnings
import numpy as np
from copy import deepcopy
from functools import reduce
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool:
return False
def _create_value(inst: Any, size: int) -> Union[
def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]:
"""
:param bool stack: whether to stack or to concatenate. E.g. if inst has
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
of (10, 3, 5), otherwise (10, 5)
"""
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = \
isinstance(inst, Number) or \
issubclass(inst.__class__, np.generic) or \
(has_shape and not inst.shape)
if not stack and is_scalar:
# here we do not consider scalar types, following the
# behavior of numpy which does not support concatenation
# of zero-dimensional arrays (scalars)
raise TypeError(f"cannot cat {inst} with which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
return np.full((size, *inst.shape),
return np.full(shape,
fill_value=None if target_type == np.object else 0,
dtype=target_type)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
return torch.full(shape,
fill_value=0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
elif is_scalar:
return _create_value(np.asarray(inst), size, stack=stack)
else: # fall back to np.object
return np.array([None for _ in range(size)])
@ -495,10 +511,12 @@ class Batch:
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
lens = [len(x) for x in batches]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
@ -513,40 +531,15 @@ class Batch:
keys_partial = set.union(*keys_map) - keys_shared
_assert_type_keys(keys_partial)
for k in keys_partial:
is_dict = False
value = None
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
if isinstance(val, (dict, Batch)):
is_dict = True
else: # np.ndarray or torch.Tensor
value = val
break
if is_dict:
self.__dict__[k] = Batch.cat(
[e.get(k, Batch()) for e in batches])
else:
if isinstance(value, np.ndarray):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = np.zeros(shape, dtype=value.dtype)
arrs.append(e.get(k, pad))
self.__dict__[k] = np.concatenate(arrs)
elif isinstance(value, torch.Tensor):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = torch.zeros(shape,
dtype=value.dtype,
device=value.device)
arrs.append(e.get(k, pad))
self.__dict__[k] = torch.cat(arrs)
else:
raise TypeError(
f"cannot cat value with type {type(value)}, we only "
"support dict, Batch, np.ndarray, and torch.Tensor")
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
@ -576,12 +569,14 @@ class Batch:
"""Stack a list of :class:`~tianshou.data.Batch` object into current
batch.
"""
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0:
batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
@ -593,7 +588,11 @@ class Batch:
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = reduce(set.symmetric_difference, keys_map)
keys_partial = set.difference(set.union(*keys_map), keys_shared)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} "
f"is only supported with axis=0, but got axis={axis}!")
_assert_type_keys(keys_partial)
for k in keys_partial:
for i, e in enumerate(batches):
@ -609,7 +608,24 @@ class Batch:
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
new batch.
new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros. E.g.
::
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)
.. note::
If there are keys that are not shared across all batches, ``stack``
with ``axis != 0`` is undefined, and will cause an exception.
"""
batch = Batch()
batch.stack_(batches, axis)