Raise exception for Batch __getitem__. (#119)
* Raise exception for Batch __getitem__. * Try fixing access to reserved key. * Simpler patch. * Add unit test to check indexing empty Batch raises an exception. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
7f9a1f1328
commit
aa3c453f42
@ -277,7 +277,7 @@ def test_batch_empty():
|
|||||||
assert b0.shape == []
|
assert b0.shape == []
|
||||||
|
|
||||||
|
|
||||||
def test_batch_numpy_compatibility():
|
def test_batch_standard_compatibility():
|
||||||
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||||
b=Batch(),
|
b=Batch(),
|
||||||
c=np.array([5.0, 6.0]))
|
c=np.array([5.0, 6.0]))
|
||||||
@ -288,6 +288,8 @@ def test_batch_numpy_compatibility():
|
|||||||
len(batch_mean)
|
len(batch_mean)
|
||||||
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
|
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
|
||||||
assert batch_mean.c == np.mean(batch.c, axis=0)
|
assert batch_mean.c == np.mean(batch.c, axis=0)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
Batch()[0]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -297,7 +299,7 @@ if __name__ == '__main__':
|
|||||||
test_utils_to_torch()
|
test_utils_to_torch()
|
||||||
test_batch_pickle()
|
test_batch_pickle()
|
||||||
test_batch_from_to_numpy_without_copy()
|
test_batch_from_to_numpy_without_copy()
|
||||||
test_batch_numpy_compatibility()
|
test_batch_standard_compatibility()
|
||||||
test_batch_cat_and_stack()
|
test_batch_cat_and_stack()
|
||||||
test_batch_copy()
|
test_batch_copy()
|
||||||
test_batch_empty()
|
test_batch_empty()
|
||||||
|
@ -24,25 +24,6 @@ def _is_batch_set(data: Any) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _valid_bounds(length: int, index: Union[
|
|
||||||
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
|
|
||||||
if isinstance(index, (int, np.integer)):
|
|
||||||
return -length <= index and index < length
|
|
||||||
elif isinstance(index, (list, np.ndarray)):
|
|
||||||
return _valid_bounds(length, np.min(index)) and \
|
|
||||||
_valid_bounds(length, np.max(index))
|
|
||||||
elif isinstance(index, slice):
|
|
||||||
if index.start is not None:
|
|
||||||
start_valid = _valid_bounds(length, index.start)
|
|
||||||
else:
|
|
||||||
start_valid = True
|
|
||||||
if index.stop is not None:
|
|
||||||
stop_valid = _valid_bounds(length, index.stop - 1)
|
|
||||||
else:
|
|
||||||
stop_valid = True
|
|
||||||
return start_valid and stop_valid
|
|
||||||
|
|
||||||
|
|
||||||
def _create_value(inst: Any, size: int) -> Union[
|
def _create_value(inst: Any, size: int) -> Union[
|
||||||
'Batch', np.ndarray, torch.Tensor]:
|
'Batch', np.ndarray, torch.Tensor]:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
@ -333,13 +314,17 @@ class Batch:
|
|||||||
"""Return self[index]."""
|
"""Return self[index]."""
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
return self.__dict__[index]
|
return self.__dict__[index]
|
||||||
b = Batch()
|
batch_items = self.items()
|
||||||
for k, v in self.items():
|
if len(batch_items) > 0:
|
||||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
b = Batch()
|
||||||
b.__dict__[k] = Batch()
|
for k, v in batch_items:
|
||||||
else:
|
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||||
b.__dict__[k] = v[index]
|
b.__dict__[k] = Batch()
|
||||||
return b
|
else:
|
||||||
|
b.__dict__[k] = v[index]
|
||||||
|
return b
|
||||||
|
else:
|
||||||
|
raise IndexError("Cannot access item from empty Batch object.")
|
||||||
|
|
||||||
def __setitem__(
|
def __setitem__(
|
||||||
self,
|
self,
|
||||||
@ -452,14 +437,6 @@ class Batch:
|
|||||||
"""Return self[k] if k in self else d. d defaults to None."""
|
"""Return self[k] if k in self else d. d defaults to None."""
|
||||||
return self.__dict__.get(k, d)
|
return self.__dict__.get(k, d)
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
try:
|
|
||||||
length = len(self)
|
|
||||||
except Exception:
|
|
||||||
length = 0
|
|
||||||
for i in range(length):
|
|
||||||
yield self[i]
|
|
||||||
|
|
||||||
def to_numpy(self) -> None:
|
def to_numpy(self) -> None:
|
||||||
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
|
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
|
||||||
operation.
|
operation.
|
||||||
|
@ -272,8 +272,10 @@ class ReplayBuffer:
|
|||||||
stack = np.stack(stack, axis=indice.ndim)
|
stack = np.stack(stack, axis=indice.ndim)
|
||||||
else:
|
else:
|
||||||
stack = val[indice]
|
stack = val[indice]
|
||||||
except TypeError:
|
except IndexError as e:
|
||||||
stack = Batch()
|
stack = Batch()
|
||||||
|
if not isinstance(val, Batch) or len(val.__dict__) > 0:
|
||||||
|
raise e
|
||||||
self.done[last_index] = last_done
|
self.done[last_index] = last_done
|
||||||
return stack
|
return stack
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user