Raise exception for Batch __getitem__. (#119)

* Raise exception for Batch __getitem__.

* Try fixing access to reserved key.

* Simpler patch.

* Add unit test to check indexing empty Batch raises an exception.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-08 16:29:37 +02:00 committed by GitHub
parent 7f9a1f1328
commit aa3c453f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 37 deletions

View File

@ -277,7 +277,7 @@ def test_batch_empty():
assert b0.shape == []
def test_batch_numpy_compatibility():
def test_batch_standard_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(),
c=np.array([5.0, 6.0]))
@ -288,6 +288,8 @@ def test_batch_numpy_compatibility():
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
assert batch_mean.c == np.mean(batch.c, axis=0)
with pytest.raises(IndexError):
Batch()[0]
if __name__ == '__main__':
@ -297,7 +299,7 @@ if __name__ == '__main__':
test_utils_to_torch()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility()
test_batch_standard_compatibility()
test_batch_cat_and_stack()
test_batch_copy()
test_batch_empty()

View File

@ -24,25 +24,6 @@ def _is_batch_set(data: Any) -> bool:
return False
def _valid_bounds(length: int, index: Union[
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
if isinstance(index, (int, np.integer)):
return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, np.max(index))
elif isinstance(index, slice):
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
def _create_value(inst: Any, size: int) -> Union[
'Batch', np.ndarray, torch.Tensor]:
if isinstance(inst, np.ndarray):
@ -333,13 +314,17 @@ class Batch:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
batch_items = self.items()
if len(batch_items) > 0:
b = Batch()
for k, v in self.items():
for k, v in batch_items:
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
return b
else:
raise IndexError("Cannot access item from empty Batch object.")
def __setitem__(
self,
@ -452,14 +437,6 @@ class Batch:
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)
def __iter__(self):
try:
length = len(self)
except Exception:
length = 0
for i in range(length):
yield self[i]
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
operation.

View File

@ -272,8 +272,10 @@ class ReplayBuffer:
stack = np.stack(stack, axis=indice.ndim)
else:
stack = val[indice]
except TypeError:
except IndexError as e:
stack = Batch()
if not isinstance(val, Batch) or len(val.__dict__) > 0:
raise e
self.done[last_index] = last_done
return stack