Fix tuple support. (#88)
This commit is contained in:
parent
ec270759ab
commit
d7dd3105bc
@ -54,6 +54,9 @@ def test_batch_over_batch():
|
|||||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||||
|
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
|
||||||
|
assert batch4.a.b.ndim == 2
|
||||||
|
assert batch4.a.b[0, 0] == 1.0
|
||||||
|
|
||||||
|
|
||||||
def test_batch_cat_and_stack():
|
def test_batch_cat_and_stack():
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import copy
|
|||||||
import pprint
|
import pprint
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Union, Iterator, Optional
|
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||||
|
|
||||||
# Disable pickle warning related to torch, since it has been removed
|
# Disable pickle warning related to torch, since it has been removed
|
||||||
# on torch master branch. See Pull Request #39003 for details:
|
# on torch master branch. See Pull Request #39003 for details:
|
||||||
@ -76,29 +76,28 @@ class Batch:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_dict: Optional[
|
batch_dict: Optional[
|
||||||
Union[dict, List[dict], np.ndarray]] = None,
|
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
if isinstance(batch_dict, (list, np.ndarray)) \
|
if isinstance(batch_dict, (list, tuple, np.ndarray)) \
|
||||||
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
|
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
|
||||||
for k, v in zip(batch_dict[0].keys(),
|
for k, v in zip(batch_dict[0].keys(),
|
||||||
zip(*[e.values() for e in batch_dict])):
|
zip(*[e.values() for e in batch_dict])):
|
||||||
if isinstance(v, (list, np.ndarray)) \
|
if isinstance(v[0], dict) \
|
||||||
and len(v) > 0 and isinstance(v[0], dict):
|
or (isinstance(v, (list, tuple, np.ndarray))
|
||||||
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
|
and len(v) > 0 and isinstance(v[0], dict)):
|
||||||
|
self.__dict__[k] = Batch(v)
|
||||||
elif isinstance(v[0], np.ndarray):
|
elif isinstance(v[0], np.ndarray):
|
||||||
self.__dict__[k] = np.stack(v, axis=0)
|
self.__dict__[k] = np.stack(v, axis=0)
|
||||||
elif isinstance(v[0], torch.Tensor):
|
elif isinstance(v[0], torch.Tensor):
|
||||||
self.__dict__[k] = torch.stack(v, dim=0)
|
self.__dict__[k] = torch.stack(v, dim=0)
|
||||||
elif isinstance(v[0], Batch):
|
elif isinstance(v[0], Batch):
|
||||||
self.__dict__[k] = Batch.stack(v)
|
self.__dict__[k] = Batch.stack(v)
|
||||||
elif isinstance(v[0], dict):
|
|
||||||
self.__dict__[k] = Batch(v)
|
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] = list(v)
|
self.__dict__[k] = list(v)
|
||||||
elif isinstance(batch_dict, dict):
|
elif isinstance(batch_dict, dict):
|
||||||
for k, v in batch_dict.items():
|
for k, v in batch_dict.items():
|
||||||
if isinstance(v, dict) \
|
if isinstance(v, dict) \
|
||||||
or (isinstance(v, (list, np.ndarray))
|
or (isinstance(v, (list, tuple, np.ndarray))
|
||||||
and len(v) > 0 and isinstance(v[0], dict)):
|
and len(v) > 0 and isinstance(v[0], dict)):
|
||||||
self.__dict__[k] = Batch(v)
|
self.__dict__[k] = Batch(v)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user