Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)
* re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
affeec13de
commit
5599a6d1a6
@ -166,7 +166,7 @@ def test_batch_cat_and_stack():
|
||||
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
||||
assert b12_stack.a.d.e.ndim == 2
|
||||
|
||||
# test batch with incompatible keys
|
||||
# test cat with incompatible keys
|
||||
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
||||
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.cat([b1, b2])
|
||||
@ -177,6 +177,7 @@ def test_batch_cat_and_stack():
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
# test stack with compatible keys
|
||||
b3 = Batch(a=np.zeros((3, 4)),
|
||||
b=torch.ones((2, 5)),
|
||||
c=Batch(d=[[1], [2]]))
|
||||
@ -194,6 +195,26 @@ def test_batch_cat_and_stack():
|
||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||
assert b5.b.d[1] == 0.0
|
||||
|
||||
# test stack with incompatible keys
|
||||
a = Batch(a=1, b=2, c=3)
|
||||
b = Batch(a=4, b=5, d=6)
|
||||
c = Batch(c=7, b=6, d=9)
|
||||
d = Batch.stack([a, b, c])
|
||||
assert np.allclose(d.a, [1, 4, 0])
|
||||
assert np.allclose(d.b, [2, 5, 6])
|
||||
assert np.allclose(d.c, [3, 0, 7])
|
||||
assert np.allclose(d.d, [0, 6, 9])
|
||||
|
||||
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
|
||||
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.stack([b1, b2])
|
||||
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
|
||||
b=torch.stack([torch.zeros(4, 6), b2.b]),
|
||||
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
|
||||
assert np.allclose(test.a, ans.a)
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
batch = Batch(
|
||||
|
@ -3,7 +3,6 @@ import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from numbers import Number
|
||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||
|
||||
@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _create_value(inst: Any, size: int) -> Union[
|
||||
def _create_value(inst: Any, size: int, stack=True) -> Union[
|
||||
'Batch', np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
:param bool stack: whether to stack or to concatenate. E.g. if inst has
|
||||
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
|
||||
of (10, 3, 5), otherwise (10, 5)
|
||||
"""
|
||||
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
|
||||
is_scalar = \
|
||||
isinstance(inst, Number) or \
|
||||
issubclass(inst.__class__, np.generic) or \
|
||||
(has_shape and not inst.shape)
|
||||
if not stack and is_scalar:
|
||||
# here we do not consider scalar types, following the
|
||||
# behavior of numpy which does not support concatenation
|
||||
# of zero-dimensional arrays (scalars)
|
||||
raise TypeError(f"cannot cat {inst} with which is scalar")
|
||||
if has_shape:
|
||||
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
|
||||
if isinstance(inst, np.ndarray):
|
||||
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||
target_type = inst.dtype.type
|
||||
else:
|
||||
target_type = np.object
|
||||
return np.full((size, *inst.shape),
|
||||
return np.full(shape,
|
||||
fill_value=None if target_type == np.object else 0,
|
||||
dtype=target_type)
|
||||
elif isinstance(inst, torch.Tensor):
|
||||
return torch.full((size, *inst.shape),
|
||||
return torch.full(shape,
|
||||
fill_value=0,
|
||||
device=inst.device,
|
||||
dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
zero_batch = Batch()
|
||||
for key, val in inst.items():
|
||||
zero_batch.__dict__[key] = _create_value(val, size)
|
||||
zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
|
||||
return zero_batch
|
||||
elif isinstance(inst, (np.generic, Number)):
|
||||
return _create_value(np.asarray(inst), size)
|
||||
elif is_scalar:
|
||||
return _create_value(np.asarray(inst), size, stack=stack)
|
||||
else: # fall back to np.object
|
||||
return np.array([None for _ in range(size)])
|
||||
|
||||
@ -495,10 +511,12 @@ class Batch:
|
||||
# partial keys will be padded by zeros
|
||||
# with the shape of [len, rest_shape]
|
||||
lens = [len(x) for x in batches]
|
||||
sum_lens = [0]
|
||||
for x in lens:
|
||||
sum_lens.append(sum_lens[-1] + x)
|
||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [
|
||||
[e[k] for e in batches] for k in keys_shared]
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
@ -513,40 +531,15 @@ class Batch:
|
||||
keys_partial = set.union(*keys_map) - keys_shared
|
||||
_assert_type_keys(keys_partial)
|
||||
for k in keys_partial:
|
||||
is_dict = False
|
||||
value = None
|
||||
for i, e in enumerate(batches):
|
||||
val = e.get(k, None)
|
||||
if val is not None:
|
||||
if isinstance(val, (dict, Batch)):
|
||||
is_dict = True
|
||||
else: # np.ndarray or torch.Tensor
|
||||
value = val
|
||||
break
|
||||
if is_dict:
|
||||
self.__dict__[k] = Batch.cat(
|
||||
[e.get(k, Batch()) for e in batches])
|
||||
else:
|
||||
if isinstance(value, np.ndarray):
|
||||
arrs = []
|
||||
for i, e in enumerate(batches):
|
||||
shape = [lens[i]] + list(value.shape[1:])
|
||||
pad = np.zeros(shape, dtype=value.dtype)
|
||||
arrs.append(e.get(k, pad))
|
||||
self.__dict__[k] = np.concatenate(arrs)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
arrs = []
|
||||
for i, e in enumerate(batches):
|
||||
shape = [lens[i]] + list(value.shape[1:])
|
||||
pad = torch.zeros(shape,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
arrs.append(e.get(k, pad))
|
||||
self.__dict__[k] = torch.cat(arrs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"cannot cat value with type {type(value)}, we only "
|
||||
"support dict, Batch, np.ndarray, and torch.Tensor")
|
||||
try:
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
except KeyError:
|
||||
self.__dict__[k] = \
|
||||
_create_value(val, sum_lens[-1], stack=False)
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||
@ -576,12 +569,14 @@ class Batch:
|
||||
"""Stack a list of :class:`~tianshou.data.Batch` object into current
|
||||
batch.
|
||||
"""
|
||||
if len(batches) == 0:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
if len(self.__dict__) > 0:
|
||||
batches = [self] + list(batches)
|
||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [
|
||||
[e[k] for e in batches] for k in keys_shared]
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
@ -593,7 +588,11 @@ class Batch:
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
||||
keys_partial = set.difference(set.union(*keys_map), keys_shared)
|
||||
if keys_partial and axis != 0:
|
||||
raise ValueError(
|
||||
f"Stack of Batch with non-shared keys {keys_partial} "
|
||||
f"is only supported with axis=0, but got axis={axis}!")
|
||||
_assert_type_keys(keys_partial)
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
@ -609,7 +608,24 @@ class Batch:
|
||||
@staticmethod
|
||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
|
||||
new batch.
|
||||
new batch. For keys that are not shared across all batches,
|
||||
batches that do not have these keys will be padded by zeros. E.g.
|
||||
::
|
||||
|
||||
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
|
||||
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
|
||||
>>> c = Batch.stack([a, b])
|
||||
>>> c.a.shape
|
||||
(2, 4, 4)
|
||||
>>> c.b.shape
|
||||
(2, 4, 6)
|
||||
>>> c.common.c.shape
|
||||
(2, 4, 5)
|
||||
|
||||
.. note::
|
||||
|
||||
If there are keys that are not shared across all batches, ``stack``
|
||||
with ``axis != 0`` is undefined, and will cause an exception.
|
||||
"""
|
||||
batch = Batch()
|
||||
batch.stack_(batches, axis)
|
||||
|
Loading…
x
Reference in New Issue
Block a user