change batch.append to batch.cat
This commit is contained in:
parent
aff0f9aee0
commit
a655334d00
@ -11,7 +11,7 @@ def test_batch():
|
||||
assert batch.obs == batch["obs"]
|
||||
batch.obs = [1]
|
||||
assert batch.obs == [1]
|
||||
batch.append(batch)
|
||||
batch.cat(batch)
|
||||
assert batch.obs == [1, 1]
|
||||
assert batch.np.shape == (6, 4)
|
||||
assert batch[0].obs == batch[1].obs
|
||||
@ -34,13 +34,13 @@ def test_batch_over_batch():
|
||||
print(batch2)
|
||||
assert batch2.values()[-1] == batch2.c
|
||||
assert batch2[-1].b.b == 0
|
||||
batch2.append(Batch(c=[6, 7, 8], b=batch))
|
||||
batch2.cat(Batch(c=[6, 7, 8], b=batch))
|
||||
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||
batch3.append(Batch(c=[6, 7, 8], b=d))
|
||||
batch3.cat(Batch(c=[6, 7, 8], b=d))
|
||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import warnings
|
||||
import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, Iterator, Optional
|
||||
|
||||
@ -218,8 +218,16 @@ class Batch:
|
||||
v.to_torch(dtype, device)
|
||||
|
||||
def append(self, batch: 'Batch') -> None:
|
||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
warnings.warn('Method append will be removed soon, please use '
|
||||
':meth:`~tianshou.data.Batch.cat`')
|
||||
return self.cat(batch)
|
||||
|
||||
def cat(self, batch: 'Batch') -> None:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object to current
|
||||
batch.
|
||||
"""
|
||||
assert isinstance(batch, Batch), \
|
||||
'Only Batch is allowed to be concatenated!'
|
||||
for k, v in batch.__dict__.items():
|
||||
if k == '_meta':
|
||||
self._meta.update(batch._meta)
|
||||
@ -235,9 +243,9 @@ class Batch:
|
||||
elif isinstance(v, list):
|
||||
self.__dict__[k] += v
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].append(v)
|
||||
self.__dict__[k].cat(v)
|
||||
else:
|
||||
s = f'No support for append with type \
|
||||
s = f'No support for method "cat" with type \
|
||||
{type(v)} in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
|
@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
impt_weight=1 / np.power(
|
||||
self._size * (batch.weight / self._weight_sum),
|
||||
self._beta))
|
||||
batch.append(impt_weight)
|
||||
batch.cat(impt_weight)
|
||||
self._check_weight_sum()
|
||||
return batch, indice
|
||||
|
||||
|
@ -406,7 +406,7 @@ class Collector(object):
|
||||
if batch_size and cur_batch or batch_size <= 0:
|
||||
batch, indice = b.sample(cur_batch)
|
||||
batch = self.process_fn(batch, b, indice)
|
||||
batch_data.append(batch)
|
||||
batch_data.cat(batch)
|
||||
else:
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
|
Loading…
x
Reference in New Issue
Block a user