import torch import copy import pprint import warnings import numpy as np from functools import reduce from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional # Disable pickle warning related to torch, since it has been removed # on torch master branch. See Pull Request #39003 for details: # https://github.com/pytorch/pytorch/pull/39003 warnings.filterwarnings( "ignore", message="pickle support for Storage will be removed in 1.5.") def _is_batch_set(data: Any) -> bool: if isinstance(data, (list, tuple)): if len(data) > 0 and isinstance(data[0], (dict, Batch)): return True elif isinstance(data, np.ndarray): if isinstance(data.item(0), (dict, Batch)): return True return False def _valid_bounds(length: int, index: Union[ slice, int, np.integer, np.ndarray, List[int]]) -> bool: if isinstance(index, (int, np.integer)): return -length <= index and index < length elif isinstance(index, (list, np.ndarray)): return _valid_bounds(length, np.min(index)) and \ _valid_bounds(length, np.max(index)) elif isinstance(index, slice): if index.start is not None: start_valid = _valid_bounds(length, index.start) else: start_valid = True if index.stop is not None: stop_valid = _valid_bounds(length, index.stop - 1) else: stop_valid = True return start_valid and stop_valid def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: if isinstance(inst, np.ndarray): return np.full((size, *inst.shape), fill_value=None if inst.dtype == np.object else 0, dtype=inst.dtype) elif isinstance(inst, torch.Tensor): return torch.full((size, *inst.shape), fill_value=None if inst.dtype == np.object else 0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): zero_batch.__dict__[key] = _create_value(val, size) return zero_batch elif isinstance(inst, (np.generic, Number)): return _create_value(np.asarray(inst), size) else: # fall back to np.object return np.array([None for _ in range(size)]) class Batch: """Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Here is the usage: :: >>> import numpy as np >>> from tianshou.data import Batch >>> data = Batch(a=4, b=[5, 5], c='2312312') >>> data.b [5, 5] >>> data.b = np.array([3, 4, 5]) >>> print(data) Batch( a: 4, b: array([3, 4, 5]), c: '2312312', ) In short, you can define a :class:`Batch` with any key-value pair. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; * ``act`` the action of step :math:`t` ; * ``rew`` the reward of step :math:`t` ; * ``done`` the done flag of step :math:`t` ; * ``obs_next`` the observation of step :math:`t+1` ; * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\ function return 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; :class:`~tianshou.data.Batch` object can be initialized using wide variety of arguments, starting with the key/value pairs or dictionary, but also list and Numpy arrays of :class:`dict` or Batch instances. In which case, each element is considered as an individual sample and get stacked together: :: >>> import numpy as np >>> from tianshou.data import Batch >>> data = Batch([{'a': {'b': [0.0, "info"]}}]) >>> print(data[0]) Batch( a: Batch( b: array(['0.0', 'info'], dtype='>> from tianshou.data import Batch >>> data = Batch(a=4, b=[5, 5]) >>> print(data["a"]) 4 >>> for key, value in data.items(): >>> print(f"{key}: {value}") a: 4 b: [5, 5] :class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for arrays. You can access or iterate over the individual samples, if any: :: >>> import numpy as np >>> from tianshou.data import Batch >>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5]) >>> print(data[0]) Batch( a: np.array([0.0, 2.0]) b: 5 ) >>> for sample in data: >>> print(sample.a) [0.0, 2.0] [1.0, 3.0] Similarly, one can also perform simple algebra on it, and stack, split or concatenate multiple instances: :: >>> import numpy as np >>> from tianshou.data import Batch >>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5) >>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5) >>> data = Batch.stack((data_1, data_2)) >>> print(data) Batch( b: array([ 5, -5]), a: array([[0., 2.], [1., 3.]]), ) >>> print(np.mean(data)) Batch( b: 0.0, a: array([0.5, 2.5]), ) >>> data_split = list(data.split(1, False)) >>> print(list(data.split(1, False))) [Batch( b: [5], a: array([[0., 2.]]), ), Batch( b: [-5], a: array([[1., 3.]]), )] >>> data_cat = Batch.cat(data_split) >>> print(data_cat) Batch( b: array([ 5, -5]), a: array([[0., 2.], [1., 3.]]), ) Note that stacking of inconsistent data is also supported. In which case, None is added in list or :class:`np.ndarray` of objects, 0 otherwise. :: >>> import numpy as np >>> from tianshou.data import Batch >>> data_1 = Batch(a=np.array([0.0, 2.0])) >>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done') >>> data = Batch.stack((data_1, data_2)) >>> print(data) Batch( a: array([[0., 2.], [1., 3.]]), b: array([None, 'done'], dtype=object), ) :meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__` methods are also provided to respectively get the length and the size of a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which means that getting the length of a scalar Batch raises an exception, while the size is 1. The size is only 0 if empty. Note that the size and length are the identical if multiple samples are stored: :: >>> import numpy as np >>> from tianshou.data import Batch >>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4))) >>> data.size 2 >>> len(data) 2 >>> data[0].size 1 >>> len(data[0]) TypeError: Object of type 'Batch' has no len() Convenience helpers are available to convert in-place the stored data into Numpy arrays or Torch tensors. Finally, note that :class:`~tianshou.data.Batch` instance are serializable and therefore Pickle compatible. This is especially important for distributed sampling. """ def __init__(self, batch_dict: Optional[Union[ dict, 'Batch', Tuple[Union[dict, 'Batch']], List[Union[dict, 'Batch']], np.ndarray]] = None, **kwargs) -> None: if _is_batch_set(batch_dict): self.stack_(batch_dict) elif isinstance(batch_dict, (dict, Batch)): for k, v in batch_dict.items(): if isinstance(v, dict) or _is_batch_set(v): self.__dict__[k] = Batch(v) else: self.__dict__[k] = v if len(kwargs) > 0: self.__init__(kwargs) def __getstate__(self): """Pickling interface. Only the actual data are serialized for both efficiency and simplicity. """ state = {} for k, v in self.items(): if isinstance(v, Batch): v = v.__getstate__() state[k] = v return state def __setstate__(self, state): """Unpickling interface. At this point, self is an empty Batch instance that has not been initialized, so it can safely be initialized by the pickle state. """ self.__init__(**state) def __getitem__(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch': """Return self[index].""" if isinstance(index, str): return self.__dict__[index] if not _valid_bounds(len(self), index): raise IndexError( f"Index {index} out of bounds for Batch of len {len(self)}.") else: b = Batch() is_index_scalar = isinstance(index, (int, np.integer)) or \ (isinstance(index, np.ndarray) and index.ndim == 0) for k, v in self.items(): if isinstance(v, Batch) and len(v.__dict__) == 0: b.__dict__[k] = Batch() elif is_index_scalar or not isinstance(v, list): b.__dict__[k] = v[index] else: b.__dict__[k] = [v[i] for i in index] return b def __setitem__(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" if isinstance(index, str): self.__dict__[index] = value return if not isinstance(value, (dict, Batch)): raise TypeError("Batch does not supported value type " f"{type(value)} for item assignment.") if not set(value.keys()).issubset(self.__dict__.keys()): raise KeyError( "Creating keys is not supported by item assignment.") for key, val in self.items(): try: self.__dict__[key][index] = value[key] except KeyError: if isinstance(val, Batch): self.__dict__[key][index] = Batch() elif isinstance(val, np.ndarray) and \ val.dtype == np.integer: # Fallback for np.array of integer, # since neither None or nan is supported. self.__dict__[key][index] = 0 else: self.__dict__[key][index] = None def __iadd__(self, other: Union['Batch', Number]): """Algebraic addition with another :class:`~tianshou.data.Batch` instance in-place.""" if isinstance(other, Batch): for (k, r), v in zip(self.__dict__.items(), other.__dict__.values()): if r is None: continue elif isinstance(r, list): self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)] else: self.__dict__[k] += v return self elif isinstance(other, Number): for k, r in self.items(): if r is None: continue elif isinstance(r, list): self.__dict__[k] = [r_ + other for r_ in r] else: self.__dict__[k] += other return self else: raise TypeError("Only addition of Batch or number is supported.") def __add__(self, other: Union['Batch', Number]): """Algebraic addition with another :class:`~tianshou.data.Batch` instance out-of-place.""" return copy.deepcopy(self).__iadd__(other) def __imul__(self, val: Number): """Algebraic multiplication with a scalar value in-place.""" assert isinstance(val, Number), \ "Only multiplication by a number is supported." for k in self.__dict__.keys(): self.__dict__[k] *= val return self def __mul__(self, val: Number): """Algebraic multiplication with a scalar value out-of-place.""" return copy.deepcopy(self).__imul__(val) def __itruediv__(self, val: Number): """Algebraic division wibyth a scalar value in-place.""" assert isinstance(val, Number), \ "Only division by a number is supported." for k in self.__dict__.keys(): self.__dict__[k] /= val return self def __truediv__(self, val: Number): """Algebraic division wibyth a scalar value out-of-place.""" return copy.deepcopy(self).__itruediv__(val) def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False for k, v in self.items(): rpl = '\n' + ' ' * (6 + len(k)) obj = pprint.pformat(v).replace('\n', rpl) s += f' {k}: {obj},\n' flag = True if flag: s += ')' else: s = self.__class__.__name__ + '()' return s def keys(self) -> List[str]: """Return self.keys().""" return self.__dict__.keys() def values(self) -> List[Any]: """Return self.values().""" return self.__dict__.values() def items(self) -> List[Tuple[str, Any]]: """Return self.items().""" return self.__dict__.items() def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: """Return self[k] if k in self else d. d defaults to None.""" return self.__dict__.get(k, d) def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray. This is an in-place operation. """ for k, v in self.items(): if isinstance(v, torch.Tensor): self.__dict__[k] = v.detach().cpu().numpy() elif isinstance(v, Batch): v.to_numpy() def to_torch(self, dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu' ) -> None: """Change all numpy.ndarray to torch.Tensor. This is an in-place operation. """ if not isinstance(device, torch.device): device = torch.device(device) for k, v in self.items(): if isinstance(v, (np.generic, np.ndarray)): v = torch.from_numpy(v).to(device) if dtype is not None: v = v.type(dtype) self.__dict__[k] = v if isinstance(v, torch.Tensor): if dtype is not None and v.dtype != dtype: must_update_tensor = True elif v.device.type != device.type: must_update_tensor = True elif device.index is not None and \ device.index != v.device.index: must_update_tensor = True else: must_update_tensor = False if must_update_tensor: if dtype is not None: v = v.type(dtype) self.__dict__[k] = v.to(device) elif isinstance(v, Batch): v.to_torch(dtype, device) def append(self, batch: 'Batch') -> None: warnings.warn('Method append will be removed soon, please use ' ':meth:`~tianshou.data.Batch.cat`') return self.cat_(batch) def cat_(self, batch: 'Batch') -> None: """Concatenate a :class:`~tianshou.data.Batch` object into current batch. """ assert isinstance(batch, Batch), \ 'Only Batch is allowed to be concatenated in-place!' for k, v in batch.items(): if v is None: continue if not hasattr(self, k) or self.__dict__[k] is None: self.__dict__[k] = copy.deepcopy(v) elif isinstance(v, np.ndarray) and v.ndim > 0: self.__dict__[k] = np.concatenate([self.__dict__[k], v]) elif isinstance(v, torch.Tensor): self.__dict__[k] = torch.cat([self.__dict__[k], v]) elif isinstance(v, list): self.__dict__[k] += copy.deepcopy(v) elif isinstance(v, Batch): self.__dict__[k].cat_(v) else: s = 'No support for method "cat" with type '\ f'{type(v)} in class Batch.' raise TypeError(s) @classmethod def cat(cls, batches: List['Batch']) -> 'Batch': """Concatenate a :class:`~tianshou.data.Batch` object into a single new batch. """ batch = cls() for batch_ in batches: batch.cat_(batch_) return batch def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: """Stack a :class:`~tianshou.data.Batch` object i into current batch. """ if len(self.__dict__) > 0: batches = [self] + list(batches) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) values_shared = [ [e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): if isinstance(v[0], (dict, Batch)): self.__dict__[k] = Batch.stack(v, axis) elif isinstance(v[0], torch.Tensor): self.__dict__[k] = torch.stack(v, axis) else: self.__dict__[k] = np.stack(v, axis) keys_partial = reduce(set.symmetric_difference, keys_map) for k in keys_partial: for i, e in enumerate(batches): val = e.get(k, None) if val is not None: try: self.__dict__[k][i] = val except KeyError: self.__dict__[k] = \ _create_value(val, len(batches)) self.__dict__[k][i] = val @staticmethod def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': """Stack a :class:`~tianshou.data.Batch` object into a single new batch. """ batch = Batch() batch.stack_(batches, axis) return batch def __len__(self) -> int: """Return len(self).""" r = [] for v in self.__dict__.values(): if isinstance(v, Batch) and len(v.__dict__) == 0: continue elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): r.append(len(v)) else: raise TypeError("Object of type 'Batch' has no len()") if len(r) == 0: raise TypeError("Object of type 'Batch' has no len()") return min(r) @property def size(self) -> int: """Return self.size.""" if len(self.__dict__.keys()) == 0: return 0 else: r = [] for v in self.__dict__.values(): if isinstance(v, Batch): r.append(v.size) elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): r.append(len(v)) else: r.append(1) return min(r) if len(r) > 0 else 0 def split(self, size: Optional[int] = None, shuffle: bool = True) -> Iterator['Batch']: """Split whole data into multiple small batch. :param int size: if it is ``None``, it does not split the data batch; otherwise it will divide the data batch with the given size. Default to ``None``. :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. """ length = len(self) if size is None: size = length if shuffle: indices = np.random.permutation(length) else: indices = np.arange(length) for idx in np.arange(0, length, size): yield self[indices[idx:(idx + size)]]