Fix tuple support. (#88)

This commit is contained in:
Alexis DUBURCQ 2020-06-23 17:37:26 +02:00 committed by GitHub
parent ec270759ab
commit d7dd3105bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 9 deletions

View File

@ -54,6 +54,9 @@ def test_batch_over_batch():
assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0
def test_batch_cat_and_stack():

View File

@ -3,7 +3,7 @@ import copy
import pprint
import warnings
import numpy as np
from typing import Any, List, Union, Iterator, Optional
from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
@ -76,29 +76,28 @@ class Batch:
def __init__(self,
batch_dict: Optional[
Union[dict, List[dict], np.ndarray]] = None,
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
**kwargs) -> None:
if isinstance(batch_dict, (list, np.ndarray)) \
if isinstance(batch_dict, (list, tuple, np.ndarray)) \
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])):
if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict):
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
if isinstance(v[0], dict) \
or (isinstance(v, (list, tuple, np.ndarray))
and len(v) > 0 and isinstance(v[0], dict)):
self.__dict__[k] = Batch(v)
elif isinstance(v[0], np.ndarray):
self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v)
elif isinstance(v[0], dict):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = list(v)
elif isinstance(batch_dict, dict):
for k, v in batch_dict.items():
if isinstance(v, dict) \
or (isinstance(v, (list, np.ndarray))
or (isinstance(v, (list, tuple, np.ndarray))
and len(v) > 0 and isinstance(v[0], dict)):
self.__dict__[k] = Batch(v)
else: