Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)
* re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
affeec13de
commit
5599a6d1a6
@ -166,7 +166,7 @@ def test_batch_cat_and_stack():
|
|||||||
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
||||||
assert b12_stack.a.d.e.ndim == 2
|
assert b12_stack.a.d.e.ndim == 2
|
||||||
|
|
||||||
# test batch with incompatible keys
|
# test cat with incompatible keys
|
||||||
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
||||||
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
||||||
test = Batch.cat([b1, b2])
|
test = Batch.cat([b1, b2])
|
||||||
@ -177,6 +177,7 @@ def test_batch_cat_and_stack():
|
|||||||
assert torch.allclose(test.b, ans.b)
|
assert torch.allclose(test.b, ans.b)
|
||||||
assert np.allclose(test.common.c, ans.common.c)
|
assert np.allclose(test.common.c, ans.common.c)
|
||||||
|
|
||||||
|
# test stack with compatible keys
|
||||||
b3 = Batch(a=np.zeros((3, 4)),
|
b3 = Batch(a=np.zeros((3, 4)),
|
||||||
b=torch.ones((2, 5)),
|
b=torch.ones((2, 5)),
|
||||||
c=Batch(d=[[1], [2]]))
|
c=Batch(d=[[1], [2]]))
|
||||||
@ -194,6 +195,26 @@ def test_batch_cat_and_stack():
|
|||||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||||
assert b5.b.d[1] == 0.0
|
assert b5.b.d[1] == 0.0
|
||||||
|
|
||||||
|
# test stack with incompatible keys
|
||||||
|
a = Batch(a=1, b=2, c=3)
|
||||||
|
b = Batch(a=4, b=5, d=6)
|
||||||
|
c = Batch(c=7, b=6, d=9)
|
||||||
|
d = Batch.stack([a, b, c])
|
||||||
|
assert np.allclose(d.a, [1, 4, 0])
|
||||||
|
assert np.allclose(d.b, [2, 5, 6])
|
||||||
|
assert np.allclose(d.c, [3, 0, 7])
|
||||||
|
assert np.allclose(d.d, [0, 6, 9])
|
||||||
|
|
||||||
|
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
|
||||||
|
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
|
||||||
|
test = Batch.stack([b1, b2])
|
||||||
|
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
|
||||||
|
b=torch.stack([torch.zeros(4, 6), b2.b]),
|
||||||
|
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
|
||||||
|
assert np.allclose(test.a, ans.a)
|
||||||
|
assert torch.allclose(test.b, ans.b)
|
||||||
|
assert np.allclose(test.common.c, ans.common.c)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch_to_torch():
|
def test_batch_over_batch_to_torch():
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import pprint
|
|||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import reduce
|
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||||
|
|
||||||
@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _create_value(inst: Any, size: int) -> Union[
|
def _create_value(inst: Any, size: int, stack=True) -> Union[
|
||||||
'Batch', np.ndarray, torch.Tensor]:
|
'Batch', np.ndarray, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
:param bool stack: whether to stack or to concatenate. E.g. if inst has
|
||||||
|
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
|
||||||
|
of (10, 3, 5), otherwise (10, 5)
|
||||||
|
"""
|
||||||
|
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
|
||||||
|
is_scalar = \
|
||||||
|
isinstance(inst, Number) or \
|
||||||
|
issubclass(inst.__class__, np.generic) or \
|
||||||
|
(has_shape and not inst.shape)
|
||||||
|
if not stack and is_scalar:
|
||||||
|
# here we do not consider scalar types, following the
|
||||||
|
# behavior of numpy which does not support concatenation
|
||||||
|
# of zero-dimensional arrays (scalars)
|
||||||
|
raise TypeError(f"cannot cat {inst} with which is scalar")
|
||||||
|
if has_shape:
|
||||||
|
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||||
target_type = inst.dtype.type
|
target_type = inst.dtype.type
|
||||||
else:
|
else:
|
||||||
target_type = np.object
|
target_type = np.object
|
||||||
return np.full((size, *inst.shape),
|
return np.full(shape,
|
||||||
fill_value=None if target_type == np.object else 0,
|
fill_value=None if target_type == np.object else 0,
|
||||||
dtype=target_type)
|
dtype=target_type)
|
||||||
elif isinstance(inst, torch.Tensor):
|
elif isinstance(inst, torch.Tensor):
|
||||||
return torch.full((size, *inst.shape),
|
return torch.full(shape,
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
device=inst.device,
|
device=inst.device,
|
||||||
dtype=inst.dtype)
|
dtype=inst.dtype)
|
||||||
elif isinstance(inst, (dict, Batch)):
|
elif isinstance(inst, (dict, Batch)):
|
||||||
zero_batch = Batch()
|
zero_batch = Batch()
|
||||||
for key, val in inst.items():
|
for key, val in inst.items():
|
||||||
zero_batch.__dict__[key] = _create_value(val, size)
|
zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
|
||||||
return zero_batch
|
return zero_batch
|
||||||
elif isinstance(inst, (np.generic, Number)):
|
elif is_scalar:
|
||||||
return _create_value(np.asarray(inst), size)
|
return _create_value(np.asarray(inst), size, stack=stack)
|
||||||
else: # fall back to np.object
|
else: # fall back to np.object
|
||||||
return np.array([None for _ in range(size)])
|
return np.array([None for _ in range(size)])
|
||||||
|
|
||||||
@ -495,10 +511,12 @@ class Batch:
|
|||||||
# partial keys will be padded by zeros
|
# partial keys will be padded by zeros
|
||||||
# with the shape of [len, rest_shape]
|
# with the shape of [len, rest_shape]
|
||||||
lens = [len(x) for x in batches]
|
lens = [len(x) for x in batches]
|
||||||
|
sum_lens = [0]
|
||||||
|
for x in lens:
|
||||||
|
sum_lens.append(sum_lens[-1] + x)
|
||||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||||
keys_shared = set.intersection(*keys_map)
|
keys_shared = set.intersection(*keys_map)
|
||||||
values_shared = [
|
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||||
[e[k] for e in batches] for k in keys_shared]
|
|
||||||
_assert_type_keys(keys_shared)
|
_assert_type_keys(keys_shared)
|
||||||
for k, v in zip(keys_shared, values_shared):
|
for k, v in zip(keys_shared, values_shared):
|
||||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||||
@ -513,40 +531,15 @@ class Batch:
|
|||||||
keys_partial = set.union(*keys_map) - keys_shared
|
keys_partial = set.union(*keys_map) - keys_shared
|
||||||
_assert_type_keys(keys_partial)
|
_assert_type_keys(keys_partial)
|
||||||
for k in keys_partial:
|
for k in keys_partial:
|
||||||
is_dict = False
|
|
||||||
value = None
|
|
||||||
for i, e in enumerate(batches):
|
for i, e in enumerate(batches):
|
||||||
val = e.get(k, None)
|
val = e.get(k, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
if isinstance(val, (dict, Batch)):
|
try:
|
||||||
is_dict = True
|
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||||
else: # np.ndarray or torch.Tensor
|
except KeyError:
|
||||||
value = val
|
self.__dict__[k] = \
|
||||||
break
|
_create_value(val, sum_lens[-1], stack=False)
|
||||||
if is_dict:
|
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||||
self.__dict__[k] = Batch.cat(
|
|
||||||
[e.get(k, Batch()) for e in batches])
|
|
||||||
else:
|
|
||||||
if isinstance(value, np.ndarray):
|
|
||||||
arrs = []
|
|
||||||
for i, e in enumerate(batches):
|
|
||||||
shape = [lens[i]] + list(value.shape[1:])
|
|
||||||
pad = np.zeros(shape, dtype=value.dtype)
|
|
||||||
arrs.append(e.get(k, pad))
|
|
||||||
self.__dict__[k] = np.concatenate(arrs)
|
|
||||||
elif isinstance(value, torch.Tensor):
|
|
||||||
arrs = []
|
|
||||||
for i, e in enumerate(batches):
|
|
||||||
shape = [lens[i]] + list(value.shape[1:])
|
|
||||||
pad = torch.zeros(shape,
|
|
||||||
dtype=value.dtype,
|
|
||||||
device=value.device)
|
|
||||||
arrs.append(e.get(k, pad))
|
|
||||||
self.__dict__[k] = torch.cat(arrs)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"cannot cat value with type {type(value)}, we only "
|
|
||||||
"support dict, Batch, np.ndarray, and torch.Tensor")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||||
@ -576,12 +569,14 @@ class Batch:
|
|||||||
"""Stack a list of :class:`~tianshou.data.Batch` object into current
|
"""Stack a list of :class:`~tianshou.data.Batch` object into current
|
||||||
batch.
|
batch.
|
||||||
"""
|
"""
|
||||||
|
if len(batches) == 0:
|
||||||
|
return
|
||||||
|
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||||
if len(self.__dict__) > 0:
|
if len(self.__dict__) > 0:
|
||||||
batches = [self] + list(batches)
|
batches = [self] + list(batches)
|
||||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||||
keys_shared = set.intersection(*keys_map)
|
keys_shared = set.intersection(*keys_map)
|
||||||
values_shared = [
|
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||||
[e[k] for e in batches] for k in keys_shared]
|
|
||||||
_assert_type_keys(keys_shared)
|
_assert_type_keys(keys_shared)
|
||||||
for k, v in zip(keys_shared, values_shared):
|
for k, v in zip(keys_shared, values_shared):
|
||||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||||
@ -593,7 +588,11 @@ class Batch:
|
|||||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||||
v = v.astype(np.object)
|
v = v.astype(np.object)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
keys_partial = set.difference(set.union(*keys_map), keys_shared)
|
||||||
|
if keys_partial and axis != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Stack of Batch with non-shared keys {keys_partial} "
|
||||||
|
f"is only supported with axis=0, but got axis={axis}!")
|
||||||
_assert_type_keys(keys_partial)
|
_assert_type_keys(keys_partial)
|
||||||
for k in keys_partial:
|
for k in keys_partial:
|
||||||
for i, e in enumerate(batches):
|
for i, e in enumerate(batches):
|
||||||
@ -609,7 +608,24 @@ class Batch:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||||
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
|
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
|
||||||
new batch.
|
new batch. For keys that are not shared across all batches,
|
||||||
|
batches that do not have these keys will be padded by zeros. E.g.
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
|
||||||
|
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
|
||||||
|
>>> c = Batch.stack([a, b])
|
||||||
|
>>> c.a.shape
|
||||||
|
(2, 4, 4)
|
||||||
|
>>> c.b.shape
|
||||||
|
(2, 4, 6)
|
||||||
|
>>> c.common.c.shape
|
||||||
|
(2, 4, 5)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If there are keys that are not shared across all batches, ``stack``
|
||||||
|
with ``axis != 0`` is undefined, and will cause an exception.
|
||||||
"""
|
"""
|
||||||
batch = Batch()
|
batch = Batch()
|
||||||
batch.stack_(batches, axis)
|
batch.stack_(batches, axis)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user