Enable partial stacking at Batch level (#100)

* Enable stacking of partially matching Batch instances.

* Fix list support for getitem.

* Fix Batch 'size' method.

* Update Batch documentation.
This commit is contained in:
Alexis DUBURCQ 2020-06-27 03:06:40 +02:00 committed by GitHub
parent 70aa7bf93e
commit a951a32487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 217 additions and 89 deletions

View File

@ -126,6 +126,13 @@ def test_batch_cat_and_stack():
b34_stack = Batch.stack((b3, b4), axis=1) b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():

View File

@ -36,7 +36,7 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0) assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert np.all(np.isnan(b.info.b.c[1:])) assert np.all(b.info.b.c[1:] == 0.0)
def test_ignore_obs_next(size=10): def test_ignore_obs_next(size=10):

View File

@ -3,6 +3,7 @@ import copy
import pprint import pprint
import warnings import warnings
import numpy as np import numpy as np
from functools import reduce
from numbers import Number from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional from typing import Any, List, Tuple, Union, Iterator, Optional
@ -42,6 +43,27 @@ def _valid_bounds(length: int, index: Union[
return start_valid and stop_valid return start_valid and stop_valid
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
return np.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
dtype=inst.dtype)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
else: # fall back to np.object
return np.array([None for _ in range(size)])
class Batch: class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a structure to pass any kind of data to other methods, for example, a
@ -75,33 +97,133 @@ class Batch:
function return 4 arguments, and the last one is ``info``); function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`; * ``policy`` the data computed by policy in step :math:`t`;
:class:`~tianshou.data.Batch` has other methods, including :class:`Batch` object can be initialized using wide variety of arguments,
:meth:`~tianshou.data.Batch.__getitem__`, starting with the key/value pairs or dictionary, but also list and Numpy
:meth:`~tianshou.data.Batch.__len__`, arrays of :class:`dict` or Batch instances. In which case, each element
:meth:`~tianshou.data.Batch.append`, is considered as an individual sample and get stacked together:
and :meth:`~tianshou.data.Batch.split`:
:: ::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6])) >>> import numpy as np
>>> # here we test __getitem__ >>> from tianshou.data import Batch
>>> index = [2, 1] >>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> data[index].obs >>> print(data[0])
array([22, 11]) Batch(
a: Batch(
b: array(['0.0', 'info'], dtype='<U32'),
),
)
>>> # here we test __len__ :class:`Batch` has the same API as a native Python :class:`dict`. In this
regard, one can access to stored data using string key, or iterate over
stored data:
::
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
:class:`Batch` is also reproduce partially the Numpy API for arrays. You
can access or iterate over the individual samples, if any:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
>>> print(data[0])
Batch(
a: np.array([0.0, 2.0])
b: 5
)
>>> for sample in data:
>>> print(sample.a)
[0.0, 2.0]
[1.0, 3.0]
Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> print(np.mean(data))
Batch(
b: 0.0,
a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: [5],
a: array([[0., 2.]]),
),
Batch(
b: [-5],
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
Note that stacking of inconsistent data is also supported. In which case,
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the length and the size of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception, while
the size is 1. The size is only 0 if empty. Note that the size and length
are the identical if multiple samples are stored:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.size
2
>>> len(data) >>> len(data)
3 2
>>> data[0].size
1
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
>>> data.append(data) # similar to list.append Convenience helpers are available to convert in-place the
>>> data.obs stored data into Numpy arrays or Torch tensors.
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch Finally, note that Batch instance are serializable and therefore Pickle
>>> for d in data.split(size=2, shuffle=False): compatible. This is especially important for distributed sampling.
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
""" """
def __init__(self, def __init__(self,
@ -110,18 +232,7 @@ class Batch:
List[Union[dict, 'Batch']], np.ndarray]] = None, List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None: **kwargs) -> None:
if _is_batch_set(batch_dict): if _is_batch_set(batch_dict):
for k, v in zip(batch_dict[0].keys(), self.stack_(batch_dict)
zip(*[e.values() for e in batch_dict])):
if isinstance(v[0], dict) or _is_batch_set(v[0]):
self.__dict__[k] = Batch(v)
elif isinstance(v[0], (np.generic, np.ndarray)):
self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v)
else:
self.__dict__[k] = np.array(v)
elif isinstance(batch_dict, (dict, Batch)): elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items(): for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v): if isinstance(v, dict) or _is_batch_set(v):
@ -160,16 +271,21 @@ class Batch:
f"Index {index} out of bounds for Batch of len {len(self)}.") f"Index {index} out of bounds for Batch of len {len(self)}.")
else: else:
b = Batch() b = Batch()
is_index_scalar = isinstance(index, (int, np.integer)) or \
(isinstance(index, np.ndarray) and index.ndim == 0)
for k, v in self.items(): for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0: if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch() b.__dict__[k] = Batch()
else: elif is_index_scalar or not isinstance(v, list):
b.__dict__[k] = v[index] b.__dict__[k] = v[index]
else:
b.__dict__[k] = [v[i] for i in index]
return b return b
def __setitem__(self, index: Union[ def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]], str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None: value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(index, str): if isinstance(index, str):
self.__dict__[index] = value self.__dict__[index] = value
return return
@ -193,10 +309,12 @@ class Batch:
else: else:
self.__dict__[key][index] = None self.__dict__[key][index] = None
def __iadd__(self, val: Union['Batch', Number]): def __iadd__(self, other: Union['Batch', Number]):
if isinstance(val, Batch): """Algebraic addition with another :class:`~tianshou.data.Batch`
instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(), for (k, r), v in zip(self.__dict__.items(),
val.__dict__.values()): other.__dict__.values()):
if r is None: if r is None:
continue continue
elif isinstance(r, list): elif isinstance(r, list):
@ -204,22 +322,25 @@ class Batch:
else: else:
self.__dict__[k] += v self.__dict__[k] += v
return self return self
elif isinstance(val, Number): elif isinstance(other, Number):
for k, r in self.items(): for k, r in self.items():
if r is None: if r is None:
continue continue
elif isinstance(r, list): elif isinstance(r, list):
self.__dict__[k] = [r_ + val for r_ in r] self.__dict__[k] = [r_ + other for r_ in r]
else: else:
self.__dict__[k] += val self.__dict__[k] += other
return self return self
else: else:
raise TypeError("Only addition of Batch or number is supported.") raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, val: Union['Batch', Number]): def __add__(self, other: Union['Batch', Number]):
return copy.deepcopy(self).__iadd__(val) """Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return copy.deepcopy(self).__iadd__(other)
def __imul__(self, val: Number): def __imul__(self, val: Number):
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, Number), \ assert isinstance(val, Number), \
"Only multiplication by a number is supported." "Only multiplication by a number is supported."
for k in self.__dict__.keys(): for k in self.__dict__.keys():
@ -227,9 +348,11 @@ class Batch:
return self return self
def __mul__(self, val: Number): def __mul__(self, val: Number):
"""Algebraic multiplication with a scalar value out-of-place."""
return copy.deepcopy(self).__imul__(val) return copy.deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number): def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place."""
assert isinstance(val, Number), \ assert isinstance(val, Number), \
"Only division by a number is supported." "Only division by a number is supported."
for k in self.__dict__.keys(): for k in self.__dict__.keys():
@ -237,6 +360,7 @@ class Batch:
return self return self
def __truediv__(self, val: Number): def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place."""
return copy.deepcopy(self).__itruediv__(val) return copy.deepcopy(self).__itruediv__(val)
def __repr__(self) -> str: def __repr__(self) -> str:
@ -319,8 +443,8 @@ class Batch:
return self.cat_(batch) return self.cat_(batch)
def cat_(self, batch: 'Batch') -> None: def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object to current """Concatenate a :class:`~tianshou.data.Batch` object into
batch. current batch.
""" """
assert isinstance(batch, Batch), \ assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!' 'Only Batch is allowed to be concatenated in-place!'
@ -347,39 +471,50 @@ class Batch:
"""Concatenate a :class:`~tianshou.data.Batch` object into a """Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch. single new batch.
""" """
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'concatenated out-of-place!'
batch = cls() batch = cls()
for batch_ in batches: for batch_ in batches:
batch.cat_(batch_) batch.cat_(batch_)
return batch return batch
@classmethod def stack_(self,
def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch': batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current
batch.
"""
if len(self.__dict__) > 0:
batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
if isinstance(v[0], (dict, Batch)):
self.__dict__[k] = Batch.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, axis)
else:
self.__dict__[k] = np.stack(v, axis)
keys_partial = reduce(set.symmetric_difference, keys_map)
for k in keys_partial:
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
try:
self.__dict__[k][i] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, len(batches))
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a """Stack a :class:`~tianshou.data.Batch` object into a
single new batch. single new batch.
""" """
assert isinstance(batches, (tuple, list)), \ batch = Batch()
'Only list of Batch instances is allowed to be '\ batch.stack_(batches, axis)
'stacked out-of-place!' return batch
if axis == 0:
return cls(batches)
else:
batch = Batch()
for k, v in zip(batches[0].keys(),
zip(*[e.values() for e in batches])):
if isinstance(v[0], (np.generic, np.ndarray, list)):
batch.__dict__[k] = np.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
batch.__dict__[k] = torch.stack(v, axis)
elif isinstance(v[0], Batch):
batch.__dict__[k] = Batch.stack(v, axis)
else:
s = 'No support for method "stack" with type '\
f'{type(v[0])} in class Batch and axis != 0.'
raise TypeError(s)
return batch
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
@ -409,7 +544,9 @@ class Batch:
elif hasattr(v, '__len__') and (not isinstance( elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0): v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v)) r.append(len(v))
return max(1, min(r) if len(r) > 0 else 0) else:
r.append(1)
return min(r) if len(r) > 0 else 0
def split(self, size: Optional[int] = None, def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']: shuffle: bool = True) -> Iterator['Batch']:

View File

@ -1,24 +1,7 @@
import numpy as np import numpy as np
from numbers import Number
from typing import Any, Tuple, Union, Optional from typing import Any, Tuple, Union, Optional
from .batch import Batch from .batch import Batch, _create_value
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
return np.full(shape=(size, *inst.shape),
fill_value=None if inst.dtype == np.inexact else 0,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
else: # fall back to np.object
return np.array([None for _ in range(size)])
class ReplayBuffer: class ReplayBuffer:
@ -125,6 +108,7 @@ class ReplayBuffer:
return self._size return self._size
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return str(self)."""
return self.__class__.__name__ + self._meta.__repr__()[5:] return self.__class__.__name__ + self._meta.__repr__()[5:]
def __getattr__(self, key: str) -> Union['Batch', Any]: def __getattr__(self, key: str) -> Union['Batch', Any]: