Raise exception for Batch __getitem__. (#119)
* Raise exception for Batch __getitem__. * Try fixing access to reserved key. * Simpler patch. * Add unit test to check indexing empty Batch raises an exception. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
7f9a1f1328
commit
aa3c453f42
@ -277,7 +277,7 @@ def test_batch_empty():
|
||||
assert b0.shape == []
|
||||
|
||||
|
||||
def test_batch_numpy_compatibility():
|
||||
def test_batch_standard_compatibility():
|
||||
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||
b=Batch(),
|
||||
c=np.array([5.0, 6.0]))
|
||||
@ -288,6 +288,8 @@ def test_batch_numpy_compatibility():
|
||||
len(batch_mean)
|
||||
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
|
||||
assert batch_mean.c == np.mean(batch.c, axis=0)
|
||||
with pytest.raises(IndexError):
|
||||
Batch()[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -297,7 +299,7 @@ if __name__ == '__main__':
|
||||
test_utils_to_torch()
|
||||
test_batch_pickle()
|
||||
test_batch_from_to_numpy_without_copy()
|
||||
test_batch_numpy_compatibility()
|
||||
test_batch_standard_compatibility()
|
||||
test_batch_cat_and_stack()
|
||||
test_batch_copy()
|
||||
test_batch_empty()
|
||||
|
@ -24,25 +24,6 @@ def _is_batch_set(data: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _valid_bounds(length: int, index: Union[
|
||||
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
|
||||
if isinstance(index, (int, np.integer)):
|
||||
return -length <= index and index < length
|
||||
elif isinstance(index, (list, np.ndarray)):
|
||||
return _valid_bounds(length, np.min(index)) and \
|
||||
_valid_bounds(length, np.max(index))
|
||||
elif isinstance(index, slice):
|
||||
if index.start is not None:
|
||||
start_valid = _valid_bounds(length, index.start)
|
||||
else:
|
||||
start_valid = True
|
||||
if index.stop is not None:
|
||||
stop_valid = _valid_bounds(length, index.stop - 1)
|
||||
else:
|
||||
stop_valid = True
|
||||
return start_valid and stop_valid
|
||||
|
||||
|
||||
def _create_value(inst: Any, size: int) -> Union[
|
||||
'Batch', np.ndarray, torch.Tensor]:
|
||||
if isinstance(inst, np.ndarray):
|
||||
@ -333,13 +314,17 @@ class Batch:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
return self.__dict__[index]
|
||||
batch_items = self.items()
|
||||
if len(batch_items) > 0:
|
||||
b = Batch()
|
||||
for k, v in self.items():
|
||||
for k, v in batch_items:
|
||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||
b.__dict__[k] = Batch()
|
||||
else:
|
||||
b.__dict__[k] = v[index]
|
||||
return b
|
||||
else:
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
@ -452,14 +437,6 @@ class Batch:
|
||||
"""Return self[k] if k in self else d. d defaults to None."""
|
||||
return self.__dict__.get(k, d)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
length = len(self)
|
||||
except Exception:
|
||||
length = 0
|
||||
for i in range(length):
|
||||
yield self[i]
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
|
||||
operation.
|
||||
|
@ -272,8 +272,10 @@ class ReplayBuffer:
|
||||
stack = np.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
stack = val[indice]
|
||||
except TypeError:
|
||||
except IndexError as e:
|
||||
stack = Batch()
|
||||
if not isinstance(val, Batch) or len(val.__dict__) > 0:
|
||||
raise e
|
||||
self.done[last_index] = last_done
|
||||
return stack
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user