Advanced Batch slicing & minor fix of RNN support (#106)
* add shape property and modify __getitem__ * change Batch.size to Batch.shape * setattr * Batch.empty * remove scalar in advanced slicing * modify empty_ and __getitem__ * missing testcase * fix empty
This commit is contained in:
parent
c639446c66
commit
db0e2e5cd2
@ -5,3 +5,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
|
|||||||
|
|
||||||
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
|
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
|
||||||
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
|
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
|
||||||
|
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
|
||||||
|
@ -13,9 +13,9 @@ def test_batch():
|
|||||||
batch.obs = [1]
|
batch.obs = [1]
|
||||||
assert batch.obs == [1]
|
assert batch.obs == [1]
|
||||||
batch.cat_(batch)
|
batch.cat_(batch)
|
||||||
assert batch.obs == [1, 1]
|
assert np.allclose(batch.obs, [1, 1])
|
||||||
assert batch.np.shape == (6, 4)
|
assert batch.np.shape == (6, 4)
|
||||||
assert batch[0].obs == batch[1].obs
|
assert np.allclose(batch[0].obs, batch[1].obs)
|
||||||
batch.obs = np.arange(5)
|
batch.obs = np.arange(5)
|
||||||
for i, b in enumerate(batch.split(1, shuffle=False)):
|
for i, b in enumerate(batch.split(1, shuffle=False)):
|
||||||
if i != 5:
|
if i != 5:
|
||||||
@ -39,14 +39,14 @@ def test_batch():
|
|||||||
'c': np.zeros(1),
|
'c': np.zeros(1),
|
||||||
'd': Batch(e=np.array(3.0))}])
|
'd': Batch(e=np.array(3.0))}])
|
||||||
assert len(batch2) == 1
|
assert len(batch2) == 1
|
||||||
assert Batch().size == 0
|
assert Batch().shape == []
|
||||||
assert batch2.size == 1
|
assert batch2.shape[0] == 1
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
batch2[-2]
|
batch2[-2]
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
batch2[1]
|
batch2[1]
|
||||||
assert batch2[0].size == 1
|
assert batch2[0].shape == []
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(IndexError):
|
||||||
batch2[0][0]
|
batch2[0][0]
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
len(batch2[0])
|
len(batch2[0])
|
||||||
@ -87,24 +87,36 @@ def test_batch_over_batch():
|
|||||||
batch2.b.b[-1] = 0
|
batch2.b.b[-1] = 0
|
||||||
print(batch2)
|
print(batch2)
|
||||||
for k, v in batch2.items():
|
for k, v in batch2.items():
|
||||||
assert batch2[k] == v
|
assert np.all(batch2[k] == v)
|
||||||
assert batch2[-1].b.b == 0
|
assert batch2[-1].b.b == 0
|
||||||
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
|
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
|
||||||
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
||||||
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
|
||||||
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
|
||||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||||
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
||||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
|
||||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
|
||||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
|
||||||
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
|
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
|
||||||
assert batch4.a.b.ndim == 2
|
assert batch4.a.b.ndim == 2
|
||||||
assert batch4.a.b[0, 0] == 1.0
|
assert batch4.a.b[0, 0] == 1.0
|
||||||
|
# advanced slicing
|
||||||
|
batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])})
|
||||||
|
assert batch5.shape == [1, 2]
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
batch5[2]
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
batch5[:, 3]
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
batch5[:, :, -1]
|
||||||
|
batch5[:, -1] += 1
|
||||||
|
assert np.allclose(batch5.a, [1, 3])
|
||||||
|
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_cat_and_stack():
|
def test_batch_cat_and_stack_and_empty():
|
||||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||||
b12_cat_out = Batch.cat((b1, b2))
|
b12_cat_out = Batch.cat((b1, b2))
|
||||||
@ -133,6 +145,24 @@ def test_batch_cat_and_stack():
|
|||||||
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
||||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||||
assert b5.b.d[1] == 0.0
|
assert b5.b.d[1] == 0.0
|
||||||
|
b5[1] = Batch.empty(b5[0])
|
||||||
|
assert np.allclose(b5.a, [False, False])
|
||||||
|
assert np.allclose(b5.b.c, [2, 0])
|
||||||
|
assert np.allclose(b5.b.d, [1, 0])
|
||||||
|
data = Batch(a=[False, True],
|
||||||
|
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
|
||||||
|
c=np.array([1, 3, 4], dtype=np.int),
|
||||||
|
t=torch.tensor([4, 5, 6, 7.]))
|
||||||
|
data[-1] = Batch.empty(data[1])
|
||||||
|
assert np.allclose(data.c, [1, 3, 0])
|
||||||
|
assert np.allclose(data.a, [False, False])
|
||||||
|
assert list(data.b.c) == ['2.0', '']
|
||||||
|
assert list(data.b.d) == [1, None]
|
||||||
|
assert np.allclose(data.b.e, [2, 0])
|
||||||
|
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
|
||||||
|
b0 = Batch()
|
||||||
|
b0.empty_()
|
||||||
|
assert b0.shape == []
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch_to_torch():
|
def test_batch_over_batch_to_torch():
|
||||||
@ -215,3 +245,5 @@ if __name__ == '__main__':
|
|||||||
test_utils_to_torch()
|
test_utils_to_torch()
|
||||||
test_batch_pickle()
|
test_batch_pickle()
|
||||||
test_batch_from_to_numpy_without_copy()
|
test_batch_from_to_numpy_without_copy()
|
||||||
|
test_batch_numpy_compatibility()
|
||||||
|
test_batch_cat_and_stack_and_empty()
|
||||||
|
@ -74,8 +74,9 @@ class Batch:
|
|||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> from tianshou.data import Batch
|
>>> from tianshou.data import Batch
|
||||||
>>> data = Batch(a=4, b=[5, 5], c='2312312')
|
>>> data = Batch(a=4, b=[5, 5], c='2312312')
|
||||||
|
>>> # the list will automatically be converted to numpy array
|
||||||
>>> data.b
|
>>> data.b
|
||||||
[5, 5]
|
array([5, 5])
|
||||||
>>> data.b = np.array([3, 4, 5])
|
>>> data.b = np.array([3, 4, 5])
|
||||||
>>> print(data)
|
>>> print(data)
|
||||||
Batch(
|
Batch(
|
||||||
@ -104,8 +105,6 @@ class Batch:
|
|||||||
together:
|
together:
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> import numpy as np
|
|
||||||
>>> from tianshou.data import Batch
|
|
||||||
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
|
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
|
||||||
>>> print(data[0])
|
>>> print(data[0])
|
||||||
Batch(
|
Batch(
|
||||||
@ -119,7 +118,6 @@ class Batch:
|
|||||||
key, or iterate over stored data:
|
key, or iterate over stored data:
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> from tianshou.data import Batch
|
|
||||||
>>> data = Batch(a=4, b=[5, 5])
|
>>> data = Batch(a=4, b=[5, 5])
|
||||||
>>> print(data["a"])
|
>>> print(data["a"])
|
||||||
4
|
4
|
||||||
@ -130,28 +128,36 @@ class Batch:
|
|||||||
|
|
||||||
|
|
||||||
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
|
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
|
||||||
arrays. You can access or iterate over the individual samples, if any:
|
arrays. It also supports the advanced slicing method, such as batch[:, i],
|
||||||
|
if the index is valid. You can access or iterate over the individual
|
||||||
|
samples, if any:
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> import numpy as np
|
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
|
||||||
>>> from tianshou.data import Batch
|
|
||||||
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
|
|
||||||
>>> print(data[0])
|
>>> print(data[0])
|
||||||
Batch(
|
Batch(
|
||||||
a: np.array([0.0, 2.0])
|
a: array([0., 2.])
|
||||||
b: 5
|
b: array([ 5, -5]),
|
||||||
)
|
)
|
||||||
>>> for sample in data:
|
>>> for sample in data:
|
||||||
>>> print(sample.a)
|
>>> print(sample.a)
|
||||||
[0.0, 2.0]
|
[0., 2.]
|
||||||
[1.0, 3.0]
|
[1., 3.]
|
||||||
|
|
||||||
|
>>> print(data.shape)
|
||||||
|
[1, 2]
|
||||||
|
>>> data[:, 1] += 1
|
||||||
|
>>> print(data)
|
||||||
|
Batch(
|
||||||
|
a: array([[0., 3.],
|
||||||
|
[1., 4.]]),
|
||||||
|
b: array([[ 5, -4]]),
|
||||||
|
)
|
||||||
|
|
||||||
Similarly, one can also perform simple algebra on it, and stack, split or
|
Similarly, one can also perform simple algebra on it, and stack, split or
|
||||||
concatenate multiple instances:
|
concatenate multiple instances:
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> import numpy as np
|
|
||||||
>>> from tianshou.data import Batch
|
|
||||||
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
|
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
|
||||||
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
|
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
|
||||||
>>> data = Batch.stack((data_1, data_2))
|
>>> data = Batch.stack((data_1, data_2))
|
||||||
@ -169,11 +175,10 @@ class Batch:
|
|||||||
>>> data_split = list(data.split(1, False))
|
>>> data_split = list(data.split(1, False))
|
||||||
>>> print(list(data.split(1, False)))
|
>>> print(list(data.split(1, False)))
|
||||||
[Batch(
|
[Batch(
|
||||||
b: [5],
|
b: array([5]),
|
||||||
a: array([[0., 2.]]),
|
a: array([[0., 2.]]),
|
||||||
),
|
), Batch(
|
||||||
Batch(
|
b: array([-5]),
|
||||||
b: [-5],
|
|
||||||
a: array([[1., 3.]]),
|
a: array([[1., 3.]]),
|
||||||
)]
|
)]
|
||||||
>>> data_cat = Batch.cat(data_split)
|
>>> data_cat = Batch.cat(data_split)
|
||||||
@ -188,8 +193,6 @@ class Batch:
|
|||||||
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
|
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> import numpy as np
|
|
||||||
>>> from tianshou.data import Batch
|
|
||||||
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
|
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
|
||||||
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
|
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
|
||||||
>>> data = Batch.stack((data_1, data_2))
|
>>> data = Batch.stack((data_1, data_2))
|
||||||
@ -200,23 +203,40 @@ class Batch:
|
|||||||
b: array([None, 'done'], dtype=object),
|
b: array([None, 'done'], dtype=object),
|
||||||
)
|
)
|
||||||
|
|
||||||
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__`
|
Also with method empty (which will set to 0 or ``None`` (with np.object))
|
||||||
methods are also provided to respectively get the length and the size of
|
::
|
||||||
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
|
|
||||||
means that getting the length of a scalar Batch raises an exception, while
|
>>> data.empty_()
|
||||||
the size is 1. The size is only 0 if empty. Note that the size and length
|
>>> print(data)
|
||||||
are the identical if multiple samples are stored:
|
Batch(
|
||||||
|
a: array([[0., 0.],
|
||||||
|
[0., 0.]]),
|
||||||
|
b: array([None, None], dtype=object),
|
||||||
|
)
|
||||||
|
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
|
||||||
|
>>> data[0] = Batch.empty(data[1])
|
||||||
|
>>> data
|
||||||
|
Batch(
|
||||||
|
a: array([False, True]),
|
||||||
|
b: Batch(
|
||||||
|
c: array([0., 3.]),
|
||||||
|
d: array([0., 0.]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
|
||||||
|
methods are also provided to respectively get the shape and the length of
|
||||||
|
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
|
||||||
|
means that getting the length of a scalar Batch raises an exception.
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> import numpy as np
|
|
||||||
>>> from tianshou.data import Batch
|
|
||||||
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
|
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
|
||||||
>>> data.size
|
>>> data.shape
|
||||||
2
|
[2]
|
||||||
>>> len(data)
|
>>> len(data)
|
||||||
2
|
2
|
||||||
>>> data[0].size
|
>>> data[0].shape
|
||||||
1
|
[]
|
||||||
>>> len(data[0])
|
>>> len(data[0])
|
||||||
TypeError: Object of type 'Batch' has no len()
|
TypeError: Object of type 'Batch' has no len()
|
||||||
|
|
||||||
@ -240,13 +260,26 @@ class Batch:
|
|||||||
if isinstance(v, dict) or _is_batch_set(v):
|
if isinstance(v, dict) or _is_batch_set(v):
|
||||||
self.__dict__[k] = Batch(v)
|
self.__dict__[k] = Batch(v)
|
||||||
else:
|
else:
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = np.array(v)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.__init__(kwargs)
|
self.__init__(kwargs)
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, value: Any):
|
||||||
|
"""self[key] = value"""
|
||||||
|
if isinstance(value, list):
|
||||||
|
if _is_batch_set(value):
|
||||||
|
value = Batch(value)
|
||||||
|
else:
|
||||||
|
value = np.array(value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
value = Batch(value)
|
||||||
|
self.__dict__[key] = value
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
"""Pickling interface. Only the actual data are serialized
|
"""Pickling interface. Only the actual data are serialized for both
|
||||||
for both efficiency and simplicity.
|
efficiency and simplicity.
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
@ -256,9 +289,9 @@ class Batch:
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
"""Unpickling interface. At this point, self is an empty Batch
|
"""Unpickling interface. At this point, self is an empty Batch instance
|
||||||
instance that has not been initialized, so it can safely be
|
that has not been initialized, so it can safely be initialized by the
|
||||||
initialized by the pickle state.
|
pickle state.
|
||||||
"""
|
"""
|
||||||
self.__init__(**state)
|
self.__init__(**state)
|
||||||
|
|
||||||
@ -267,26 +300,18 @@ class Batch:
|
|||||||
"""Return self[index]."""
|
"""Return self[index]."""
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
return self.__dict__[index]
|
return self.__dict__[index]
|
||||||
|
b = Batch()
|
||||||
|
for k, v in self.items():
|
||||||
|
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||||
|
b.__dict__[k] = Batch()
|
||||||
|
else:
|
||||||
|
b.__dict__[k] = v[index]
|
||||||
|
return b
|
||||||
|
|
||||||
if not _valid_bounds(len(self), index):
|
def __setitem__(
|
||||||
raise IndexError(
|
self,
|
||||||
f"Index {index} out of bounds for Batch of len {len(self)}.")
|
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||||
else:
|
value: Any) -> None:
|
||||||
b = Batch()
|
|
||||||
is_index_scalar = isinstance(index, (int, np.integer)) or \
|
|
||||||
(isinstance(index, np.ndarray) and index.ndim == 0)
|
|
||||||
for k, v in self.items():
|
|
||||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
|
||||||
b.__dict__[k] = Batch()
|
|
||||||
elif is_index_scalar or not isinstance(v, list):
|
|
||||||
b.__dict__[k] = v[index]
|
|
||||||
else:
|
|
||||||
b.__dict__[k] = [v[i] for i in index]
|
|
||||||
return b
|
|
||||||
|
|
||||||
def __setitem__(self, index: Union[
|
|
||||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
|
||||||
value: Any) -> None:
|
|
||||||
"""Assign value to self[index]."""
|
"""Assign value to self[index]."""
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
self.__dict__[index] = value
|
self.__dict__[index] = value
|
||||||
@ -319,8 +344,6 @@ class Batch:
|
|||||||
other.__dict__.values()):
|
other.__dict__.values()):
|
||||||
if r is None:
|
if r is None:
|
||||||
continue
|
continue
|
||||||
elif isinstance(r, list):
|
|
||||||
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
|
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] += v
|
self.__dict__[k] += v
|
||||||
return self
|
return self
|
||||||
@ -328,8 +351,6 @@ class Batch:
|
|||||||
for k, r in self.items():
|
for k, r in self.items():
|
||||||
if r is None:
|
if r is None:
|
||||||
continue
|
continue
|
||||||
elif isinstance(r, list):
|
|
||||||
self.__dict__[k] = [r_ + other for r_ in r]
|
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] += other
|
self.__dict__[k] += other
|
||||||
return self
|
return self
|
||||||
@ -440,13 +461,14 @@ class Batch:
|
|||||||
v.to_torch(dtype, device)
|
v.to_torch(dtype, device)
|
||||||
|
|
||||||
def append(self, batch: 'Batch') -> None:
|
def append(self, batch: 'Batch') -> None:
|
||||||
warnings.warn('Method append will be removed soon, please use '
|
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
|
||||||
|
'removed soon, please use '
|
||||||
':meth:`~tianshou.data.Batch.cat`')
|
':meth:`~tianshou.data.Batch.cat`')
|
||||||
return self.cat_(batch)
|
return self.cat_(batch)
|
||||||
|
|
||||||
def cat_(self, batch: 'Batch') -> None:
|
def cat_(self, batch: 'Batch') -> None:
|
||||||
"""Concatenate a :class:`~tianshou.data.Batch` object into
|
"""Concatenate a :class:`~tianshou.data.Batch` object into current
|
||||||
current batch.
|
batch.
|
||||||
"""
|
"""
|
||||||
assert isinstance(batch, Batch), \
|
assert isinstance(batch, Batch), \
|
||||||
'Only Batch is allowed to be concatenated in-place!'
|
'Only Batch is allowed to be concatenated in-place!'
|
||||||
@ -459,8 +481,6 @@ class Batch:
|
|||||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||||
elif isinstance(v, list):
|
|
||||||
self.__dict__[k] += copy.deepcopy(v)
|
|
||||||
elif isinstance(v, Batch):
|
elif isinstance(v, Batch):
|
||||||
self.__dict__[k].cat_(v)
|
self.__dict__[k].cat_(v)
|
||||||
else:
|
else:
|
||||||
@ -468,12 +488,12 @@ class Batch:
|
|||||||
f'{type(v)} in class Batch.'
|
f'{type(v)} in class Batch.'
|
||||||
raise TypeError(s)
|
raise TypeError(s)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def cat(cls, batches: List['Batch']) -> 'Batch':
|
def cat(batches: List['Batch']) -> 'Batch':
|
||||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a
|
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
|
||||||
single new batch.
|
new batch.
|
||||||
"""
|
"""
|
||||||
batch = cls()
|
batch = Batch()
|
||||||
for batch_ in batches:
|
for batch_ in batches:
|
||||||
batch.cat_(batch_)
|
batch.cat_(batch_)
|
||||||
return batch
|
return batch
|
||||||
@ -481,8 +501,7 @@ class Batch:
|
|||||||
def stack_(self,
|
def stack_(self,
|
||||||
batches: List[Union[dict, 'Batch']],
|
batches: List[Union[dict, 'Batch']],
|
||||||
axis: int = 0) -> None:
|
axis: int = 0) -> None:
|
||||||
"""Stack a :class:`~tianshou.data.Batch` object i into current
|
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
|
||||||
batch.
|
|
||||||
"""
|
"""
|
||||||
if len(self.__dict__) > 0:
|
if len(self.__dict__) > 0:
|
||||||
batches = [self] + list(batches)
|
batches = [self] + list(batches)
|
||||||
@ -511,13 +530,42 @@ class Batch:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
|
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
|
||||||
"""Stack a :class:`~tianshou.data.Batch` object into a
|
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
||||||
single new batch.
|
batch.
|
||||||
"""
|
"""
|
||||||
batch = Batch()
|
batch = Batch()
|
||||||
batch.stack_(batches, axis)
|
batch.stack_(batches, axis)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def empty_(self) -> 'Batch':
|
||||||
|
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
|
||||||
|
``None`` filled.
|
||||||
|
"""
|
||||||
|
for k, v in self.items():
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
if isinstance(v, Batch):
|
||||||
|
self.__dict__[k].empty_()
|
||||||
|
elif isinstance(v, np.ndarray) and v.dtype == np.object:
|
||||||
|
self.__dict__[k].fill(None)
|
||||||
|
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
|
||||||
|
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
|
||||||
|
else: # np
|
||||||
|
self.__dict__[k] *= 0
|
||||||
|
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
|
||||||
|
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
|
||||||
|
return self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty(batch: 'Batch') -> 'Batch':
|
||||||
|
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
|
||||||
|
``None`` filled, the shape is the same as the given
|
||||||
|
:class:`~tianshou.data.Batch`.
|
||||||
|
"""
|
||||||
|
batch = Batch(**batch)
|
||||||
|
batch.empty_()
|
||||||
|
return batch
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
r = []
|
r = []
|
||||||
@ -534,21 +582,20 @@ class Batch:
|
|||||||
return min(r)
|
return min(r)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self) -> int:
|
def shape(self) -> List[int]:
|
||||||
"""Return self.size."""
|
"""Return self.shape."""
|
||||||
if len(self.__dict__.keys()) == 0:
|
if len(self.__dict__.keys()) == 0:
|
||||||
return 0
|
return []
|
||||||
else:
|
else:
|
||||||
r = []
|
data_shape = []
|
||||||
for v in self.__dict__.values():
|
for v in self.__dict__.values():
|
||||||
if isinstance(v, Batch):
|
try:
|
||||||
r.append(v.size)
|
data_shape.append(v.shape)
|
||||||
elif hasattr(v, '__len__') and (not isinstance(
|
except AttributeError:
|
||||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
raise TypeError("No support for 'shape' method with "
|
||||||
r.append(len(v))
|
f"type {type(v)} in class Batch.")
|
||||||
else:
|
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
|
||||||
r.append(1)
|
else data_shape[0]
|
||||||
return min(r) if len(r) > 0 else 0
|
|
||||||
|
|
||||||
def split(self, size: Optional[int] = None,
|
def split(self, size: Optional[int] = None,
|
||||||
shuffle: bool = True) -> Iterator['Batch']:
|
shuffle: bool = True) -> Iterator['Batch']:
|
||||||
|
@ -200,8 +200,10 @@ class Collector(object):
|
|||||||
return
|
return
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[id] = None
|
self.state[id] = None
|
||||||
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
|
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||||
self.state[id] *= 0
|
self.state[id] *= 0
|
||||||
|
else: # Batch
|
||||||
|
self.state[id].empty_()
|
||||||
|
|
||||||
def collect(self,
|
def collect(self,
|
||||||
n_step: int = 0,
|
n_step: int = 0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user