Fix support of 0-dim numpy array (#89)
* Fix support of 0-dim numpy array. * Do not raise exception if Batch index does not make sense since it breaks existing code. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
d7dd3105bc
commit
ebc551a25e
@ -34,6 +34,18 @@ def test_batch():
|
|||||||
assert batch_item.a.c == batch_dict['c']
|
assert batch_item.a.c == batch_dict['c']
|
||||||
assert isinstance(batch_item.a.d, torch.Tensor)
|
assert isinstance(batch_item.a.d, torch.Tensor)
|
||||||
assert batch_item.a.d == batch_dict['d']
|
assert batch_item.a.d == batch_dict['d']
|
||||||
|
batch2 = Batch(a=[{
|
||||||
|
'b': np.float64(1.0),
|
||||||
|
'c': np.zeros(1),
|
||||||
|
'd': Batch(e=np.array(3.0))}])
|
||||||
|
assert len(batch2) == 1
|
||||||
|
assert list(batch2[1].keys()) == ['a']
|
||||||
|
assert len(batch2[-2].a.d.keys()) == 0
|
||||||
|
assert len(batch2[-1].keys()) > 0
|
||||||
|
assert batch2[0][0].a.c == 0.0
|
||||||
|
assert isinstance(batch2[0].a.c, np.ndarray)
|
||||||
|
assert isinstance(batch2[0].a.b, np.float64)
|
||||||
|
assert isinstance(batch2[0].a.d.e, np.float64)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch():
|
def test_batch_over_batch():
|
||||||
@ -60,15 +72,18 @@ def test_batch_over_batch():
|
|||||||
|
|
||||||
|
|
||||||
def test_batch_cat_and_stack():
|
def test_batch_cat_and_stack():
|
||||||
b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}])
|
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||||
b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}])
|
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||||
b_cat_out = Batch.cat((b1, b2))
|
b_cat_out = Batch.cat((b1, b2))
|
||||||
b_cat_in = copy.deepcopy(b1)
|
b_cat_in = copy.deepcopy(b1)
|
||||||
b_cat_in.cat_(b2)
|
b_cat_in.cat_(b2)
|
||||||
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
||||||
assert b_cat_in.a.d.e.ndim == 2
|
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
||||||
|
assert isinstance(b_cat_in.a.d.e, np.ndarray)
|
||||||
|
assert b_cat_in.a.d.e.ndim == 1
|
||||||
b_stack = Batch.stack((b1, b2))
|
b_stack = Batch.stack((b1, b2))
|
||||||
assert b_stack.a.d.e.ndim == 3
|
assert isinstance(b_stack.a.d.e, np.ndarray)
|
||||||
|
assert b_stack.a.d.e.ndim == 2
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch_to_torch():
|
def test_batch_over_batch_to_torch():
|
||||||
|
@ -78,15 +78,23 @@ class Batch:
|
|||||||
batch_dict: Optional[
|
batch_dict: Optional[
|
||||||
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
|
Union[dict, Tuple[dict], List[dict], np.ndarray]] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
if isinstance(batch_dict, (list, tuple, np.ndarray)) \
|
def _is_batch_set(data: Any) -> bool:
|
||||||
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
|
if isinstance(data, (list, tuple)):
|
||||||
|
if len(data) > 0 and isinstance(data[0], dict):
|
||||||
|
return True
|
||||||
|
elif isinstance(data, np.ndarray):
|
||||||
|
if isinstance(data.item(0), dict):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(batch_dict, np.ndarray) and batch_dict.ndim == 0:
|
||||||
|
batch_dict = batch_dict[()]
|
||||||
|
if _is_batch_set(batch_dict):
|
||||||
for k, v in zip(batch_dict[0].keys(),
|
for k, v in zip(batch_dict[0].keys(),
|
||||||
zip(*[e.values() for e in batch_dict])):
|
zip(*[e.values() for e in batch_dict])):
|
||||||
if isinstance(v[0], dict) \
|
if isinstance(v[0], dict) or _is_batch_set(v[0]):
|
||||||
or (isinstance(v, (list, tuple, np.ndarray))
|
|
||||||
and len(v) > 0 and isinstance(v[0], dict)):
|
|
||||||
self.__dict__[k] = Batch(v)
|
self.__dict__[k] = Batch(v)
|
||||||
elif isinstance(v[0], np.ndarray):
|
elif isinstance(v[0], (np.generic, np.ndarray)):
|
||||||
self.__dict__[k] = np.stack(v, axis=0)
|
self.__dict__[k] = np.stack(v, axis=0)
|
||||||
elif isinstance(v[0], torch.Tensor):
|
elif isinstance(v[0], torch.Tensor):
|
||||||
self.__dict__[k] = torch.stack(v, dim=0)
|
self.__dict__[k] = torch.stack(v, dim=0)
|
||||||
@ -96,9 +104,7 @@ class Batch:
|
|||||||
self.__dict__[k] = list(v)
|
self.__dict__[k] = list(v)
|
||||||
elif isinstance(batch_dict, dict):
|
elif isinstance(batch_dict, dict):
|
||||||
for k, v in batch_dict.items():
|
for k, v in batch_dict.items():
|
||||||
if isinstance(v, dict) \
|
if isinstance(v, dict) or _is_batch_set(v):
|
||||||
or (isinstance(v, (list, tuple, np.ndarray))
|
|
||||||
and len(v) > 0 and isinstance(v[0], dict)):
|
|
||||||
self.__dict__[k] = Batch(v)
|
self.__dict__[k] = Batch(v)
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
@ -124,17 +130,31 @@ class Batch:
|
|||||||
"""
|
"""
|
||||||
self.__init__(**state)
|
self.__init__(**state)
|
||||||
|
|
||||||
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
|
def __getitem__(self, index: Union[
|
||||||
|
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
|
||||||
"""Return self[index]."""
|
"""Return self[index]."""
|
||||||
|
def _valid_bounds(length: int, index: Union[
|
||||||
|
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
|
||||||
|
if isinstance(index, (int, np.integer)):
|
||||||
|
return -length <= index and index < length
|
||||||
|
elif isinstance(index, (list, np.ndarray)):
|
||||||
|
return _valid_bounds(length, min(index)) and \
|
||||||
|
_valid_bounds(length, max(index))
|
||||||
|
elif isinstance(index, slice):
|
||||||
|
return _valid_bounds(length, index.start) and \
|
||||||
|
_valid_bounds(length, index.stop - 1)
|
||||||
|
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
return self.__getattr__(index)
|
return self.__getattr__(index)
|
||||||
|
else:
|
||||||
b = Batch()
|
b = Batch()
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if hasattr(v, '__len__'):
|
if isinstance(v, Batch):
|
||||||
try:
|
b.__dict__[k] = v[index]
|
||||||
b.__dict__.update(**{k: v[index]})
|
elif hasattr(v, '__len__') and (not isinstance(
|
||||||
except IndexError:
|
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||||
continue
|
if _valid_bounds(len(v), index):
|
||||||
|
b.__dict__[k] = v[index]
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||||
@ -198,7 +218,7 @@ class Batch:
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, (np.generic, np.ndarray)):
|
||||||
v = torch.from_numpy(v).to(device)
|
v = torch.from_numpy(v).to(device)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
v = v.type(dtype)
|
v = v.type(dtype)
|
||||||
@ -236,7 +256,7 @@ class Batch:
|
|||||||
continue
|
continue
|
||||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||||
self.__dict__[k] = copy.deepcopy(v)
|
self.__dict__[k] = copy.deepcopy(v)
|
||||||
elif isinstance(v, np.ndarray):
|
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
||||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||||
@ -274,7 +294,11 @@ class Batch:
|
|||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]
|
r = []
|
||||||
|
for v in self.__dict__.values():
|
||||||
|
if hasattr(v, '__len__') and (not isinstance(
|
||||||
|
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||||
|
r.append(len(v))
|
||||||
return max(r) if len(r) > 0 else 0
|
return max(r) if len(r) > 0 else 0
|
||||||
|
|
||||||
def split(self, size: Optional[int] = None,
|
def split(self, size: Optional[int] = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user