Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)
* re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
		
							parent
							
								
									affeec13de
								
							
						
					
					
						commit
						5599a6d1a6
					
				| @ -166,7 +166,7 @@ def test_batch_cat_and_stack(): | |||||||
|     assert isinstance(b12_stack.a.d.e, np.ndarray) |     assert isinstance(b12_stack.a.d.e, np.ndarray) | ||||||
|     assert b12_stack.a.d.e.ndim == 2 |     assert b12_stack.a.d.e.ndim == 2 | ||||||
| 
 | 
 | ||||||
|     # test batch with incompatible keys |     # test cat with incompatible keys | ||||||
|     b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) |     b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) | ||||||
|     b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) |     b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) | ||||||
|     test = Batch.cat([b1, b2]) |     test = Batch.cat([b1, b2]) | ||||||
| @ -177,6 +177,7 @@ def test_batch_cat_and_stack(): | |||||||
|     assert torch.allclose(test.b, ans.b) |     assert torch.allclose(test.b, ans.b) | ||||||
|     assert np.allclose(test.common.c, ans.common.c) |     assert np.allclose(test.common.c, ans.common.c) | ||||||
| 
 | 
 | ||||||
|  |     # test stack with compatible keys | ||||||
|     b3 = Batch(a=np.zeros((3, 4)), |     b3 = Batch(a=np.zeros((3, 4)), | ||||||
|                b=torch.ones((2, 5)), |                b=torch.ones((2, 5)), | ||||||
|                c=Batch(d=[[1], [2]])) |                c=Batch(d=[[1], [2]])) | ||||||
| @ -194,6 +195,26 @@ def test_batch_cat_and_stack(): | |||||||
|     assert b5.b.d[0] == b5_dict[0]['b']['d'] |     assert b5.b.d[0] == b5_dict[0]['b']['d'] | ||||||
|     assert b5.b.d[1] == 0.0 |     assert b5.b.d[1] == 0.0 | ||||||
| 
 | 
 | ||||||
|  |     # test stack with incompatible keys | ||||||
|  |     a = Batch(a=1, b=2, c=3) | ||||||
|  |     b = Batch(a=4, b=5, d=6) | ||||||
|  |     c = Batch(c=7, b=6, d=9) | ||||||
|  |     d = Batch.stack([a, b, c]) | ||||||
|  |     assert np.allclose(d.a, [1, 4, 0]) | ||||||
|  |     assert np.allclose(d.b, [2, 5, 6]) | ||||||
|  |     assert np.allclose(d.c, [3, 0, 7]) | ||||||
|  |     assert np.allclose(d.d, [0, 6, 9]) | ||||||
|  | 
 | ||||||
|  |     b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) | ||||||
|  |     b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) | ||||||
|  |     test = Batch.stack([b1, b2]) | ||||||
|  |     ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), | ||||||
|  |                 b=torch.stack([torch.zeros(4, 6), b2.b]), | ||||||
|  |                 common=Batch(c=np.stack([b1.common.c, b2.common.c]))) | ||||||
|  |     assert np.allclose(test.a, ans.a) | ||||||
|  |     assert torch.allclose(test.b, ans.b) | ||||||
|  |     assert np.allclose(test.common.c, ans.common.c) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def test_batch_over_batch_to_torch(): | def test_batch_over_batch_to_torch(): | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ import pprint | |||||||
| import warnings | import warnings | ||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from functools import reduce |  | ||||||
| from numbers import Number | from numbers import Number | ||||||
| from typing import Any, List, Tuple, Union, Iterator, Optional | from typing import Any, List, Tuple, Union, Iterator, Optional | ||||||
| 
 | 
 | ||||||
| @ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool: | |||||||
|     return False |     return False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _create_value(inst: Any, size: int) -> Union[ | def _create_value(inst: Any, size: int, stack=True) -> Union[ | ||||||
|         'Batch', np.ndarray, torch.Tensor]: |         'Batch', np.ndarray, torch.Tensor]: | ||||||
|  |     """ | ||||||
|  |     :param bool stack: whether to stack or to concatenate. E.g. if inst has | ||||||
|  |         shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape | ||||||
|  |         of (10, 3, 5), otherwise (10, 5) | ||||||
|  |     """ | ||||||
|  |     has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) | ||||||
|  |     is_scalar = \ | ||||||
|  |         isinstance(inst, Number) or \ | ||||||
|  |         issubclass(inst.__class__, np.generic) or \ | ||||||
|  |         (has_shape and not inst.shape) | ||||||
|  |     if not stack and is_scalar: | ||||||
|  |         # here we do not consider scalar types, following the | ||||||
|  |         # behavior of numpy which does not support concatenation | ||||||
|  |         # of zero-dimensional arrays (scalars) | ||||||
|  |         raise TypeError(f"cannot cat {inst} with which is scalar") | ||||||
|  |     if has_shape: | ||||||
|  |         shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) | ||||||
|     if isinstance(inst, np.ndarray): |     if isinstance(inst, np.ndarray): | ||||||
|         if issubclass(inst.dtype.type, (np.bool_, np.number)): |         if issubclass(inst.dtype.type, (np.bool_, np.number)): | ||||||
|             target_type = inst.dtype.type |             target_type = inst.dtype.type | ||||||
|         else: |         else: | ||||||
|             target_type = np.object |             target_type = np.object | ||||||
|         return np.full((size, *inst.shape), |         return np.full(shape, | ||||||
|                        fill_value=None if target_type == np.object else 0, |                        fill_value=None if target_type == np.object else 0, | ||||||
|                        dtype=target_type) |                        dtype=target_type) | ||||||
|     elif isinstance(inst, torch.Tensor): |     elif isinstance(inst, torch.Tensor): | ||||||
|         return torch.full((size, *inst.shape), |         return torch.full(shape, | ||||||
|                           fill_value=0, |                           fill_value=0, | ||||||
|                           device=inst.device, |                           device=inst.device, | ||||||
|                           dtype=inst.dtype) |                           dtype=inst.dtype) | ||||||
|     elif isinstance(inst, (dict, Batch)): |     elif isinstance(inst, (dict, Batch)): | ||||||
|         zero_batch = Batch() |         zero_batch = Batch() | ||||||
|         for key, val in inst.items(): |         for key, val in inst.items(): | ||||||
|             zero_batch.__dict__[key] = _create_value(val, size) |             zero_batch.__dict__[key] = _create_value(val, size, stack=stack) | ||||||
|         return zero_batch |         return zero_batch | ||||||
|     elif isinstance(inst, (np.generic, Number)): |     elif is_scalar: | ||||||
|         return _create_value(np.asarray(inst), size) |         return _create_value(np.asarray(inst), size, stack=stack) | ||||||
|     else:  # fall back to np.object |     else:  # fall back to np.object | ||||||
|         return np.array([None for _ in range(size)]) |         return np.array([None for _ in range(size)]) | ||||||
| 
 | 
 | ||||||
| @ -495,10 +511,12 @@ class Batch: | |||||||
|         # partial keys will be padded by zeros |         # partial keys will be padded by zeros | ||||||
|         # with the shape of [len, rest_shape] |         # with the shape of [len, rest_shape] | ||||||
|         lens = [len(x) for x in batches] |         lens = [len(x) for x in batches] | ||||||
|  |         sum_lens = [0] | ||||||
|  |         for x in lens: | ||||||
|  |             sum_lens.append(sum_lens[-1] + x) | ||||||
|         keys_map = list(map(lambda e: set(e.keys()), batches)) |         keys_map = list(map(lambda e: set(e.keys()), batches)) | ||||||
|         keys_shared = set.intersection(*keys_map) |         keys_shared = set.intersection(*keys_map) | ||||||
|         values_shared = [ |         values_shared = [[e[k] for e in batches] for k in keys_shared] | ||||||
|             [e[k] for e in batches] for k in keys_shared] |  | ||||||
|         _assert_type_keys(keys_shared) |         _assert_type_keys(keys_shared) | ||||||
|         for k, v in zip(keys_shared, values_shared): |         for k, v in zip(keys_shared, values_shared): | ||||||
|             if all(isinstance(e, (dict, Batch)) for e in v): |             if all(isinstance(e, (dict, Batch)) for e in v): | ||||||
| @ -513,40 +531,15 @@ class Batch: | |||||||
|         keys_partial = set.union(*keys_map) - keys_shared |         keys_partial = set.union(*keys_map) - keys_shared | ||||||
|         _assert_type_keys(keys_partial) |         _assert_type_keys(keys_partial) | ||||||
|         for k in keys_partial: |         for k in keys_partial: | ||||||
|             is_dict = False |  | ||||||
|             value = None |  | ||||||
|             for i, e in enumerate(batches): |             for i, e in enumerate(batches): | ||||||
|                 val = e.get(k, None) |                 val = e.get(k, None) | ||||||
|                 if val is not None: |                 if val is not None: | ||||||
|                     if isinstance(val, (dict, Batch)): |                     try: | ||||||
|                         is_dict = True |                         self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val | ||||||
|                     else:  # np.ndarray or torch.Tensor |                     except KeyError: | ||||||
|                         value = val |                         self.__dict__[k] = \ | ||||||
|                     break |                             _create_value(val, sum_lens[-1], stack=False) | ||||||
|             if is_dict: |                         self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val | ||||||
|                 self.__dict__[k] = Batch.cat( |  | ||||||
|                     [e.get(k, Batch()) for e in batches]) |  | ||||||
|             else: |  | ||||||
|                 if isinstance(value, np.ndarray): |  | ||||||
|                     arrs = [] |  | ||||||
|                     for i, e in enumerate(batches): |  | ||||||
|                         shape = [lens[i]] + list(value.shape[1:]) |  | ||||||
|                         pad = np.zeros(shape, dtype=value.dtype) |  | ||||||
|                         arrs.append(e.get(k, pad)) |  | ||||||
|                     self.__dict__[k] = np.concatenate(arrs) |  | ||||||
|                 elif isinstance(value, torch.Tensor): |  | ||||||
|                     arrs = [] |  | ||||||
|                     for i, e in enumerate(batches): |  | ||||||
|                         shape = [lens[i]] + list(value.shape[1:]) |  | ||||||
|                         pad = torch.zeros(shape, |  | ||||||
|                                           dtype=value.dtype, |  | ||||||
|                                           device=value.device) |  | ||||||
|                         arrs.append(e.get(k, pad)) |  | ||||||
|                     self.__dict__[k] = torch.cat(arrs) |  | ||||||
|                 else: |  | ||||||
|                     raise TypeError( |  | ||||||
|                         f"cannot cat value with type {type(value)}, we only " |  | ||||||
|                         "support dict, Batch, np.ndarray, and torch.Tensor") |  | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': |     def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': | ||||||
| @ -576,12 +569,14 @@ class Batch: | |||||||
|         """Stack a list of :class:`~tianshou.data.Batch` object into current |         """Stack a list of :class:`~tianshou.data.Batch` object into current | ||||||
|         batch. |         batch. | ||||||
|         """ |         """ | ||||||
|  |         if len(batches) == 0: | ||||||
|  |             return | ||||||
|  |         batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] | ||||||
|         if len(self.__dict__) > 0: |         if len(self.__dict__) > 0: | ||||||
|             batches = [self] + list(batches) |             batches = [self] + list(batches) | ||||||
|         keys_map = list(map(lambda e: set(e.keys()), batches)) |         keys_map = list(map(lambda e: set(e.keys()), batches)) | ||||||
|         keys_shared = set.intersection(*keys_map) |         keys_shared = set.intersection(*keys_map) | ||||||
|         values_shared = [ |         values_shared = [[e[k] for e in batches] for k in keys_shared] | ||||||
|             [e[k] for e in batches] for k in keys_shared] |  | ||||||
|         _assert_type_keys(keys_shared) |         _assert_type_keys(keys_shared) | ||||||
|         for k, v in zip(keys_shared, values_shared): |         for k, v in zip(keys_shared, values_shared): | ||||||
|             if all(isinstance(e, (dict, Batch)) for e in v): |             if all(isinstance(e, (dict, Batch)) for e in v): | ||||||
| @ -593,7 +588,11 @@ class Batch: | |||||||
|                 if not issubclass(v.dtype.type, (np.bool_, np.number)): |                 if not issubclass(v.dtype.type, (np.bool_, np.number)): | ||||||
|                     v = v.astype(np.object) |                     v = v.astype(np.object) | ||||||
|                 self.__dict__[k] = v |                 self.__dict__[k] = v | ||||||
|         keys_partial = reduce(set.symmetric_difference, keys_map) |         keys_partial = set.difference(set.union(*keys_map), keys_shared) | ||||||
|  |         if keys_partial and axis != 0: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"Stack of Batch with non-shared keys {keys_partial} " | ||||||
|  |                 f"is only supported with axis=0, but got axis={axis}!") | ||||||
|         _assert_type_keys(keys_partial) |         _assert_type_keys(keys_partial) | ||||||
|         for k in keys_partial: |         for k in keys_partial: | ||||||
|             for i, e in enumerate(batches): |             for i, e in enumerate(batches): | ||||||
| @ -609,7 +608,24 @@ class Batch: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': |     def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': | ||||||
|         """Stack a list of :class:`~tianshou.data.Batch` object into a single |         """Stack a list of :class:`~tianshou.data.Batch` object into a single | ||||||
|         new batch. |         new batch. For keys that are not shared across all batches, | ||||||
|  |         batches that do not have these keys will be padded by zeros. E.g. | ||||||
|  |         :: | ||||||
|  | 
 | ||||||
|  |             >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) | ||||||
|  |             >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) | ||||||
|  |             >>> c = Batch.stack([a, b]) | ||||||
|  |             >>> c.a.shape | ||||||
|  |             (2, 4, 4) | ||||||
|  |             >>> c.b.shape | ||||||
|  |             (2, 4, 6) | ||||||
|  |             >>> c.common.c.shape | ||||||
|  |             (2, 4, 5) | ||||||
|  | 
 | ||||||
|  |         .. note:: | ||||||
|  | 
 | ||||||
|  |             If there are keys that are not shared across all batches, ``stack`` | ||||||
|  |             with ``axis != 0`` is undefined, and will cause an exception. | ||||||
|         """ |         """ | ||||||
|         batch = Batch() |         batch = Batch() | ||||||
|         batch.stack_(batches, axis) |         batch.stack_(batches, axis) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user