Enable partial stacking at Batch level (#100)

* Enable stacking of partially matching Batch instances.

* Fix list support for getitem.

* Fix Batch 'size' method.

* Update Batch documentation.
This commit is contained in:
Alexis DUBURCQ 2020-06-27 03:06:40 +02:00 committed by GitHub
parent 70aa7bf93e
commit a951a32487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 217 additions and 89 deletions

View File

@ -126,6 +126,13 @@ def test_batch_cat_and_stack():
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
def test_batch_over_batch_to_torch():

View File

@ -36,7 +36,7 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert np.all(np.isnan(b.info.b.c[1:]))
assert np.all(b.info.b.c[1:] == 0.0)
def test_ignore_obs_next(size=10):

View File

@ -3,6 +3,7 @@ import copy
import pprint
import warnings
import numpy as np
from functools import reduce
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
@ -42,6 +43,27 @@ def _valid_bounds(length: int, index: Union[
return start_valid and stop_valid
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
return np.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
dtype=inst.dtype)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
else: # fall back to np.object
return np.array([None for _ in range(size)])
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
@ -75,33 +97,133 @@ class Batch:
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
:class:`~tianshou.data.Batch` has other methods, including
:meth:`~tianshou.data.Batch.__getitem__`,
:meth:`~tianshou.data.Batch.__len__`,
:meth:`~tianshou.data.Batch.append`,
and :meth:`~tianshou.data.Batch.split`:
:class:`Batch` object can be initialized using wide variety of arguments,
starting with the key/value pairs or dictionary, but also list and Numpy
arrays of :class:`dict` or Batch instances. In which case, each element
is considered as an individual sample and get stacked together:
::
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
>>> # here we test __getitem__
>>> index = [2, 1]
>>> data[index].obs
array([22, 11])
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
a: Batch(
b: array(['0.0', 'info'], dtype='<U32'),
),
)
>>> # here we test __len__
:class:`Batch` has the same API as a native Python :class:`dict`. In this
regard, one can access to stored data using string key, or iterate over
stored data:
::
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
:class:`Batch` is also reproduce partially the Numpy API for arrays. You
can access or iterate over the individual samples, if any:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
>>> print(data[0])
Batch(
a: np.array([0.0, 2.0])
b: 5
)
>>> for sample in data:
>>> print(sample.a)
[0.0, 2.0]
[1.0, 3.0]
Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> print(np.mean(data))
Batch(
b: 0.0,
a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: [5],
a: array([[0., 2.]]),
),
Batch(
b: [-5],
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
Note that stacking of inconsistent data is also supported. In which case,
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the length and the size of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception, while
the size is 1. The size is only 0 if empty. Note that the size and length
are the identical if multiple samples are stored:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.size
2
>>> len(data)
3
2
>>> data[0].size
1
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
>>> data.append(data) # similar to list.append
>>> data.obs
array([0, 11, 22, 0, 11, 22])
Convenience helpers are available to convert in-place the
stored data into Numpy arrays or Torch tensors.
>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, shuffle=False):
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
[11 22] [6 6]
Finally, note that Batch instance are serializable and therefore Pickle
compatible. This is especially important for distributed sampling.
"""
def __init__(self,
@ -110,18 +232,7 @@ class Batch:
List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None:
if _is_batch_set(batch_dict):
for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])):
if isinstance(v[0], dict) or _is_batch_set(v[0]):
self.__dict__[k] = Batch(v)
elif isinstance(v[0], (np.generic, np.ndarray)):
self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v)
else:
self.__dict__[k] = np.array(v)
self.stack_(batch_dict)
elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v):
@ -160,16 +271,21 @@ class Batch:
f"Index {index} out of bounds for Batch of len {len(self)}.")
else:
b = Batch()
is_index_scalar = isinstance(index, (int, np.integer)) or \
(isinstance(index, np.ndarray) and index.ndim == 0)
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
else:
elif is_index_scalar or not isinstance(v, list):
b.__dict__[k] = v[index]
else:
b.__dict__[k] = [v[i] for i in index]
return b
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(index, str):
self.__dict__[index] = value
return
@ -193,10 +309,12 @@ class Batch:
else:
self.__dict__[key][index] = None
def __iadd__(self, val: Union['Batch', Number]):
if isinstance(val, Batch):
def __iadd__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(),
val.__dict__.values()):
other.__dict__.values()):
if r is None:
continue
elif isinstance(r, list):
@ -204,22 +322,25 @@ class Batch:
else:
self.__dict__[k] += v
return self
elif isinstance(val, Number):
elif isinstance(other, Number):
for k, r in self.items():
if r is None:
continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + val for r_ in r]
self.__dict__[k] = [r_ + other for r_ in r]
else:
self.__dict__[k] += val
self.__dict__[k] += other
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, val: Union['Batch', Number]):
return copy.deepcopy(self).__iadd__(val)
def __add__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return copy.deepcopy(self).__iadd__(other)
def __imul__(self, val: Number):
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, Number), \
"Only multiplication by a number is supported."
for k in self.__dict__.keys():
@ -227,9 +348,11 @@ class Batch:
return self
def __mul__(self, val: Number):
"""Algebraic multiplication with a scalar value out-of-place."""
return copy.deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place."""
assert isinstance(val, Number), \
"Only division by a number is supported."
for k in self.__dict__.keys():
@ -237,6 +360,7 @@ class Batch:
return self
def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place."""
return copy.deepcopy(self).__itruediv__(val)
def __repr__(self) -> str:
@ -319,8 +443,8 @@ class Batch:
return self.cat_(batch)
def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object to current
batch.
"""Concatenate a :class:`~tianshou.data.Batch` object into
current batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!'
@ -347,39 +471,50 @@ class Batch:
"""Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'concatenated out-of-place!'
batch = cls()
for batch_ in batches:
batch.cat_(batch_)
return batch
@classmethod
def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch':
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current
batch.
"""
if len(self.__dict__) > 0:
batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
if isinstance(v[0], (dict, Batch)):
self.__dict__[k] = Batch.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, axis)
else:
self.__dict__[k] = np.stack(v, axis)
keys_partial = reduce(set.symmetric_difference, keys_map)
for k in keys_partial:
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
try:
self.__dict__[k][i] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, len(batches))
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'stacked out-of-place!'
if axis == 0:
return cls(batches)
else:
batch = Batch()
for k, v in zip(batches[0].keys(),
zip(*[e.values() for e in batches])):
if isinstance(v[0], (np.generic, np.ndarray, list)):
batch.__dict__[k] = np.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
batch.__dict__[k] = torch.stack(v, axis)
elif isinstance(v[0], Batch):
batch.__dict__[k] = Batch.stack(v, axis)
else:
s = 'No support for method "stack" with type '\
f'{type(v[0])} in class Batch and axis != 0.'
raise TypeError(s)
return batch
batch = Batch()
batch.stack_(batches, axis)
return batch
def __len__(self) -> int:
"""Return len(self)."""
@ -409,7 +544,9 @@ class Batch:
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
return max(1, min(r) if len(r) > 0 else 0)
else:
r.append(1)
return min(r) if len(r) > 0 else 0
def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']:

View File

@ -1,24 +1,7 @@
import numpy as np
from numbers import Number
from typing import Any, Tuple, Union, Optional
from .batch import Batch
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
return np.full(shape=(size, *inst.shape),
fill_value=None if inst.dtype == np.inexact else 0,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
else: # fall back to np.object
return np.array([None for _ in range(size)])
from .batch import Batch, _create_value
class ReplayBuffer:
@ -125,6 +108,7 @@ class ReplayBuffer:
return self._size
def __repr__(self) -> str:
"""Return str(self)."""
return self.__class__.__name__ + self._meta.__repr__()[5:]
def __getattr__(self, key: str) -> Union['Batch', Any]: