From d2b2fa87c073e4095d89193372d0b5ffd920d7a3 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 29 May 2020 08:03:37 +0800 Subject: [PATCH] fix #56 --- test/base/test_batch.py | 3 ++- tianshou/data/batch.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index ab176b1..7929387 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -22,9 +22,10 @@ def test_batch(): def test_batch_over_batch(): batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) - batch2 = Batch(b=batch, c=[6, 7, 8]) + batch2 = Batch(c=[6, 7, 8], b=batch) batch2.b.b[-1] = 0 print(batch2) + assert batch2.values()[-1] == batch2.c assert batch2[-1].b.b == 0 diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1cbf772..7da2584 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -131,6 +131,10 @@ class Batch(object): return sorted([ i for i in self.__dict__ if i[0] != '_'] + list(self._meta)) + def values(self) -> List[Any]: + """Return self.values().""" + return [self[k] for k in self.keys()] + def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: """Return self[k] if k in self else d. d defaults to None.""" if k in self.__dict__ or k in self._meta: