Fix tuple support. (#88)

This commit is contained in:
Alexis DUBURCQ 2020-06-23 17:37:26 +02:00 committed by GitHub
parent ec270759ab
commit d7dd3105bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 9 deletions

View File

@ -54,6 +54,9 @@ def test_batch_over_batch():
assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6] assert batch3.b.b == [4, 5, 6, 4, 5, 6]
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0
def test_batch_cat_and_stack(): def test_batch_cat_and_stack():

View File

@ -3,7 +3,7 @@ import copy
import pprint import pprint
import warnings import warnings
import numpy as np import numpy as np
from typing import Any, List, Union, Iterator, Optional from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed # Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details: # on torch master branch. See Pull Request #39003 for details:
@ -76,29 +76,28 @@ class Batch:
def __init__(self, def __init__(self,
batch_dict: Optional[ batch_dict: Optional[
Union[dict, List[dict], np.ndarray]] = None, Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
**kwargs) -> None: **kwargs) -> None:
if isinstance(batch_dict, (list, np.ndarray)) \ if isinstance(batch_dict, (list, tuple, np.ndarray)) \
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict): and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
for k, v in zip(batch_dict[0].keys(), for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])): zip(*[e.values() for e in batch_dict])):
if isinstance(v, (list, np.ndarray)) \ if isinstance(v[0], dict) \
and len(v) > 0 and isinstance(v[0], dict): or (isinstance(v, (list, tuple, np.ndarray))
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v]) and len(v) > 0 and isinstance(v[0], dict)):
self.__dict__[k] = Batch(v)
elif isinstance(v[0], np.ndarray): elif isinstance(v[0], np.ndarray):
self.__dict__[k] = np.stack(v, axis=0) self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor): elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0) self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch): elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v) self.__dict__[k] = Batch.stack(v)
elif isinstance(v[0], dict):
self.__dict__[k] = Batch(v)
else: else:
self.__dict__[k] = list(v) self.__dict__[k] = list(v)
elif isinstance(batch_dict, dict): elif isinstance(batch_dict, dict):
for k, v in batch_dict.items(): for k, v in batch_dict.items():
if isinstance(v, dict) \ if isinstance(v, dict) \
or (isinstance(v, (list, np.ndarray)) or (isinstance(v, (list, tuple, np.ndarray))
and len(v) > 0 and isinstance(v[0], dict)): and len(v) > 0 and isinstance(v[0], dict)):
self.__dict__[k] = Batch(v) self.__dict__[k] = Batch(v)
else: else: