Advanced Batch slicing & minor fix of RNN support (#106)
* add shape property and modify __getitem__ * change Batch.size to Batch.shape * setattr * Batch.empty * remove scalar in advanced slicing * modify empty_ and __getitem__ * missing testcase * fix empty
This commit is contained in:
parent
c639446c66
commit
db0e2e5cd2
@ -5,3 +5,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
|
||||
|
||||
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
|
||||
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
|
||||
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
|
||||
|
@ -13,9 +13,9 @@ def test_batch():
|
||||
batch.obs = [1]
|
||||
assert batch.obs == [1]
|
||||
batch.cat_(batch)
|
||||
assert batch.obs == [1, 1]
|
||||
assert np.allclose(batch.obs, [1, 1])
|
||||
assert batch.np.shape == (6, 4)
|
||||
assert batch[0].obs == batch[1].obs
|
||||
assert np.allclose(batch[0].obs, batch[1].obs)
|
||||
batch.obs = np.arange(5)
|
||||
for i, b in enumerate(batch.split(1, shuffle=False)):
|
||||
if i != 5:
|
||||
@ -39,14 +39,14 @@ def test_batch():
|
||||
'c': np.zeros(1),
|
||||
'd': Batch(e=np.array(3.0))}])
|
||||
assert len(batch2) == 1
|
||||
assert Batch().size == 0
|
||||
assert batch2.size == 1
|
||||
assert Batch().shape == []
|
||||
assert batch2.shape[0] == 1
|
||||
with pytest.raises(IndexError):
|
||||
batch2[-2]
|
||||
with pytest.raises(IndexError):
|
||||
batch2[1]
|
||||
assert batch2[0].size == 1
|
||||
with pytest.raises(TypeError):
|
||||
assert batch2[0].shape == []
|
||||
with pytest.raises(IndexError):
|
||||
batch2[0][0]
|
||||
with pytest.raises(TypeError):
|
||||
len(batch2[0])
|
||||
@ -87,24 +87,36 @@ def test_batch_over_batch():
|
||||
batch2.b.b[-1] = 0
|
||||
print(batch2)
|
||||
for k, v in batch2.items():
|
||||
assert batch2[k] == v
|
||||
assert np.all(batch2[k] == v)
|
||||
assert batch2[-1].b.b == 0
|
||||
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
|
||||
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
||||
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
||||
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
|
||||
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
|
||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||
assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
|
||||
assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
|
||||
assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
|
||||
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
|
||||
assert batch4.a.b.ndim == 2
|
||||
assert batch4.a.b[0, 0] == 1.0
|
||||
# advanced slicing
|
||||
batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])})
|
||||
assert batch5.shape == [1, 2]
|
||||
with pytest.raises(IndexError):
|
||||
batch5[2]
|
||||
with pytest.raises(IndexError):
|
||||
batch5[:, 3]
|
||||
with pytest.raises(IndexError):
|
||||
batch5[:, :, -1]
|
||||
batch5[:, -1] += 1
|
||||
assert np.allclose(batch5.a, [1, 3])
|
||||
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
||||
|
||||
|
||||
def test_batch_cat_and_stack():
|
||||
def test_batch_cat_and_stack_and_empty():
|
||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||
b12_cat_out = Batch.cat((b1, b2))
|
||||
@ -133,6 +145,24 @@ def test_batch_cat_and_stack():
|
||||
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||
assert b5.b.d[1] == 0.0
|
||||
b5[1] = Batch.empty(b5[0])
|
||||
assert np.allclose(b5.a, [False, False])
|
||||
assert np.allclose(b5.b.c, [2, 0])
|
||||
assert np.allclose(b5.b.d, [1, 0])
|
||||
data = Batch(a=[False, True],
|
||||
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
|
||||
c=np.array([1, 3, 4], dtype=np.int),
|
||||
t=torch.tensor([4, 5, 6, 7.]))
|
||||
data[-1] = Batch.empty(data[1])
|
||||
assert np.allclose(data.c, [1, 3, 0])
|
||||
assert np.allclose(data.a, [False, False])
|
||||
assert list(data.b.c) == ['2.0', '']
|
||||
assert list(data.b.d) == [1, None]
|
||||
assert np.allclose(data.b.e, [2, 0])
|
||||
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
|
||||
b0 = Batch()
|
||||
b0.empty_()
|
||||
assert b0.shape == []
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
@ -215,3 +245,5 @@ if __name__ == '__main__':
|
||||
test_utils_to_torch()
|
||||
test_batch_pickle()
|
||||
test_batch_from_to_numpy_without_copy()
|
||||
test_batch_numpy_compatibility()
|
||||
test_batch_cat_and_stack_and_empty()
|
||||
|
@ -74,8 +74,9 @@ class Batch:
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=4, b=[5, 5], c='2312312')
|
||||
>>> # the list will automatically be converted to numpy array
|
||||
>>> data.b
|
||||
[5, 5]
|
||||
array([5, 5])
|
||||
>>> data.b = np.array([3, 4, 5])
|
||||
>>> print(data)
|
||||
Batch(
|
||||
@ -104,8 +105,6 @@ class Batch:
|
||||
together:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
|
||||
>>> print(data[0])
|
||||
Batch(
|
||||
@ -119,7 +118,6 @@ class Batch:
|
||||
key, or iterate over stored data:
|
||||
::
|
||||
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=4, b=[5, 5])
|
||||
>>> print(data["a"])
|
||||
4
|
||||
@ -130,28 +128,36 @@ class Batch:
|
||||
|
||||
|
||||
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
|
||||
arrays. You can access or iterate over the individual samples, if any:
|
||||
arrays. It also supports the advanced slicing method, such as batch[:, i],
|
||||
if the index is valid. You can access or iterate over the individual
|
||||
samples, if any:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
|
||||
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
|
||||
>>> print(data[0])
|
||||
Batch(
|
||||
a: np.array([0.0, 2.0])
|
||||
b: 5
|
||||
a: array([0., 2.])
|
||||
b: array([ 5, -5]),
|
||||
)
|
||||
>>> for sample in data:
|
||||
>>> print(sample.a)
|
||||
[0.0, 2.0]
|
||||
[1.0, 3.0]
|
||||
[0., 2.]
|
||||
[1., 3.]
|
||||
|
||||
>>> print(data.shape)
|
||||
[1, 2]
|
||||
>>> data[:, 1] += 1
|
||||
>>> print(data)
|
||||
Batch(
|
||||
a: array([[0., 3.],
|
||||
[1., 4.]]),
|
||||
b: array([[ 5, -4]]),
|
||||
)
|
||||
|
||||
Similarly, one can also perform simple algebra on it, and stack, split or
|
||||
concatenate multiple instances:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
|
||||
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
|
||||
>>> data = Batch.stack((data_1, data_2))
|
||||
@ -169,11 +175,10 @@ class Batch:
|
||||
>>> data_split = list(data.split(1, False))
|
||||
>>> print(list(data.split(1, False)))
|
||||
[Batch(
|
||||
b: [5],
|
||||
b: array([5]),
|
||||
a: array([[0., 2.]]),
|
||||
),
|
||||
Batch(
|
||||
b: [-5],
|
||||
), Batch(
|
||||
b: array([-5]),
|
||||
a: array([[1., 3.]]),
|
||||
)]
|
||||
>>> data_cat = Batch.cat(data_split)
|
||||
@ -188,8 +193,6 @@ class Batch:
|
||||
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
|
||||
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
|
||||
>>> data = Batch.stack((data_1, data_2))
|
||||
@ -200,23 +203,40 @@ class Batch:
|
||||
b: array([None, 'done'], dtype=object),
|
||||
)
|
||||
|
||||
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__`
|
||||
methods are also provided to respectively get the length and the size of
|
||||
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
|
||||
means that getting the length of a scalar Batch raises an exception, while
|
||||
the size is 1. The size is only 0 if empty. Note that the size and length
|
||||
are the identical if multiple samples are stored:
|
||||
Also with method empty (which will set to 0 or ``None`` (with np.object))
|
||||
::
|
||||
|
||||
>>> data.empty_()
|
||||
>>> print(data)
|
||||
Batch(
|
||||
a: array([[0., 0.],
|
||||
[0., 0.]]),
|
||||
b: array([None, None], dtype=object),
|
||||
)
|
||||
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
|
||||
>>> data[0] = Batch.empty(data[1])
|
||||
>>> data
|
||||
Batch(
|
||||
a: array([False, True]),
|
||||
b: Batch(
|
||||
c: array([0., 3.]),
|
||||
d: array([0., 0.]),
|
||||
),
|
||||
)
|
||||
|
||||
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
|
||||
methods are also provided to respectively get the shape and the length of
|
||||
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
|
||||
means that getting the length of a scalar Batch raises an exception.
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
|
||||
>>> data.size
|
||||
2
|
||||
>>> data.shape
|
||||
[2]
|
||||
>>> len(data)
|
||||
2
|
||||
>>> data[0].size
|
||||
1
|
||||
>>> data[0].shape
|
||||
[]
|
||||
>>> len(data[0])
|
||||
TypeError: Object of type 'Batch' has no len()
|
||||
|
||||
@ -240,13 +260,26 @@ class Batch:
|
||||
if isinstance(v, dict) or _is_batch_set(v):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
if isinstance(v, list):
|
||||
v = np.array(v)
|
||||
self.__dict__[k] = v
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs)
|
||||
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
"""self[key] = value"""
|
||||
if isinstance(value, list):
|
||||
if _is_batch_set(value):
|
||||
value = Batch(value)
|
||||
else:
|
||||
value = np.array(value)
|
||||
elif isinstance(value, dict):
|
||||
value = Batch(value)
|
||||
self.__dict__[key] = value
|
||||
|
||||
def __getstate__(self):
|
||||
"""Pickling interface. Only the actual data are serialized
|
||||
for both efficiency and simplicity.
|
||||
"""Pickling interface. Only the actual data are serialized for both
|
||||
efficiency and simplicity.
|
||||
"""
|
||||
state = {}
|
||||
for k, v in self.items():
|
||||
@ -256,9 +289,9 @@ class Batch:
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Unpickling interface. At this point, self is an empty Batch
|
||||
instance that has not been initialized, so it can safely be
|
||||
initialized by the pickle state.
|
||||
"""Unpickling interface. At this point, self is an empty Batch instance
|
||||
that has not been initialized, so it can safely be initialized by the
|
||||
pickle state.
|
||||
"""
|
||||
self.__init__(**state)
|
||||
|
||||
@ -267,26 +300,18 @@ class Batch:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
return self.__dict__[index]
|
||||
b = Batch()
|
||||
for k, v in self.items():
|
||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||
b.__dict__[k] = Batch()
|
||||
else:
|
||||
b.__dict__[k] = v[index]
|
||||
return b
|
||||
|
||||
if not _valid_bounds(len(self), index):
|
||||
raise IndexError(
|
||||
f"Index {index} out of bounds for Batch of len {len(self)}.")
|
||||
else:
|
||||
b = Batch()
|
||||
is_index_scalar = isinstance(index, (int, np.integer)) or \
|
||||
(isinstance(index, np.ndarray) and index.ndim == 0)
|
||||
for k, v in self.items():
|
||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||
b.__dict__[k] = Batch()
|
||||
elif is_index_scalar or not isinstance(v, list):
|
||||
b.__dict__[k] = v[index]
|
||||
else:
|
||||
b.__dict__[k] = [v[i] for i in index]
|
||||
return b
|
||||
|
||||
def __setitem__(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
if isinstance(index, str):
|
||||
self.__dict__[index] = value
|
||||
@ -319,8 +344,6 @@ class Batch:
|
||||
other.__dict__.values()):
|
||||
if r is None:
|
||||
continue
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
|
||||
else:
|
||||
self.__dict__[k] += v
|
||||
return self
|
||||
@ -328,8 +351,6 @@ class Batch:
|
||||
for k, r in self.items():
|
||||
if r is None:
|
||||
continue
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + other for r_ in r]
|
||||
else:
|
||||
self.__dict__[k] += other
|
||||
return self
|
||||
@ -440,13 +461,14 @@ class Batch:
|
||||
v.to_torch(dtype, device)
|
||||
|
||||
def append(self, batch: 'Batch') -> None:
|
||||
warnings.warn('Method append will be removed soon, please use '
|
||||
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
|
||||
'removed soon, please use '
|
||||
':meth:`~tianshou.data.Batch.cat`')
|
||||
return self.cat_(batch)
|
||||
|
||||
def cat_(self, batch: 'Batch') -> None:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into
|
||||
current batch.
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into current
|
||||
batch.
|
||||
"""
|
||||
assert isinstance(batch, Batch), \
|
||||
'Only Batch is allowed to be concatenated in-place!'
|
||||
@ -459,8 +481,6 @@ class Batch:
|
||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||
elif isinstance(v, list):
|
||||
self.__dict__[k] += copy.deepcopy(v)
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].cat_(v)
|
||||
else:
|
||||
@ -468,12 +488,12 @@ class Batch:
|
||||
f'{type(v)} in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
@classmethod
|
||||
def cat(cls, batches: List['Batch']) -> 'Batch':
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
@staticmethod
|
||||
def cat(batches: List['Batch']) -> 'Batch':
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
|
||||
new batch.
|
||||
"""
|
||||
batch = cls()
|
||||
batch = Batch()
|
||||
for batch_ in batches:
|
||||
batch.cat_(batch_)
|
||||
return batch
|
||||
@ -481,8 +501,7 @@ class Batch:
|
||||
def stack_(self,
|
||||
batches: List[Union[dict, 'Batch']],
|
||||
axis: int = 0) -> None:
|
||||
"""Stack a :class:`~tianshou.data.Batch` object i into current
|
||||
batch.
|
||||
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
|
||||
"""
|
||||
if len(self.__dict__) > 0:
|
||||
batches = [self] + list(batches)
|
||||
@ -511,13 +530,42 @@ class Batch:
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
||||
batch.
|
||||
"""
|
||||
batch = Batch()
|
||||
batch.stack_(batches, axis)
|
||||
return batch
|
||||
|
||||
def empty_(self) -> 'Batch':
|
||||
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
|
||||
``None`` filled.
|
||||
"""
|
||||
for k, v in self.items():
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, Batch):
|
||||
self.__dict__[k].empty_()
|
||||
elif isinstance(v, np.ndarray) and v.dtype == np.object:
|
||||
self.__dict__[k].fill(None)
|
||||
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
|
||||
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
|
||||
else: # np
|
||||
self.__dict__[k] *= 0
|
||||
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
|
||||
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def empty(batch: 'Batch') -> 'Batch':
|
||||
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
|
||||
``None`` filled, the shape is the same as the given
|
||||
:class:`~tianshou.data.Batch`.
|
||||
"""
|
||||
batch = Batch(**batch)
|
||||
batch.empty_()
|
||||
return batch
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
r = []
|
||||
@ -534,21 +582,20 @@ class Batch:
|
||||
return min(r)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Return self.size."""
|
||||
def shape(self) -> List[int]:
|
||||
"""Return self.shape."""
|
||||
if len(self.__dict__.keys()) == 0:
|
||||
return 0
|
||||
return []
|
||||
else:
|
||||
r = []
|
||||
data_shape = []
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch):
|
||||
r.append(v.size)
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
r.append(len(v))
|
||||
else:
|
||||
r.append(1)
|
||||
return min(r) if len(r) > 0 else 0
|
||||
try:
|
||||
data_shape.append(v.shape)
|
||||
except AttributeError:
|
||||
raise TypeError("No support for 'shape' method with "
|
||||
f"type {type(v)} in class Batch.")
|
||||
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
|
||||
else data_shape[0]
|
||||
|
||||
def split(self, size: Optional[int] = None,
|
||||
shuffle: bool = True) -> Iterator['Batch']:
|
||||
|
@ -200,8 +200,10 @@ class Collector(object):
|
||||
return
|
||||
if isinstance(self.state, list):
|
||||
self.state[id] = None
|
||||
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
|
||||
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||
self.state[id] *= 0
|
||||
else: # Batch
|
||||
self.state[id].empty_()
|
||||
|
||||
def collect(self,
|
||||
n_step: int = 0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user