Enable partial stacking at Batch level (#100)
* Enable stacking of partially matching Batch instances. * Fix list support for getitem. * Fix Batch 'size' method. * Update Batch documentation.
This commit is contained in:
parent
70aa7bf93e
commit
a951a32487
@ -126,6 +126,13 @@ def test_batch_cat_and_stack():
|
||||
b34_stack = Batch.stack((b3, b4), axis=1)
|
||||
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
|
||||
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
|
||||
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
|
||||
{'a': True, 'b': {'c': 3.0}}])
|
||||
b5 = Batch(b5_dict)
|
||||
assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
|
||||
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||
assert b5.b.d[1] == 0.0
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
|
@ -36,7 +36,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
||||
assert np.all(b.info.a[1:] == 0)
|
||||
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
||||
assert np.all(np.isnan(b.info.b.c[1:]))
|
||||
assert np.all(b.info.b.c[1:] == 0.0)
|
||||
|
||||
|
||||
def test_ignore_obs_next(size=10):
|
||||
|
@ -3,6 +3,7 @@ import copy
|
||||
import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
from functools import reduce
|
||||
from numbers import Number
|
||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||
|
||||
@ -42,6 +43,27 @@ def _valid_bounds(length: int, index: Union[
|
||||
return start_valid and stop_valid
|
||||
|
||||
|
||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
||||
if isinstance(inst, np.ndarray):
|
||||
return np.full((size, *inst.shape),
|
||||
fill_value=None if inst.dtype == np.object else 0,
|
||||
dtype=inst.dtype)
|
||||
elif isinstance(inst, torch.Tensor):
|
||||
return torch.full((size, *inst.shape),
|
||||
fill_value=None if inst.dtype == np.object else 0,
|
||||
device=inst.device,
|
||||
dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
zero_batch = Batch()
|
||||
for key, val in inst.items():
|
||||
zero_batch.__dict__[key] = _create_value(val, size)
|
||||
return zero_batch
|
||||
elif isinstance(inst, (np.generic, Number)):
|
||||
return _create_value(np.asarray(inst), size)
|
||||
else: # fall back to np.object
|
||||
return np.array([None for _ in range(size)])
|
||||
|
||||
|
||||
class Batch:
|
||||
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||
structure to pass any kind of data to other methods, for example, a
|
||||
@ -75,33 +97,133 @@ class Batch:
|
||||
function return 4 arguments, and the last one is ``info``);
|
||||
* ``policy`` the data computed by policy in step :math:`t`;
|
||||
|
||||
:class:`~tianshou.data.Batch` has other methods, including
|
||||
:meth:`~tianshou.data.Batch.__getitem__`,
|
||||
:meth:`~tianshou.data.Batch.__len__`,
|
||||
:meth:`~tianshou.data.Batch.append`,
|
||||
and :meth:`~tianshou.data.Batch.split`:
|
||||
:class:`Batch` object can be initialized using wide variety of arguments,
|
||||
starting with the key/value pairs or dictionary, but also list and Numpy
|
||||
arrays of :class:`dict` or Batch instances. In which case, each element
|
||||
is considered as an individual sample and get stacked together:
|
||||
::
|
||||
|
||||
>>> data = Batch(obs=np.array([0, 11, 22]), rew=np.array([6, 6, 6]))
|
||||
>>> # here we test __getitem__
|
||||
>>> index = [2, 1]
|
||||
>>> data[index].obs
|
||||
array([22, 11])
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
|
||||
>>> print(data[0])
|
||||
Batch(
|
||||
a: Batch(
|
||||
b: array(['0.0', 'info'], dtype='<U32'),
|
||||
),
|
||||
)
|
||||
|
||||
>>> # here we test __len__
|
||||
:class:`Batch` has the same API as a native Python :class:`dict`. In this
|
||||
regard, one can access to stored data using string key, or iterate over
|
||||
stored data:
|
||||
::
|
||||
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=4, b=[5, 5])
|
||||
>>> print(data["a"])
|
||||
4
|
||||
>>> for key, value in data.items():
|
||||
>>> print(f"{key}: {value}")
|
||||
a: 4
|
||||
b: [5, 5]
|
||||
|
||||
|
||||
:class:`Batch` is also reproduce partially the Numpy API for arrays. You
|
||||
can access or iterate over the individual samples, if any:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
|
||||
>>> print(data[0])
|
||||
Batch(
|
||||
a: np.array([0.0, 2.0])
|
||||
b: 5
|
||||
)
|
||||
>>> for sample in data:
|
||||
>>> print(sample.a)
|
||||
[0.0, 2.0]
|
||||
[1.0, 3.0]
|
||||
|
||||
Similarly, one can also perform simple algebra on it, and stack, split or
|
||||
concatenate multiple instances:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
|
||||
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
|
||||
>>> data = Batch.stack((data_1, data_2))
|
||||
>>> print(data)
|
||||
Batch(
|
||||
b: array([ 5, -5]),
|
||||
a: array([[0., 2.],
|
||||
[1., 3.]]),
|
||||
)
|
||||
>>> print(np.mean(data))
|
||||
Batch(
|
||||
b: 0.0,
|
||||
a: array([0.5, 2.5]),
|
||||
)
|
||||
>>> data_split = list(data.split(1, False))
|
||||
>>> print(list(data.split(1, False)))
|
||||
[Batch(
|
||||
b: [5],
|
||||
a: array([[0., 2.]]),
|
||||
),
|
||||
Batch(
|
||||
b: [-5],
|
||||
a: array([[1., 3.]]),
|
||||
)]
|
||||
>>> data_cat = Batch.cat(data_split)
|
||||
>>> print(data_cat)
|
||||
Batch(
|
||||
b: array([ 5, -5]),
|
||||
a: array([[0., 2.],
|
||||
[1., 3.]]),
|
||||
)
|
||||
|
||||
Note that stacking of inconsistent data is also supported. In which case,
|
||||
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
|
||||
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
|
||||
>>> data = Batch.stack((data_1, data_2))
|
||||
>>> print(data)
|
||||
Batch(
|
||||
a: array([[0., 2.],
|
||||
[1., 3.]]),
|
||||
b: array([None, 'done'], dtype=object),
|
||||
)
|
||||
|
||||
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__`
|
||||
methods are also provided to respectively get the length and the size of
|
||||
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
|
||||
means that getting the length of a scalar Batch raises an exception, while
|
||||
the size is 1. The size is only 0 if empty. Note that the size and length
|
||||
are the identical if multiple samples are stored:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from tianshou.data import Batch
|
||||
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
|
||||
>>> data.size
|
||||
2
|
||||
>>> len(data)
|
||||
3
|
||||
2
|
||||
>>> data[0].size
|
||||
1
|
||||
>>> len(data[0])
|
||||
TypeError: Object of type 'Batch' has no len()
|
||||
|
||||
>>> data.append(data) # similar to list.append
|
||||
>>> data.obs
|
||||
array([0, 11, 22, 0, 11, 22])
|
||||
Convenience helpers are available to convert in-place the
|
||||
stored data into Numpy arrays or Torch tensors.
|
||||
|
||||
>>> # split whole data into multiple small batch
|
||||
>>> for d in data.split(size=2, shuffle=False):
|
||||
... print(d.obs, d.rew)
|
||||
[ 0 11] [6 6]
|
||||
[22 0] [6 6]
|
||||
[11 22] [6 6]
|
||||
Finally, note that Batch instance are serializable and therefore Pickle
|
||||
compatible. This is especially important for distributed sampling.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -110,18 +232,7 @@ class Batch:
|
||||
List[Union[dict, 'Batch']], np.ndarray]] = None,
|
||||
**kwargs) -> None:
|
||||
if _is_batch_set(batch_dict):
|
||||
for k, v in zip(batch_dict[0].keys(),
|
||||
zip(*[e.values() for e in batch_dict])):
|
||||
if isinstance(v[0], dict) or _is_batch_set(v[0]):
|
||||
self.__dict__[k] = Batch(v)
|
||||
elif isinstance(v[0], (np.generic, np.ndarray)):
|
||||
self.__dict__[k] = np.stack(v, axis=0)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
self.__dict__[k] = torch.stack(v, dim=0)
|
||||
elif isinstance(v[0], Batch):
|
||||
self.__dict__[k] = Batch.stack(v)
|
||||
else:
|
||||
self.__dict__[k] = np.array(v)
|
||||
self.stack_(batch_dict)
|
||||
elif isinstance(batch_dict, (dict, Batch)):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, dict) or _is_batch_set(v):
|
||||
@ -160,16 +271,21 @@ class Batch:
|
||||
f"Index {index} out of bounds for Batch of len {len(self)}.")
|
||||
else:
|
||||
b = Batch()
|
||||
is_index_scalar = isinstance(index, (int, np.integer)) or \
|
||||
(isinstance(index, np.ndarray) and index.ndim == 0)
|
||||
for k, v in self.items():
|
||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||
b.__dict__[k] = Batch()
|
||||
else:
|
||||
elif is_index_scalar or not isinstance(v, list):
|
||||
b.__dict__[k] = v[index]
|
||||
else:
|
||||
b.__dict__[k] = [v[i] for i in index]
|
||||
return b
|
||||
|
||||
def __setitem__(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
if isinstance(index, str):
|
||||
self.__dict__[index] = value
|
||||
return
|
||||
@ -193,10 +309,12 @@ class Batch:
|
||||
else:
|
||||
self.__dict__[key][index] = None
|
||||
|
||||
def __iadd__(self, val: Union['Batch', Number]):
|
||||
if isinstance(val, Batch):
|
||||
def __iadd__(self, other: Union['Batch', Number]):
|
||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||
instance in-place."""
|
||||
if isinstance(other, Batch):
|
||||
for (k, r), v in zip(self.__dict__.items(),
|
||||
val.__dict__.values()):
|
||||
other.__dict__.values()):
|
||||
if r is None:
|
||||
continue
|
||||
elif isinstance(r, list):
|
||||
@ -204,22 +322,25 @@ class Batch:
|
||||
else:
|
||||
self.__dict__[k] += v
|
||||
return self
|
||||
elif isinstance(val, Number):
|
||||
elif isinstance(other, Number):
|
||||
for k, r in self.items():
|
||||
if r is None:
|
||||
continue
|
||||
elif isinstance(r, list):
|
||||
self.__dict__[k] = [r_ + val for r_ in r]
|
||||
self.__dict__[k] = [r_ + other for r_ in r]
|
||||
else:
|
||||
self.__dict__[k] += val
|
||||
self.__dict__[k] += other
|
||||
return self
|
||||
else:
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
|
||||
def __add__(self, val: Union['Batch', Number]):
|
||||
return copy.deepcopy(self).__iadd__(val)
|
||||
def __add__(self, other: Union['Batch', Number]):
|
||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||
instance out-of-place."""
|
||||
return copy.deepcopy(self).__iadd__(other)
|
||||
|
||||
def __imul__(self, val: Number):
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
assert isinstance(val, Number), \
|
||||
"Only multiplication by a number is supported."
|
||||
for k in self.__dict__.keys():
|
||||
@ -227,9 +348,11 @@ class Batch:
|
||||
return self
|
||||
|
||||
def __mul__(self, val: Number):
|
||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||
return copy.deepcopy(self).__imul__(val)
|
||||
|
||||
def __itruediv__(self, val: Number):
|
||||
"""Algebraic division wibyth a scalar value in-place."""
|
||||
assert isinstance(val, Number), \
|
||||
"Only division by a number is supported."
|
||||
for k in self.__dict__.keys():
|
||||
@ -237,6 +360,7 @@ class Batch:
|
||||
return self
|
||||
|
||||
def __truediv__(self, val: Number):
|
||||
"""Algebraic division wibyth a scalar value out-of-place."""
|
||||
return copy.deepcopy(self).__itruediv__(val)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -319,8 +443,8 @@ class Batch:
|
||||
return self.cat_(batch)
|
||||
|
||||
def cat_(self, batch: 'Batch') -> None:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object to current
|
||||
batch.
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into
|
||||
current batch.
|
||||
"""
|
||||
assert isinstance(batch, Batch), \
|
||||
'Only Batch is allowed to be concatenated in-place!'
|
||||
@ -347,39 +471,50 @@ class Batch:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""
|
||||
assert isinstance(batches, (tuple, list)), \
|
||||
'Only list of Batch instances is allowed to be '\
|
||||
'concatenated out-of-place!'
|
||||
batch = cls()
|
||||
for batch_ in batches:
|
||||
batch.cat_(batch_)
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch':
|
||||
def stack_(self,
|
||||
batches: List[Union[dict, 'Batch']],
|
||||
axis: int = 0) -> None:
|
||||
"""Stack a :class:`~tianshou.data.Batch` object i into current
|
||||
batch.
|
||||
"""
|
||||
if len(self.__dict__) > 0:
|
||||
batches = [self] + list(batches)
|
||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [
|
||||
[e[k] for e in batches] for k in keys_shared]
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if isinstance(v[0], (dict, Batch)):
|
||||
self.__dict__[k] = Batch.stack(v, axis)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
self.__dict__[k] = torch.stack(v, axis)
|
||||
else:
|
||||
self.__dict__[k] = np.stack(v, axis)
|
||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
val = e.get(k, None)
|
||||
if val is not None:
|
||||
try:
|
||||
self.__dict__[k][i] = val
|
||||
except KeyError:
|
||||
self.__dict__[k] = \
|
||||
_create_value(val, len(batches))
|
||||
self.__dict__[k][i] = val
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""
|
||||
assert isinstance(batches, (tuple, list)), \
|
||||
'Only list of Batch instances is allowed to be '\
|
||||
'stacked out-of-place!'
|
||||
if axis == 0:
|
||||
return cls(batches)
|
||||
else:
|
||||
batch = Batch()
|
||||
for k, v in zip(batches[0].keys(),
|
||||
zip(*[e.values() for e in batches])):
|
||||
if isinstance(v[0], (np.generic, np.ndarray, list)):
|
||||
batch.__dict__[k] = np.stack(v, axis)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
batch.__dict__[k] = torch.stack(v, axis)
|
||||
elif isinstance(v[0], Batch):
|
||||
batch.__dict__[k] = Batch.stack(v, axis)
|
||||
else:
|
||||
s = 'No support for method "stack" with type '\
|
||||
f'{type(v[0])} in class Batch and axis != 0.'
|
||||
raise TypeError(s)
|
||||
return batch
|
||||
batch = Batch()
|
||||
batch.stack_(batches, axis)
|
||||
return batch
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
@ -409,7 +544,9 @@ class Batch:
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
r.append(len(v))
|
||||
return max(1, min(r) if len(r) > 0 else 0)
|
||||
else:
|
||||
r.append(1)
|
||||
return min(r) if len(r) > 0 else 0
|
||||
|
||||
def split(self, size: Optional[int] = None,
|
||||
shuffle: bool = True) -> Iterator['Batch']:
|
||||
|
@ -1,24 +1,7 @@
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
|
||||
from .batch import Batch
|
||||
|
||||
|
||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
||||
if isinstance(inst, np.ndarray):
|
||||
return np.full(shape=(size, *inst.shape),
|
||||
fill_value=None if inst.dtype == np.inexact else 0,
|
||||
dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
zero_batch = Batch()
|
||||
for key, val in inst.items():
|
||||
zero_batch.__dict__[key] = _create_value(val, size)
|
||||
return zero_batch
|
||||
elif isinstance(inst, (np.generic, Number)):
|
||||
return _create_value(np.asarray(inst), size)
|
||||
else: # fall back to np.object
|
||||
return np.array([None for _ in range(size)])
|
||||
from .batch import Batch, _create_value
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
@ -125,6 +108,7 @@ class ReplayBuffer:
|
||||
return self._size
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
return self.__class__.__name__ + self._meta.__repr__()[5:]
|
||||
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user