From a951a32487b31e4a789f6e890377b63d4b5c460e Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Sat, 27 Jun 2020 03:06:40 +0200 Subject: [PATCH] Enable partial stacking at Batch level (#100) * Enable stacking of partially matching Batch instances. * Fix list support for getitem. * Fix Batch 'size' method. * Update Batch documentation. --- test/base/test_batch.py | 7 + test/base/test_buffer.py | 2 +- tianshou/data/batch.py | 277 +++++++++++++++++++++++++++++---------- tianshou/data/buffer.py | 20 +-- 4 files changed, 217 insertions(+), 89 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 4854f70..4b528a3 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -126,6 +126,13 @@ def test_batch_cat_and_stack(): b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) + b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, + {'a': True, 'b': {'c': 3.0}}]) + b5 = Batch(b5_dict) + assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) + assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) + assert b5.b.d[0] == b5_dict[0]['b']['d'] + assert b5.b.d[1] == 0.0 def test_batch_over_batch_to_torch(): diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 6ab2c88..bc374b8 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -36,7 +36,7 @@ def test_replaybuffer(size=10, bufsize=20): assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact - assert np.all(np.isnan(b.info.b.c[1:])) + assert np.all(b.info.b.c[1:] == 0.0) def test_ignore_obs_next(size=10): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 21680fa..2d9b11b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,6 +3,7 @@ import copy import pprint import warnings import numpy as np +from functools import reduce from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional @@ -42,6 +43,27 @@ def _valid_bounds(length: int, index: Union[ return start_valid and stop_valid +def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: + if isinstance(inst, np.ndarray): + return np.full((size, *inst.shape), + fill_value=None if inst.dtype == np.object else 0, + dtype=inst.dtype) + elif isinstance(inst, torch.Tensor): + return torch.full((size, *inst.shape), + fill_value=None if inst.dtype == np.object else 0, + device=inst.device, + dtype=inst.dtype) + elif isinstance(inst, (dict, Batch)): + zero_batch = Batch() + for key, val in inst.items(): + zero_batch.__dict__[key] = _create_value(val, size) + return zero_batch + elif isinstance(inst, (np.generic, Number)): + return _create_value(np.asarray(inst), size) + else: # fall back to np.object + return np.array([None for _ in range(size)]) + + class Batch: """Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a @@ -75,33 +97,133 @@ class Batch: function return 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; - :class:`~tianshou.data.Batch` has other methods, including - :meth:`~tianshou.data.Batch.__getitem__`, - :meth:`~tianshou.data.Batch.__len__`, - :meth:`~tianshou.data.Batch.append`, - and :meth:`~tianshou.data.Batch.split`: + :class:`Batch` object can be initialized using wide variety of arguments, + starting with the key/value pairs or dictionary, but also list and Numpy + arrays of :class:`dict` or Batch instances. In which case, each element + is considered as an individual sample and get stacked together: :: - >>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6])) - >>> # here we test __getitem__ - >>> index = [2, 1] - >>> data[index].obs - array([22, 11]) + >>> import numpy as np + >>> from tianshou.data import Batch + >>> data = Batch([{'a': {'b': [0.0, "info"]}}]) + >>> print(data[0]) + Batch( + a: Batch( + b: array(['0.0', 'info'], dtype='>> # here we test __len__ + :class:`Batch` has the same API as a native Python :class:`dict`. In this + regard, one can access to stored data using string key, or iterate over + stored data: + :: + + >>> from tianshou.data import Batch + >>> data = Batch(a=4, b=[5, 5]) + >>> print(data["a"]) + 4 + >>> for key, value in data.items(): + >>> print(f"{key}: {value}") + a: 4 + b: [5, 5] + + + :class:`Batch` is also reproduce partially the Numpy API for arrays. You + can access or iterate over the individual samples, if any: + :: + + >>> import numpy as np + >>> from tianshou.data import Batch + >>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5]) + >>> print(data[0]) + Batch( + a: np.array([0.0, 2.0]) + b: 5 + ) + >>> for sample in data: + >>> print(sample.a) + [0.0, 2.0] + [1.0, 3.0] + + Similarly, one can also perform simple algebra on it, and stack, split or + concatenate multiple instances: + :: + + >>> import numpy as np + >>> from tianshou.data import Batch + >>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5) + >>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5) + >>> data = Batch.stack((data_1, data_2)) + >>> print(data) + Batch( + b: array([ 5, -5]), + a: array([[0., 2.], + [1., 3.]]), + ) + >>> print(np.mean(data)) + Batch( + b: 0.0, + a: array([0.5, 2.5]), + ) + >>> data_split = list(data.split(1, False)) + >>> print(list(data.split(1, False))) + [Batch( + b: [5], + a: array([[0., 2.]]), + ), + Batch( + b: [-5], + a: array([[1., 3.]]), + )] + >>> data_cat = Batch.cat(data_split) + >>> print(data_cat) + Batch( + b: array([ 5, -5]), + a: array([[0., 2.], + [1., 3.]]), + ) + + Note that stacking of inconsistent data is also supported. In which case, + None is added in list or :class:`np.ndarray` of objects, 0 otherwise. + :: + + >>> import numpy as np + >>> from tianshou.data import Batch + >>> data_1 = Batch(a=np.array([0.0, 2.0])) + >>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done') + >>> data = Batch.stack((data_1, data_2)) + >>> print(data) + Batch( + a: array([[0., 2.], + [1., 3.]]), + b: array([None, 'done'], dtype=object), + ) + + :meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__` + methods are also provided to respectively get the length and the size of + a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which + means that getting the length of a scalar Batch raises an exception, while + the size is 1. The size is only 0 if empty. Note that the size and length + are the identical if multiple samples are stored: + :: + + >>> import numpy as np + >>> from tianshou.data import Batch + >>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4))) + >>> data.size + 2 >>> len(data) - 3 + 2 + >>> data[0].size + 1 + >>> len(data[0]) + TypeError: Object of type 'Batch' has no len() - >>> data.append(data) # similar to list.append - >>> data.obs - array([0, 11, 22, 0, 11, 22]) + Convenience helpers are available to convert in-place the + stored data into Numpy arrays or Torch tensors. - >>> # split whole data into multiple small batch - >>> for d in data.split(size=2, shuffle=False): - ... print(d.obs, d.rew) - [ 0 11] [6 6] - [22 0] [6 6] - [11 22] [6 6] + Finally, note that Batch instance are serializable and therefore Pickle + compatible. This is especially important for distributed sampling. """ def __init__(self, @@ -110,18 +232,7 @@ class Batch: List[Union[dict, 'Batch']], np.ndarray]] = None, **kwargs) -> None: if _is_batch_set(batch_dict): - for k, v in zip(batch_dict[0].keys(), - zip(*[e.values() for e in batch_dict])): - if isinstance(v[0], dict) or _is_batch_set(v[0]): - self.__dict__[k] = Batch(v) - elif isinstance(v[0], (np.generic, np.ndarray)): - self.__dict__[k] = np.stack(v, axis=0) - elif isinstance(v[0], torch.Tensor): - self.__dict__[k] = torch.stack(v, dim=0) - elif isinstance(v[0], Batch): - self.__dict__[k] = Batch.stack(v) - else: - self.__dict__[k] = np.array(v) + self.stack_(batch_dict) elif isinstance(batch_dict, (dict, Batch)): for k, v in batch_dict.items(): if isinstance(v, dict) or _is_batch_set(v): @@ -160,16 +271,21 @@ class Batch: f"Index {index} out of bounds for Batch of len {len(self)}.") else: b = Batch() + is_index_scalar = isinstance(index, (int, np.integer)) or \ + (isinstance(index, np.ndarray) and index.ndim == 0) for k, v in self.items(): if isinstance(v, Batch) and len(v.__dict__) == 0: b.__dict__[k] = Batch() - else: + elif is_index_scalar or not isinstance(v, list): b.__dict__[k] = v[index] + else: + b.__dict__[k] = [v[i] for i in index] return b def __setitem__(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: + """Assign value to self[index].""" if isinstance(index, str): self.__dict__[index] = value return @@ -193,10 +309,12 @@ class Batch: else: self.__dict__[key][index] = None - def __iadd__(self, val: Union['Batch', Number]): - if isinstance(val, Batch): + def __iadd__(self, other: Union['Batch', Number]): + """Algebraic addition with another :class:`~tianshou.data.Batch` + instance in-place.""" + if isinstance(other, Batch): for (k, r), v in zip(self.__dict__.items(), - val.__dict__.values()): + other.__dict__.values()): if r is None: continue elif isinstance(r, list): @@ -204,22 +322,25 @@ class Batch: else: self.__dict__[k] += v return self - elif isinstance(val, Number): + elif isinstance(other, Number): for k, r in self.items(): if r is None: continue elif isinstance(r, list): - self.__dict__[k] = [r_ + val for r_ in r] + self.__dict__[k] = [r_ + other for r_ in r] else: - self.__dict__[k] += val + self.__dict__[k] += other return self else: raise TypeError("Only addition of Batch or number is supported.") - def __add__(self, val: Union['Batch', Number]): - return copy.deepcopy(self).__iadd__(val) + def __add__(self, other: Union['Batch', Number]): + """Algebraic addition with another :class:`~tianshou.data.Batch` + instance out-of-place.""" + return copy.deepcopy(self).__iadd__(other) def __imul__(self, val: Number): + """Algebraic multiplication with a scalar value in-place.""" assert isinstance(val, Number), \ "Only multiplication by a number is supported." for k in self.__dict__.keys(): @@ -227,9 +348,11 @@ class Batch: return self def __mul__(self, val: Number): + """Algebraic multiplication with a scalar value out-of-place.""" return copy.deepcopy(self).__imul__(val) def __itruediv__(self, val: Number): + """Algebraic division wibyth a scalar value in-place.""" assert isinstance(val, Number), \ "Only division by a number is supported." for k in self.__dict__.keys(): @@ -237,6 +360,7 @@ class Batch: return self def __truediv__(self, val: Number): + """Algebraic division wibyth a scalar value out-of-place.""" return copy.deepcopy(self).__itruediv__(val) def __repr__(self) -> str: @@ -319,8 +443,8 @@ class Batch: return self.cat_(batch) def cat_(self, batch: 'Batch') -> None: - """Concatenate a :class:`~tianshou.data.Batch` object to current - batch. + """Concatenate a :class:`~tianshou.data.Batch` object into + current batch. """ assert isinstance(batch, Batch), \ 'Only Batch is allowed to be concatenated in-place!' @@ -347,39 +471,50 @@ class Batch: """Concatenate a :class:`~tianshou.data.Batch` object into a single new batch. """ - assert isinstance(batches, (tuple, list)), \ - 'Only list of Batch instances is allowed to be '\ - 'concatenated out-of-place!' batch = cls() for batch_ in batches: batch.cat_(batch_) return batch - @classmethod - def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch': + def stack_(self, + batches: List[Union[dict, 'Batch']], + axis: int = 0) -> None: + """Stack a :class:`~tianshou.data.Batch` object i into current + batch. + """ + if len(self.__dict__) > 0: + batches = [self] + list(batches) + keys_map = list(map(lambda e: set(e.keys()), batches)) + keys_shared = set.intersection(*keys_map) + values_shared = [ + [e[k] for e in batches] for k in keys_shared] + for k, v in zip(keys_shared, values_shared): + if isinstance(v[0], (dict, Batch)): + self.__dict__[k] = Batch.stack(v, axis) + elif isinstance(v[0], torch.Tensor): + self.__dict__[k] = torch.stack(v, axis) + else: + self.__dict__[k] = np.stack(v, axis) + keys_partial = reduce(set.symmetric_difference, keys_map) + for k in keys_partial: + for i, e in enumerate(batches): + val = e.get(k, None) + if val is not None: + try: + self.__dict__[k][i] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, len(batches)) + self.__dict__[k][i] = val + + @staticmethod + def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': """Stack a :class:`~tianshou.data.Batch` object into a single new batch. """ - assert isinstance(batches, (tuple, list)), \ - 'Only list of Batch instances is allowed to be '\ - 'stacked out-of-place!' - if axis == 0: - return cls(batches) - else: - batch = Batch() - for k, v in zip(batches[0].keys(), - zip(*[e.values() for e in batches])): - if isinstance(v[0], (np.generic, np.ndarray, list)): - batch.__dict__[k] = np.stack(v, axis) - elif isinstance(v[0], torch.Tensor): - batch.__dict__[k] = torch.stack(v, axis) - elif isinstance(v[0], Batch): - batch.__dict__[k] = Batch.stack(v, axis) - else: - s = 'No support for method "stack" with type '\ - f'{type(v[0])} in class Batch and axis != 0.' - raise TypeError(s) - return batch + batch = Batch() + batch.stack_(batches, axis) + return batch def __len__(self) -> int: """Return len(self).""" @@ -409,7 +544,9 @@ class Batch: elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): r.append(len(v)) - return max(1, min(r) if len(r) > 0 else 0) + else: + r.append(1) + return min(r) if len(r) > 0 else 0 def split(self, size: Optional[int] = None, shuffle: bool = True) -> Iterator['Batch']: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 103b582..c050fd5 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,24 +1,7 @@ import numpy as np -from numbers import Number from typing import Any, Tuple, Union, Optional -from .batch import Batch - - -def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: - if isinstance(inst, np.ndarray): - return np.full(shape=(size, *inst.shape), - fill_value=None if inst.dtype == np.inexact else 0, - dtype=inst.dtype) - elif isinstance(inst, (dict, Batch)): - zero_batch = Batch() - for key, val in inst.items(): - zero_batch.__dict__[key] = _create_value(val, size) - return zero_batch - elif isinstance(inst, (np.generic, Number)): - return _create_value(np.asarray(inst), size) - else: # fall back to np.object - return np.array([None for _ in range(size)]) +from .batch import Batch, _create_value class ReplayBuffer: @@ -125,6 +108,7 @@ class ReplayBuffer: return self._size def __repr__(self) -> str: + """Return str(self).""" return self.__class__.__name__ + self._meta.__repr__()[5:] def __getattr__(self, key: str) -> Union['Batch', Any]: