change batch.append to batch.cat

This commit is contained in:
Trinkle23897 2020-06-20 22:23:12 +08:00
parent aff0f9aee0
commit a655334d00
4 changed files with 18 additions and 10 deletions

View File

@ -11,7 +11,7 @@ def test_batch():
assert batch.obs == batch["obs"]
batch.obs = [1]
assert batch.obs == [1]
batch.append(batch)
batch.cat(batch)
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs
@ -34,13 +34,13 @@ def test_batch_over_batch():
print(batch2)
assert batch2.values()[-1] == batch2.c
assert batch2[-1].b.b == 0
batch2.append(Batch(c=[6, 7, 8], b=batch))
batch2.cat(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.append(Batch(c=[6, 7, 8], b=d))
batch3.cat(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6]

View File

@ -1,6 +1,6 @@
import torch
import warnings
import pprint
import warnings
import numpy as np
from typing import Any, List, Union, Iterator, Optional
@ -218,8 +218,16 @@ class Batch:
v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
warnings.warn('Method append will be removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat(batch)
def cat(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object to current
batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated!'
for k, v in batch.__dict__.items():
if k == '_meta':
self._meta.update(batch._meta)
@ -235,9 +243,9 @@ class Batch:
elif isinstance(v, list):
self.__dict__[k] += v
elif isinstance(v, Batch):
self.__dict__[k].append(v)
self.__dict__[k].cat(v)
else:
s = f'No support for append with type \
s = f'No support for method "cat" with type \
{type(v)} in class Batch.'
raise TypeError(s)

View File

@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
impt_weight=1 / np.power(
self._size * (batch.weight / self._weight_sum),
self._beta))
batch.append(impt_weight)
batch.cat(impt_weight)
self._check_weight_sum()
return batch, indice

View File

@ -406,7 +406,7 @@ class Collector(object):
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)
batch_data.append(batch)
batch_data.cat(batch)
else:
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)