Raise exception for Batch __getitem__. (#119)

* Raise exception for Batch __getitem__.

* Try fixing access to reserved key.

* Simpler patch.

* Add unit test to check indexing empty Batch raises an exception.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-08 16:29:37 +02:00 committed by GitHub
parent 7f9a1f1328
commit aa3c453f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 37 deletions

View File

@ -277,7 +277,7 @@ def test_batch_empty():
assert b0.shape == [] assert b0.shape == []
def test_batch_numpy_compatibility(): def test_batch_standard_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(), b=Batch(),
c=np.array([5.0, 6.0])) c=np.array([5.0, 6.0]))
@ -288,6 +288,8 @@ def test_batch_numpy_compatibility():
len(batch_mean) len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
assert batch_mean.c == np.mean(batch.c, axis=0) assert batch_mean.c == np.mean(batch.c, axis=0)
with pytest.raises(IndexError):
Batch()[0]
if __name__ == '__main__': if __name__ == '__main__':
@ -297,7 +299,7 @@ if __name__ == '__main__':
test_utils_to_torch() test_utils_to_torch()
test_batch_pickle() test_batch_pickle()
test_batch_from_to_numpy_without_copy() test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility() test_batch_standard_compatibility()
test_batch_cat_and_stack() test_batch_cat_and_stack()
test_batch_copy() test_batch_copy()
test_batch_empty() test_batch_empty()

View File

@ -24,25 +24,6 @@ def _is_batch_set(data: Any) -> bool:
return False return False
def _valid_bounds(length: int, index: Union[
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
if isinstance(index, (int, np.integer)):
return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, np.max(index))
elif isinstance(index, slice):
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
def _create_value(inst: Any, size: int) -> Union[ def _create_value(inst: Any, size: int) -> Union[
'Batch', np.ndarray, torch.Tensor]: 'Batch', np.ndarray, torch.Tensor]:
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
@ -333,13 +314,17 @@ class Batch:
"""Return self[index].""" """Return self[index]."""
if isinstance(index, str): if isinstance(index, str):
return self.__dict__[index] return self.__dict__[index]
b = Batch() batch_items = self.items()
for k, v in self.items(): if len(batch_items) > 0:
if isinstance(v, Batch) and len(v.__dict__) == 0: b = Batch()
b.__dict__[k] = Batch() for k, v in batch_items:
else: if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = v[index] b.__dict__[k] = Batch()
return b else:
b.__dict__[k] = v[index]
return b
else:
raise IndexError("Cannot access item from empty Batch object.")
def __setitem__( def __setitem__(
self, self,
@ -452,14 +437,6 @@ class Batch:
"""Return self[k] if k in self else d. d defaults to None.""" """Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d) return self.__dict__.get(k, d)
def __iter__(self):
try:
length = len(self)
except Exception:
length = 0
for i in range(length):
yield self[i]
def to_numpy(self) -> None: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an in-place """Change all torch.Tensor to numpy.ndarray. This is an in-place
operation. operation.

View File

@ -272,8 +272,10 @@ class ReplayBuffer:
stack = np.stack(stack, axis=indice.ndim) stack = np.stack(stack, axis=indice.ndim)
else: else:
stack = val[indice] stack = val[indice]
except TypeError: except IndexError as e:
stack = Batch() stack = Batch()
if not isinstance(val, Batch) or len(val.__dict__) > 0:
raise e
self.done[last_index] = last_done self.done[last_index] = last_done
return stack return stack