Advanced Batch slicing & minor fix of RNN support (#106)

* add shape property and modify __getitem__

* change Batch.size to Batch.shape

* setattr

* Batch.empty

* remove scalar in advanced slicing

* modify empty_ and __getitem__

* missing testcase

* fix empty
This commit is contained in:
n+e 2020-06-30 18:02:44 +08:00 committed by GitHub
parent c639446c66
commit db0e2e5cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 183 additions and 101 deletions

View File

@ -5,3 +5,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_) * Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_) * Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)

View File

@ -13,9 +13,9 @@ def test_batch():
batch.obs = [1] batch.obs = [1]
assert batch.obs == [1] assert batch.obs == [1]
batch.cat_(batch) batch.cat_(batch)
assert batch.obs == [1, 1] assert np.allclose(batch.obs, [1, 1])
assert batch.np.shape == (6, 4) assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs assert np.allclose(batch[0].obs, batch[1].obs)
batch.obs = np.arange(5) batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, shuffle=False)): for i, b in enumerate(batch.split(1, shuffle=False)):
if i != 5: if i != 5:
@ -39,14 +39,14 @@ def test_batch():
'c': np.zeros(1), 'c': np.zeros(1),
'd': Batch(e=np.array(3.0))}]) 'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1 assert len(batch2) == 1
assert Batch().size == 0 assert Batch().shape == []
assert batch2.size == 1 assert batch2.shape[0] == 1
with pytest.raises(IndexError): with pytest.raises(IndexError):
batch2[-2] batch2[-2]
with pytest.raises(IndexError): with pytest.raises(IndexError):
batch2[1] batch2[1]
assert batch2[0].size == 1 assert batch2[0].shape == []
with pytest.raises(TypeError): with pytest.raises(IndexError):
batch2[0][0] batch2[0][0]
with pytest.raises(TypeError): with pytest.raises(TypeError):
len(batch2[0]) len(batch2[0])
@ -87,24 +87,36 @@ def test_batch_over_batch():
batch2.b.b[-1] = 0 batch2.b.b[-1] = 0
print(batch2) print(batch2)
for k, v in batch2.items(): for k, v in batch2.items():
assert batch2[k] == v assert np.all(batch2[k] == v)
assert batch2[-1].b.b == 0 assert batch2[-1].b.b == 0
batch2.cat_(Batch(c=[6, 7, 8], b=batch)) batch2.cat_(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8] assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert batch2.b.a == [3, 4, 5, 3, 4, 5] assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
assert batch2.b.b == [4, 5, 0, 4, 5, 0] assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
d = {'a': [3, 4, 5], 'b': [4, 5, 6]} d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d) batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat_(Batch(c=[6, 7, 8], b=d)) batch3.cat_(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8] assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
assert batch3.b.b == [4, 5, 6, 4, 5, 6] assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
batch4 = Batch(({'a': {'b': np.array([1.0])}},)) batch4 = Batch(({'a': {'b': np.array([1.0])}},))
assert batch4.a.b.ndim == 2 assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0 assert batch4.a.b[0, 0] == 1.0
# advanced slicing
batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])})
assert batch5.shape == [1, 2]
with pytest.raises(IndexError):
batch5[2]
with pytest.raises(IndexError):
batch5[:, 3]
with pytest.raises(IndexError):
batch5[:, :, -1]
batch5[:, -1] += 1
assert np.allclose(batch5.a, [1, 3])
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
def test_batch_cat_and_stack(): def test_batch_cat_and_stack_and_empty():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2)) b12_cat_out = Batch.cat((b1, b2))
@ -133,6 +145,24 @@ def test_batch_cat_and_stack():
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0 assert b5.b.d[1] == 0.0
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == ['2.0', '']
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():
@ -215,3 +245,5 @@ if __name__ == '__main__':
test_utils_to_torch() test_utils_to_torch()
test_batch_pickle() test_batch_pickle()
test_batch_from_to_numpy_without_copy() test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility()
test_batch_cat_and_stack_and_empty()

View File

@ -74,8 +74,9 @@ class Batch:
>>> import numpy as np >>> import numpy as np
>>> from tianshou.data import Batch >>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312') >>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b >>> data.b
[5, 5] array([5, 5])
>>> data.b = np.array([3, 4, 5]) >>> data.b = np.array([3, 4, 5])
>>> print(data) >>> print(data)
Batch( Batch(
@ -104,8 +105,6 @@ class Batch:
together: together:
:: ::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch([{'a': {'b': [0.0, "info"]}}]) >>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0]) >>> print(data[0])
Batch( Batch(
@ -119,7 +118,6 @@ class Batch:
key, or iterate over stored data: key, or iterate over stored data:
:: ::
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5]) >>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"]) >>> print(data["a"])
4 4
@ -130,28 +128,36 @@ class Batch:
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for :class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
arrays. You can access or iterate over the individual samples, if any: arrays. It also supports the advanced slicing method, such as batch[:, i],
if the index is valid. You can access or iterate over the individual
samples, if any:
:: ::
>>> import numpy as np >>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> from tianshou.data import Batch
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
>>> print(data[0]) >>> print(data[0])
Batch( Batch(
a: np.array([0.0, 2.0]) a: array([0., 2.])
b: 5 b: array([ 5, -5]),
) )
>>> for sample in data: >>> for sample in data:
>>> print(sample.a) >>> print(sample.a)
[0.0, 2.0] [0., 2.]
[1.0, 3.0] [1., 3.]
>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
Similarly, one can also perform simple algebra on it, and stack, split or Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances: concatenate multiple instances:
:: ::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5) >>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5) >>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2)) >>> data = Batch.stack((data_1, data_2))
@ -169,11 +175,10 @@ class Batch:
>>> data_split = list(data.split(1, False)) >>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False))) >>> print(list(data.split(1, False)))
[Batch( [Batch(
b: [5], b: array([5]),
a: array([[0., 2.]]), a: array([[0., 2.]]),
), ), Batch(
Batch( b: array([-5]),
b: [-5],
a: array([[1., 3.]]), a: array([[1., 3.]]),
)] )]
>>> data_cat = Batch.cat(data_split) >>> data_cat = Batch.cat(data_split)
@ -188,8 +193,6 @@ class Batch:
None is added in list or :class:`np.ndarray` of objects, 0 otherwise. None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
:: ::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0])) >>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done') >>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2)) >>> data = Batch.stack((data_1, data_2))
@ -200,23 +203,40 @@ class Batch:
b: array([None, 'done'], dtype=object), b: array([None, 'done'], dtype=object),
) )
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__` Also with method empty (which will set to 0 or ``None`` (with np.object))
methods are also provided to respectively get the length and the size of ::
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception, while >>> data.empty_()
the size is 1. The size is only 0 if empty. Note that the size and length >>> print(data)
are the identical if multiple samples are stored: Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([0., 3.]),
d: array([0., 0.]),
),
)
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the shape and the length of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception.
:: ::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4))) >>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.size >>> data.shape
2 [2]
>>> len(data) >>> len(data)
2 2
>>> data[0].size >>> data[0].shape
1 []
>>> len(data[0]) >>> len(data[0])
TypeError: Object of type 'Batch' has no len() TypeError: Object of type 'Batch' has no len()
@ -240,13 +260,26 @@ class Batch:
if isinstance(v, dict) or _is_batch_set(v): if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v) self.__dict__[k] = Batch(v)
else: else:
if isinstance(v, list):
v = np.array(v)
self.__dict__[k] = v self.__dict__[k] = v
if len(kwargs) > 0: if len(kwargs) > 0:
self.__init__(kwargs) self.__init__(kwargs)
def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
if isinstance(value, list):
if _is_batch_set(value):
value = Batch(value)
else:
value = np.array(value)
elif isinstance(value, dict):
value = Batch(value)
self.__dict__[key] = value
def __getstate__(self): def __getstate__(self):
"""Pickling interface. Only the actual data are serialized """Pickling interface. Only the actual data are serialized for both
for both efficiency and simplicity. efficiency and simplicity.
""" """
state = {} state = {}
for k, v in self.items(): for k, v in self.items():
@ -256,9 +289,9 @@ class Batch:
return state return state
def __setstate__(self, state): def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch """Unpickling interface. At this point, self is an empty Batch instance
instance that has not been initialized, so it can safely be that has not been initialized, so it can safely be initialized by the
initialized by the pickle state. pickle state.
""" """
self.__init__(**state) self.__init__(**state)
@ -267,26 +300,18 @@ class Batch:
"""Return self[index].""" """Return self[index]."""
if isinstance(index, str): if isinstance(index, str):
return self.__dict__[index] return self.__dict__[index]
b = Batch()
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
return b
if not _valid_bounds(len(self), index): def __setitem__(
raise IndexError( self,
f"Index {index} out of bounds for Batch of len {len(self)}.") index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
else: value: Any) -> None:
b = Batch()
is_index_scalar = isinstance(index, (int, np.integer)) or \
(isinstance(index, np.ndarray) and index.ndim == 0)
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
elif is_index_scalar or not isinstance(v, list):
b.__dict__[k] = v[index]
else:
b.__dict__[k] = [v[i] for i in index]
return b
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index].""" """Assign value to self[index]."""
if isinstance(index, str): if isinstance(index, str):
self.__dict__[index] = value self.__dict__[index] = value
@ -319,8 +344,6 @@ class Batch:
other.__dict__.values()): other.__dict__.values()):
if r is None: if r is None:
continue continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else: else:
self.__dict__[k] += v self.__dict__[k] += v
return self return self
@ -328,8 +351,6 @@ class Batch:
for k, r in self.items(): for k, r in self.items():
if r is None: if r is None:
continue continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + other for r_ in r]
else: else:
self.__dict__[k] += other self.__dict__[k] += other
return self return self
@ -440,13 +461,14 @@ class Batch:
v.to_torch(dtype, device) v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None: def append(self, batch: 'Batch') -> None:
warnings.warn('Method append will be removed soon, please use ' warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
'removed soon, please use '
':meth:`~tianshou.data.Batch.cat`') ':meth:`~tianshou.data.Batch.cat`')
return self.cat_(batch) return self.cat_(batch)
def cat_(self, batch: 'Batch') -> None: def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into """Concatenate a :class:`~tianshou.data.Batch` object into current
current batch. batch.
""" """
assert isinstance(batch, Batch), \ assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!' 'Only Batch is allowed to be concatenated in-place!'
@ -459,8 +481,6 @@ class Batch:
self.__dict__[k] = np.concatenate([self.__dict__[k], v]) self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v]) self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list):
self.__dict__[k] += copy.deepcopy(v)
elif isinstance(v, Batch): elif isinstance(v, Batch):
self.__dict__[k].cat_(v) self.__dict__[k].cat_(v)
else: else:
@ -468,12 +488,12 @@ class Batch:
f'{type(v)} in class Batch.' f'{type(v)} in class Batch.'
raise TypeError(s) raise TypeError(s)
@classmethod @staticmethod
def cat(cls, batches: List['Batch']) -> 'Batch': def cat(batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a """Concatenate a :class:`~tianshou.data.Batch` object into a single
single new batch. new batch.
""" """
batch = cls() batch = Batch()
for batch_ in batches: for batch_ in batches:
batch.cat_(batch_) batch.cat_(batch_)
return batch return batch
@ -481,8 +501,7 @@ class Batch:
def stack_(self, def stack_(self,
batches: List[Union[dict, 'Batch']], batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None: axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current """Stack a :class:`~tianshou.data.Batch` object i into current batch.
batch.
""" """
if len(self.__dict__) > 0: if len(self.__dict__) > 0:
batches = [self] + list(batches) batches = [self] + list(batches)
@ -511,13 +530,42 @@ class Batch:
@staticmethod @staticmethod
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a """Stack a :class:`~tianshou.data.Batch` object into a single new
single new batch. batch.
""" """
batch = Batch() batch = Batch()
batch.stack_(batches, axis) batch.stack_(batches, axis)
return batch return batch
def empty_(self) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled.
"""
for k, v in self.items():
if v is None:
continue
if isinstance(v, Batch):
self.__dict__[k].empty_()
elif isinstance(v, np.ndarray) and v.dtype == np.object:
self.__dict__[k].fill(None)
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
else: # np
self.__dict__[k] *= 0
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
return self
@staticmethod
def empty(batch: 'Batch') -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`.
"""
batch = Batch(**batch)
batch.empty_()
return batch
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
r = [] r = []
@ -534,21 +582,20 @@ class Batch:
return min(r) return min(r)
@property @property
def size(self) -> int: def shape(self) -> List[int]:
"""Return self.size.""" """Return self.shape."""
if len(self.__dict__.keys()) == 0: if len(self.__dict__.keys()) == 0:
return 0 return []
else: else:
r = [] data_shape = []
for v in self.__dict__.values(): for v in self.__dict__.values():
if isinstance(v, Batch): try:
r.append(v.size) data_shape.append(v.shape)
elif hasattr(v, '__len__') and (not isinstance( except AttributeError:
v, (np.ndarray, torch.Tensor)) or v.ndim > 0): raise TypeError("No support for 'shape' method with "
r.append(len(v)) f"type {type(v)} in class Batch.")
else: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
r.append(1) else data_shape[0]
return min(r) if len(r) > 0 else 0
def split(self, size: Optional[int] = None, def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']: shuffle: bool = True) -> Iterator['Batch']:

View File

@ -200,8 +200,10 @@ class Collector(object):
return return
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[id] = None self.state[id] = None
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)): elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] *= 0 self.state[id] *= 0
else: # Batch
self.state[id].empty_()
def collect(self, def collect(self,
n_step: int = 0, n_step: int = 0,