n+e db0e2e5cd2
Advanced Batch slicing & minor fix of RNN support (#106)
* add shape property and modify __getitem__

* change Batch.size to Batch.shape

* setattr

* Batch.empty

* remove scalar in advanced slicing

* modify empty_ and __getitem__

* missing testcase

* fix empty
2020-06-30 18:02:44 +08:00

619 lines
22 KiB
Python

import torch
import copy
import pprint
import warnings
import numpy as np
from functools import reduce
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")
def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
return True
elif isinstance(data, np.ndarray):
if isinstance(data.item(0), (dict, Batch)):
return True
return False
def _valid_bounds(length: int, index: Union[
slice, int, np.integer, np.ndarray, List[int]]) -> bool:
if isinstance(index, (int, np.integer)):
return -length <= index and index < length
elif isinstance(index, (list, np.ndarray)):
return _valid_bounds(length, np.min(index)) and \
_valid_bounds(length, np.max(index))
elif isinstance(index, slice):
if index.start is not None:
start_valid = _valid_bounds(length, index.start)
else:
start_valid = True
if index.stop is not None:
stop_valid = _valid_bounds(length, index.stop - 1)
else:
stop_valid = True
return start_valid and stop_valid
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
return np.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
dtype=inst.dtype)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
else: # fall back to np.object
return np.array([None for _ in range(size)])
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
)
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
:class:`~tianshou.data.Batch` object can be initialized using wide variety
of arguments, starting with the key/value pairs or dictionary, but also
list and Numpy arrays of :class:`dict` or Batch instances. In which case,
each element is considered as an individual sample and get stacked
together:
::
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
a: Batch(
b: array(['0.0', 'info'], dtype='<U32'),
),
)
:class:`~tianshou.data.Batch` has the same API as a native Python
:class:`dict`. In this regard, one can access to stored data using string
key, or iterate over stored data:
::
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
arrays. It also supports the advanced slicing method, such as batch[:, i],
if the index is valid. You can access or iterate over the individual
samples, if any:
::
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> print(data[0])
Batch(
a: array([0., 2.])
b: array([ 5, -5]),
)
>>> for sample in data:
>>> print(sample.a)
[0., 2.]
[1., 3.]
>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances:
::
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> print(np.mean(data))
Batch(
b: 0.0,
a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: array([5]),
a: array([[0., 2.]]),
), Batch(
b: array([-5]),
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
Note that stacking of inconsistent data is also supported. In which case,
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
::
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
Also with method empty (which will set to 0 or ``None`` (with np.object))
::
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([0., 3.]),
d: array([0., 0.]),
),
)
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the shape and the length of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception.
::
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
Convenience helpers are available to convert in-place the stored data into
Numpy arrays or Torch tensors.
Finally, note that :class:`~tianshou.data.Batch` instance are serializable
and therefore Pickle compatible. This is especially important for
distributed sampling.
"""
def __init__(self,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
**kwargs) -> None:
if _is_batch_set(batch_dict):
self.stack_(batch_dict)
elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v)
else:
if isinstance(v, list):
v = np.array(v)
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
if isinstance(value, list):
if _is_batch_set(value):
value = Batch(value)
else:
value = np.array(value)
elif isinstance(value, dict):
value = Batch(value)
self.__dict__[key] = value
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized for both
efficiency and simplicity.
"""
state = {}
for k, v in self.items():
if isinstance(v, Batch):
v = v.__getstate__()
state[k] = v
return state
def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch instance
that has not been initialized, so it can safely be initialized by the
pickle state.
"""
self.__init__(**state)
def __getitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
b = Batch()
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
return b
def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(index, str):
self.__dict__[index] = value
return
if not isinstance(value, (dict, Batch)):
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
"Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
self.__dict__[key][index] = value[key]
except KeyError:
if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, np.ndarray) and \
val.dtype == np.integer:
# Fallback for np.array of integer,
# since neither None or nan is supported.
self.__dict__[key][index] = 0
else:
self.__dict__[key][index] = None
def __iadd__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()):
if r is None:
continue
else:
self.__dict__[k] += v
return self
elif isinstance(other, Number):
for k, r in self.items():
if r is None:
continue
else:
self.__dict__[k] += other
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return copy.deepcopy(self).__iadd__(other)
def __imul__(self, val: Number):
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, Number), \
"Only multiplication by a number is supported."
for k in self.__dict__.keys():
self.__dict__[k] *= val
return self
def __mul__(self, val: Number):
"""Algebraic multiplication with a scalar value out-of-place."""
return copy.deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place."""
assert isinstance(val, Number), \
"Only division by a number is supported."
for k in self.__dict__.keys():
self.__dict__[k] /= val
return self
def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place."""
return copy.deepcopy(self).__itruediv__(val)
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k, v in self.items():
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(v).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')'
else:
s = self.__class__.__name__ + '()'
return s
def keys(self) -> List[str]:
"""Return self.keys()."""
return self.__dict__.keys()
def values(self) -> List[Any]:
"""Return self.values()."""
return self.__dict__.values()
def items(self) -> List[Tuple[str, Any]]:
"""Return self.items()."""
return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
operation.
"""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
operation.
"""
if not isinstance(device, torch.device):
device = torch.device(device)
for k, v in self.items():
if isinstance(v, (np.generic, np.ndarray)):
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype:
must_update_tensor = True
elif v.device.type != device.type:
must_update_tensor = True
elif device.index is not None and \
device.index != v.device.index:
must_update_tensor = True
else:
must_update_tensor = False
if must_update_tensor:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None:
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
'removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat_(batch)
def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into current
batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.items():
if v is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = copy.deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, Batch):
self.__dict__[k].cat_(v)
else:
s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.'
raise TypeError(s)
@staticmethod
def cat(batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
new batch.
"""
batch = Batch()
for batch_ in batches:
batch.cat_(batch_)
return batch
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
"""
if len(self.__dict__) > 0:
batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
if isinstance(v[0], (dict, Batch)):
self.__dict__[k] = Batch.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, axis)
else:
self.__dict__[k] = np.stack(v, axis)
keys_partial = reduce(set.symmetric_difference, keys_map)
for k in keys_partial:
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
try:
self.__dict__[k][i] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, len(batches))
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a single new
batch.
"""
batch = Batch()
batch.stack_(batches, axis)
return batch
def empty_(self) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled.
"""
for k, v in self.items():
if v is None:
continue
if isinstance(v, Batch):
self.__dict__[k].empty_()
elif isinstance(v, np.ndarray) and v.dtype == np.object:
self.__dict__[k].fill(None)
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
else: # np
self.__dict__[k] *= 0
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
return self
@staticmethod
def empty(batch: 'Batch') -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`.
"""
batch = Batch(**batch)
batch.empty_()
return batch
def __len__(self) -> int:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and len(v.__dict__) == 0:
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
else:
raise TypeError("Object of type 'Batch' has no len()")
if len(r) == 0:
raise TypeError("Object of type 'Batch' has no len()")
return min(r)
@property
def shape(self) -> List[int]:
"""Return self.shape."""
if len(self.__dict__.keys()) == 0:
return []
else:
data_shape = []
for v in self.__dict__.values():
try:
data_shape.append(v.shape)
except AttributeError:
raise TypeError("No support for 'shape' method with "
f"type {type(v)} in class Batch.")
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
else data_shape[0]
def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']:
"""Split whole data into multiple small batch.
:param int size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
Default to ``None``.
:param bool shuffle: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``.
"""
length = len(self)
if size is None:
size = length
if shuffle:
indices = np.random.permutation(length)
else:
indices = np.arange(length)
for idx in np.arange(0, length, size):
yield self[indices[idx:(idx + size)]]