diff --git a/test/base/test_batch.py b/test/base/test_batch.py index ab886dd..0d30a0a 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -11,7 +11,7 @@ def test_batch(): assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] - batch.append(batch) + batch.cat(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs @@ -34,13 +34,13 @@ def test_batch_over_batch(): print(batch2) assert batch2.values()[-1] == batch2.c assert batch2[-1].b.b == 0 - batch2.append(Batch(c=[6, 7, 8], b=batch)) + batch2.cat(Batch(c=[6, 7, 8], b=batch)) assert batch2.c == [6, 7, 8, 6, 7, 8] assert batch2.b.a == [3, 4, 5, 3, 4, 5] assert batch2.b.b == [4, 5, 0, 4, 5, 0] d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) - batch3.append(Batch(c=[6, 7, 8], b=d)) + batch3.cat(Batch(c=[6, 7, 8], b=d)) assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.b == [4, 5, 6, 4, 5, 6] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 61d23d9..f6c1df6 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,6 @@ import torch -import warnings import pprint +import warnings import numpy as np from typing import Any, List, Union, Iterator, Optional @@ -218,8 +218,16 @@ class Batch: v.to_torch(dtype, device) def append(self, batch: 'Batch') -> None: - """Append a :class:`~tianshou.data.Batch` object to current batch.""" - assert isinstance(batch, Batch), 'Only append Batch is allowed!' + warnings.warn('Method append will be removed soon, please use ' + ':meth:`~tianshou.data.Batch.cat`') + return self.cat(batch) + + def cat(self, batch: 'Batch') -> None: + """Concatenate a :class:`~tianshou.data.Batch` object to current + batch. + """ + assert isinstance(batch, Batch), \ + 'Only Batch is allowed to be concatenated!' for k, v in batch.__dict__.items(): if k == '_meta': self._meta.update(batch._meta) @@ -235,9 +243,9 @@ class Batch: elif isinstance(v, list): self.__dict__[k] += v elif isinstance(v, Batch): - self.__dict__[k].append(v) + self.__dict__[k].cat(v) else: - s = f'No support for append with type \ + s = f'No support for method "cat" with type \ {type(v)} in class Batch.' raise TypeError(s) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 592be5f..92756d5 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): impt_weight=1 / np.power( self._size * (batch.weight / self._weight_sum), self._beta)) - batch.append(impt_weight) + batch.cat(impt_weight) self._check_weight_sum() return batch, indice diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index ed75b4e..4ff0997 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -406,7 +406,7 @@ class Collector(object): if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) - batch_data.append(batch) + batch_data.cat(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice)