From aa3c453f425df22680f9bb7a3f1c8dc025a3f1af Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Wed, 8 Jul 2020 16:29:37 +0200 Subject: [PATCH] Raise exception for Batch __getitem__. (#119) * Raise exception for Batch __getitem__. * Try fixing access to reserved key. * Simpler patch. * Add unit test to check indexing empty Batch raises an exception. Co-authored-by: Alexis Duburcq --- test/base/test_batch.py | 6 ++++-- tianshou/data/batch.py | 45 ++++++++++------------------------------- tianshou/data/buffer.py | 4 +++- 3 files changed, 18 insertions(+), 37 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 1303fb0..e2390de 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -277,7 +277,7 @@ def test_batch_empty(): assert b0.shape == [] -def test_batch_numpy_compatibility(): +def test_batch_standard_compatibility(): batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) @@ -288,6 +288,8 @@ def test_batch_numpy_compatibility(): len(batch_mean) assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) assert batch_mean.c == np.mean(batch.c, axis=0) + with pytest.raises(IndexError): + Batch()[0] if __name__ == '__main__': @@ -297,7 +299,7 @@ if __name__ == '__main__': test_utils_to_torch() test_batch_pickle() test_batch_from_to_numpy_without_copy() - test_batch_numpy_compatibility() + test_batch_standard_compatibility() test_batch_cat_and_stack() test_batch_copy() test_batch_empty() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4647e7f..e14b98f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -24,25 +24,6 @@ def _is_batch_set(data: Any) -> bool: return False -def _valid_bounds(length: int, index: Union[ - slice, int, np.integer, np.ndarray, List[int]]) -> bool: - if isinstance(index, (int, np.integer)): - return -length <= index and index < length - elif isinstance(index, (list, np.ndarray)): - return _valid_bounds(length, np.min(index)) and \ - _valid_bounds(length, np.max(index)) - elif isinstance(index, slice): - if index.start is not None: - start_valid = _valid_bounds(length, index.start) - else: - start_valid = True - if index.stop is not None: - stop_valid = _valid_bounds(length, index.stop - 1) - else: - stop_valid = True - return start_valid and stop_valid - - def _create_value(inst: Any, size: int) -> Union[ 'Batch', np.ndarray, torch.Tensor]: if isinstance(inst, np.ndarray): @@ -333,13 +314,17 @@ class Batch: """Return self[index].""" if isinstance(index, str): return self.__dict__[index] - b = Batch() - for k, v in self.items(): - if isinstance(v, Batch) and len(v.__dict__) == 0: - b.__dict__[k] = Batch() - else: - b.__dict__[k] = v[index] - return b + batch_items = self.items() + if len(batch_items) > 0: + b = Batch() + for k, v in batch_items: + if isinstance(v, Batch) and len(v.__dict__) == 0: + b.__dict__[k] = Batch() + else: + b.__dict__[k] = v[index] + return b + else: + raise IndexError("Cannot access item from empty Batch object.") def __setitem__( self, @@ -452,14 +437,6 @@ class Batch: """Return self[k] if k in self else d. d defaults to None.""" return self.__dict__.get(k, d) - def __iter__(self): - try: - length = len(self) - except Exception: - length = 0 - for i in range(length): - yield self[i] - def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray. This is an in-place operation. diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 53c158e..5c0bcad 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -272,8 +272,10 @@ class ReplayBuffer: stack = np.stack(stack, axis=indice.ndim) else: stack = val[indice] - except TypeError: + except IndexError as e: stack = Batch() + if not isinstance(val, Batch) or len(val.__dict__) > 0: + raise e self.done[last_index] = last_done return stack