Fix tuple support. (#88)
This commit is contained in:
parent
ec270759ab
commit
d7dd3105bc
@ -54,6 +54,9 @@ def test_batch_over_batch():
|
||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
|
||||
assert batch4.a.b.ndim == 2
|
||||
assert batch4.a.b[0, 0] == 1.0
|
||||
|
||||
|
||||
def test_batch_cat_and_stack():
|
||||
|
@ -3,7 +3,7 @@ import copy
|
||||
import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, Iterator, Optional
|
||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
# on torch master branch. See Pull Request #39003 for details:
|
||||
@ -76,29 +76,28 @@ class Batch:
|
||||
|
||||
def __init__(self,
|
||||
batch_dict: Optional[
|
||||
Union[dict, List[dict], np.ndarray]] = None,
|
||||
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
|
||||
**kwargs) -> None:
|
||||
if isinstance(batch_dict, (list, np.ndarray)) \
|
||||
if isinstance(batch_dict, (list, tuple, np.ndarray)) \
|
||||
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
|
||||
for k, v in zip(batch_dict[0].keys(),
|
||||
zip(*[e.values() for e in batch_dict])):
|
||||
if isinstance(v, (list, np.ndarray)) \
|
||||
and len(v) > 0 and isinstance(v[0], dict):
|
||||
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
|
||||
if isinstance(v[0], dict) \
|
||||
or (isinstance(v, (list, tuple, np.ndarray))
|
||||
and len(v) > 0 and isinstance(v[0], dict)):
|
||||
self.__dict__[k] = Batch(v)
|
||||
elif isinstance(v[0], np.ndarray):
|
||||
self.__dict__[k] = np.stack(v, axis=0)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
self.__dict__[k] = torch.stack(v, dim=0)
|
||||
elif isinstance(v[0], Batch):
|
||||
self.__dict__[k] = Batch.stack(v)
|
||||
elif isinstance(v[0], dict):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
self.__dict__[k] = list(v)
|
||||
elif isinstance(batch_dict, dict):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, dict) \
|
||||
or (isinstance(v, (list, np.ndarray))
|
||||
or (isinstance(v, (list, tuple, np.ndarray))
|
||||
and len(v) > 0 and isinstance(v[0], dict)):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user