diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6fa0d60..542564a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -433,6 +433,14 @@ class Batch: """Return self[k] if k in self else d. d defaults to None.""" return self.__dict__.get(k, d) + def __iter__(self): + try: + length = len(self) + except Exception: + length = 0 + for i in range(length): + yield self[i] + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray. This is an in-place operation.