change batch.append to batch.cat

This commit is contained in:
Trinkle23897 2020-06-20 22:23:12 +08:00
parent aff0f9aee0
commit a655334d00
4 changed files with 18 additions and 10 deletions

View File

@ -11,7 +11,7 @@ def test_batch():
assert batch.obs == batch["obs"] assert batch.obs == batch["obs"]
batch.obs = [1] batch.obs = [1]
assert batch.obs == [1] assert batch.obs == [1]
batch.append(batch) batch.cat(batch)
assert batch.obs == [1, 1] assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4) assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs assert batch[0].obs == batch[1].obs
@ -34,13 +34,13 @@ def test_batch_over_batch():
print(batch2) print(batch2)
assert batch2.values()[-1] == batch2.c assert batch2.values()[-1] == batch2.c
assert batch2[-1].b.b == 0 assert batch2[-1].b.b == 0
batch2.append(Batch(c=[6, 7, 8], b=batch)) batch2.cat(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8] assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5] assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0] assert batch2.b.b == [4, 5, 0, 4, 5, 0]
d = {'a': [3, 4, 5], 'b': [4, 5, 6]} d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d) batch3 = Batch(c=[6, 7, 8], b=d)
batch3.append(Batch(c=[6, 7, 8], b=d)) batch3.cat(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6] assert batch3.b.b == [4, 5, 6, 4, 5, 6]

View File

@ -1,6 +1,6 @@
import torch import torch
import warnings
import pprint import pprint
import warnings
import numpy as np import numpy as np
from typing import Any, List, Union, Iterator, Optional from typing import Any, List, Union, Iterator, Optional
@ -218,8 +218,16 @@ class Batch:
v.to_torch(dtype, device) v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None: def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch.""" warnings.warn('Method append will be removed soon, please use '
assert isinstance(batch, Batch), 'Only append Batch is allowed!' ':meth:`~tianshou.data.Batch.cat`')
return self.cat(batch)
def cat(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object to current
batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated!'
for k, v in batch.__dict__.items(): for k, v in batch.__dict__.items():
if k == '_meta': if k == '_meta':
self._meta.update(batch._meta) self._meta.update(batch._meta)
@ -235,9 +243,9 @@ class Batch:
elif isinstance(v, list): elif isinstance(v, list):
self.__dict__[k] += v self.__dict__[k] += v
elif isinstance(v, Batch): elif isinstance(v, Batch):
self.__dict__[k].append(v) self.__dict__[k].cat(v)
else: else:
s = f'No support for append with type \ s = f'No support for method "cat" with type \
{type(v)} in class Batch.' {type(v)} in class Batch.'
raise TypeError(s) raise TypeError(s)

View File

@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
impt_weight=1 / np.power( impt_weight=1 / np.power(
self._size * (batch.weight / self._weight_sum), self._size * (batch.weight / self._weight_sum),
self._beta)) self._beta))
batch.append(impt_weight) batch.cat(impt_weight)
self._check_weight_sum() self._check_weight_sum()
return batch, indice return batch, indice

View File

@ -406,7 +406,7 @@ class Collector(object):
if batch_size and cur_batch or batch_size <= 0: if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch) batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice) batch = self.process_fn(batch, b, indice)
batch_data.append(batch) batch_data.cat(batch)
else: else:
batch_data, indice = self.buffer.sample(batch_size) batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice) batch_data = self.process_fn(batch_data, self.buffer, indice)