change batch.append to batch.cat
This commit is contained in:
parent
aff0f9aee0
commit
a655334d00
@ -11,7 +11,7 @@ def test_batch():
|
|||||||
assert batch.obs == batch["obs"]
|
assert batch.obs == batch["obs"]
|
||||||
batch.obs = [1]
|
batch.obs = [1]
|
||||||
assert batch.obs == [1]
|
assert batch.obs == [1]
|
||||||
batch.append(batch)
|
batch.cat(batch)
|
||||||
assert batch.obs == [1, 1]
|
assert batch.obs == [1, 1]
|
||||||
assert batch.np.shape == (6, 4)
|
assert batch.np.shape == (6, 4)
|
||||||
assert batch[0].obs == batch[1].obs
|
assert batch[0].obs == batch[1].obs
|
||||||
@ -34,13 +34,13 @@ def test_batch_over_batch():
|
|||||||
print(batch2)
|
print(batch2)
|
||||||
assert batch2.values()[-1] == batch2.c
|
assert batch2.values()[-1] == batch2.c
|
||||||
assert batch2[-1].b.b == 0
|
assert batch2[-1].b.b == 0
|
||||||
batch2.append(Batch(c=[6, 7, 8], b=batch))
|
batch2.cat(Batch(c=[6, 7, 8], b=batch))
|
||||||
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
||||||
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
||||||
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
||||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||||
batch3.append(Batch(c=[6, 7, 8], b=d))
|
batch3.cat(Batch(c=[6, 7, 8], b=d))
|
||||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import warnings
|
|
||||||
import pprint
|
import pprint
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Union, Iterator, Optional
|
from typing import Any, List, Union, Iterator, Optional
|
||||||
|
|
||||||
@ -218,8 +218,16 @@ class Batch:
|
|||||||
v.to_torch(dtype, device)
|
v.to_torch(dtype, device)
|
||||||
|
|
||||||
def append(self, batch: 'Batch') -> None:
|
def append(self, batch: 'Batch') -> None:
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
warnings.warn('Method append will be removed soon, please use '
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
':meth:`~tianshou.data.Batch.cat`')
|
||||||
|
return self.cat(batch)
|
||||||
|
|
||||||
|
def cat(self, batch: 'Batch') -> None:
|
||||||
|
"""Concatenate a :class:`~tianshou.data.Batch` object to current
|
||||||
|
batch.
|
||||||
|
"""
|
||||||
|
assert isinstance(batch, Batch), \
|
||||||
|
'Only Batch is allowed to be concatenated!'
|
||||||
for k, v in batch.__dict__.items():
|
for k, v in batch.__dict__.items():
|
||||||
if k == '_meta':
|
if k == '_meta':
|
||||||
self._meta.update(batch._meta)
|
self._meta.update(batch._meta)
|
||||||
@ -235,9 +243,9 @@ class Batch:
|
|||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
self.__dict__[k] += v
|
self.__dict__[k] += v
|
||||||
elif isinstance(v, Batch):
|
elif isinstance(v, Batch):
|
||||||
self.__dict__[k].append(v)
|
self.__dict__[k].cat(v)
|
||||||
else:
|
else:
|
||||||
s = f'No support for append with type \
|
s = f'No support for method "cat" with type \
|
||||||
{type(v)} in class Batch.'
|
{type(v)} in class Batch.'
|
||||||
raise TypeError(s)
|
raise TypeError(s)
|
||||||
|
|
||||||
|
@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
impt_weight=1 / np.power(
|
impt_weight=1 / np.power(
|
||||||
self._size * (batch.weight / self._weight_sum),
|
self._size * (batch.weight / self._weight_sum),
|
||||||
self._beta))
|
self._beta))
|
||||||
batch.append(impt_weight)
|
batch.cat(impt_weight)
|
||||||
self._check_weight_sum()
|
self._check_weight_sum()
|
||||||
return batch, indice
|
return batch, indice
|
||||||
|
|
||||||
|
@ -406,7 +406,7 @@ class Collector(object):
|
|||||||
if batch_size and cur_batch or batch_size <= 0:
|
if batch_size and cur_batch or batch_size <= 0:
|
||||||
batch, indice = b.sample(cur_batch)
|
batch, indice = b.sample(cur_batch)
|
||||||
batch = self.process_fn(batch, b, indice)
|
batch = self.process_fn(batch, b, indice)
|
||||||
batch_data.append(batch)
|
batch_data.cat(batch)
|
||||||
else:
|
else:
|
||||||
batch_data, indice = self.buffer.sample(batch_size)
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user