youkaichao 3a08e27ed4 Standardized behavior of Batch.cat and misc code refactor (#137)
* code refactor; remove unused kwargs; add reward_normalization for dqn

* bugfix for __setitem__ with torch.Tensor; add Batch.condense

* minor fix

* support cat with empty Batch

* remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases

* support stack with empty Batch

* remove condense

* refactor code to reflect the shared / partial / reserved categories of keys

* add is_empty(recursive=False)

* doc fix

* docfix and bugfix for _is_batch_set

* add doc for key reservation

* bugfix for algebra operators

* fix cat with lens hint

* code refactor

* bugfix for storing None

* use ValueError instead of exception

* hide lens away from users

* add comment for __cat

* move the computation of the initial value of lens in cat_ itself.

* change the place of doc string

* doc fix for Batch doc string

* change recursive to recurse

* doc string fix

* minor fix for batch doc
2020-07-20 15:54:18 +08:00

886 lines
34 KiB
Python

import torch
import pprint
import warnings
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")
def _is_batch_set(data: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with np.object type,
# where each element is a dict/Batch object
if isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
return True
elif isinstance(data, np.ndarray) and data.dtype == np.object:
# ``for e in data`` will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
if all(isinstance(e, (dict, Batch)) for e in data):
return True
return False
def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]:
"""
:param bool stack: whether to stack or to concatenate. E.g. if inst has
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
of (10, 3, 5), otherwise (10, 5)
"""
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = \
isinstance(inst, Number) or \
issubclass(inst.__class__, np.generic) or \
(has_shape and not inst.shape)
if not stack and is_scalar:
# here we do not consider scalar types, following the
# behavior of numpy which does not support concatenation
# of zero-dimensional arrays (scalars)
raise TypeError(f"cannot concatenate with {inst} which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
return np.full(shape,
fill_value=None if target_type == np.object else 0,
dtype=target_type)
elif isinstance(inst, torch.Tensor):
return torch.full(shape,
fill_value=0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
return zero_batch
elif is_scalar:
return _create_value(np.asarray(inst), size, stack=stack)
else: # fall back to np.object
return np.array([None for _ in range(size)])
def _assert_type_keys(keys):
keys = list(keys)
assert all(isinstance(e, str) for e in keys), \
f"keys should all be string, but got {keys}"
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
)
In short, you can define a :class:`Batch` with any key-value pair.
For Numpy arrays, only data types with ``np.object``, bool, and number are
supported. For strings or other data types, however, they can be held in
``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
For convenience, :class:`~tianshou.data.Batch` supports the mechanism of
key reservation: one can specify a key without any value, which serves as
a placeholder for the Batch object. For example, you know there will be a
key named ``obs``, but do not know the value until the simulator runs. Then
you can reserve the key ``obs``. This is done by setting the value to
``Batch()``.
For a Batch object, we call it "incomplete" if: (i) it is ``Batch()``; (ii)
it has reserved keys; (iii) any of its sub-Batch is incomplete. Otherwise,
the Batch object is finalized.
Key reservation mechanism is convenient, but also causes some problem in
aggregation operators like ``stack`` or ``cat`` of Batch objects. We say
that Batch objects are compatible for aggregation with three cases:
1. finalized Batch objects are compatible if and only if their exists a \
way to extend keys so that their structures are exactly the same.
2. incomplete Batch objects and other finalized objects are compatible if \
their exists a way to extend keys so that incomplete Batch objects can \
have the same structure as finalized objects.
3. incomplete Batch objects themselevs are compatible if their exists a \
way to extend keys so that their structure can be the same.
In a word, incomplete Batch objects have a set of possible structures
in the future, but finalized Batch object only have a finalized structure.
Batch objects are compatible if and only if they share at least one
commonly possible structure by extending keys.
:class:`~tianshou.data.Batch` object can be initialized by a wide variety
of arguments, ranging from the key/value pairs or dictionary, to list and
Numpy arrays of :class:`dict` or Batch instances where each element is
considered as an individual sample and get stacked together:
::
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
a: Batch(
b: array([0.0, 'info'], dtype=object),
),
)
:class:`~tianshou.data.Batch` has the same API as a native Python
:class:`dict`. In this regard, one can access stored data using string key,
or iterate over stored data:
::
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
>>> for key, value in data.items():
>>> print(f"{key}: {value}")
a: 4
b: [5, 5]
:class:`~tianshou.data.Batch` also partially reproduces the Numpy API for
arrays. It also supports the advanced slicing method, such as batch[:, i],
if the index is valid. You can access or iterate over the individual
samples, if any:
::
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> print(data[0])
Batch(
a: array([0., 2.])
b: array([ 5, -5]),
)
>>> for sample in data:
>>> print(sample.a)
[0. 2.]
>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances:
::
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
>>> print(np.mean(data))
Batch(
b: 0.0,
a: array([0.5, 2.5]),
)
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: array([5]),
a: array([[0., 2.]]),
), Batch(
b: array([-5]),
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
>>> print(data_cat)
Batch(
b: array([ 5, -5]),
a: array([[0., 2.],
[1., 3.]]),
)
Note that stacking of inconsistent data is also supported. In which case,
``None`` is added in list or :class:`np.ndarray` of objects, 0 otherwise.
::
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
>>> print(data)
Batch(
a: array([[0., 2.],
[1., 3.]]),
b: array([None, 'done'], dtype=object),
)
Method ``empty_`` sets elements to 0 or ``None`` for ``np.object``.
::
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([None, 'st']),
d: array([0., 0.]),
),
)
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the shape and the length of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception.
::
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
Convenience helpers are available to convert in-place the stored data into
Numpy arrays or Torch tensors.
Finally, note that :class:`~tianshou.data.Batch` is serializable and
therefore Pickle compatible. This is especially important for distributed
sampling.
"""
def __init__(self,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
copy: bool = False,
**kwargs) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if batch_dict is not None:
if isinstance(batch_dict, (dict, Batch)):
_assert_type_keys(batch_dict.keys())
for k, v in batch_dict.items():
if isinstance(v, (list, tuple, np.ndarray)):
v_ = None
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.stack(v)
continue
else:
v_ = np.asanyarray(v)
if v_.dtype != np.object:
v = v_ # normal data list, this is the main case
if not issubclass(v.dtype.type,
(np.bool_, np.number)):
v = v.astype(np.object)
else:
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# this is actually a data list with objects
v = v_
self.__dict__[k] = v
elif isinstance(v, dict):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = v
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy)
def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
if isinstance(value, list):
if _is_batch_set(value):
value = Batch(value)
else:
value = np.array(value)
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
elif isinstance(value, dict) or isinstance(value, np.ndarray) \
and value.dtype == np.object and _is_batch_set(value):
value = Batch(value)
self.__dict__[key] = value
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized for both
efficiency and simplicity.
"""
state = {}
for k, v in self.items():
if isinstance(v, Batch):
v = v.__getstate__()
state[k] = v
return state
def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch instance
that has not been initialized, so it can safely be initialized by the
pickle state.
"""
self.__init__(**state)
def __getitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
batch_items = self.items()
if len(batch_items) > 0:
b = Batch()
for k, v in batch_items:
if isinstance(v, Batch) and v.is_empty():
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
return b
else:
raise IndexError("Cannot access item from empty Batch object.")
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(value, (list, tuple)):
value = np.asanyarray(value)
if isinstance(value, np.ndarray):
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
if isinstance(index, str):
self.__dict__[index] = value
return
if not isinstance(value, (dict, Batch)):
if _is_batch_set(value):
value = Batch(value)
else:
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
"Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
self.__dict__[key][index] = value[key]
except KeyError:
if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, torch.Tensor) or \
(isinstance(val, np.ndarray) and
issubclass(val.dtype.type, (np.bool_, np.number))):
self.__dict__[key][index] = 0
else:
self.__dict__[key][index] = None
def __iadd__(self, other: Union['Batch', Number, np.number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()):
# TODO are keys consistent?
if isinstance(r, Batch) and r.is_empty():
continue
else:
self.__dict__[k] += v
return self
elif isinstance(other, (Number, np.number)):
for k, r in self.items():
if isinstance(r, Batch) and r.is_empty():
continue
else:
self.__dict__[k] += other
return self
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Union['Batch', Number, np.number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
"Only multiplication by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
continue
self.__dict__[k] *= val
return self
def __mul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(val)
def __itruediv__(self, val: Union[Number, np.number]):
"""Algebraic division with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
"Only division by a number is supported."
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
continue
self.__dict__[k] /= val
return self
def __truediv__(self, val: Union[Number, np.number]):
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(val)
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k, v in self.items():
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(v).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')'
else:
s = self.__class__.__name__ + '()'
return s
def keys(self) -> List[str]:
"""Return self.keys()."""
return self.__dict__.keys()
def values(self) -> List[Any]:
"""Return self.values()."""
return self.__dict__.values()
def items(self) -> List[Tuple[str, Any]]:
"""Return self.items()."""
return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an in-place
operation.
"""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.detach().cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu') -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
operation.
"""
if not isinstance(device, torch.device):
device = torch.device(device)
for k, v in self.items():
if isinstance(v, (np.number, np.bool_, Number, np.ndarray)):
if isinstance(v, (np.number, np.bool_, Number)):
v = np.asanyarray(v)
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index is not None and \
device.index != v.device.index:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)
def __cat(self,
batches: Union['Batch', List[Union[dict, 'Batch']]],
lens: List[int]) -> None:
"""::
>>> a = Batch(a=np.random.randn(3, 4))
>>> x = Batch(a=a, b=np.random.randn(4, 4))
>>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4))
If we want to concatenate x and y, we want to pad y.a.a with zeros.
Without ``lens`` as a hint, when we concatenate x.a and y.a, we would
not be able to know how to pad y.a. So ``Batch.cat_`` should compute
the ``lens`` to give ``Batch.__cat`` a hint.
::
>>> ans = Batch.cat([x, y])
>>> # this is equivalent to the following line
>>> ans = Batch(); ans.__cat([x, y], lens=[3, 4])
>>> # this lens is equal to [len(a), len(b)]
"""
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
if not (isinstance(v, Batch) and v.is_empty()))
for batch in batches]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
batch_holder = Batch()
batch_holder.__cat(v, lens=lens)
self.__dict__[k] = batch_holder
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.cat(v)
else:
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
# will fail here
v = np.concatenate(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_total = set.union(*[set(b.keys()) for b in batches])
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
_assert_type_keys(keys_reserve_or_partial)
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
for k in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
for k in keys_partial:
for i, e in enumerate(batches):
if k not in e.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
continue
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
def cat_(self,
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
into current batch.
"""
if isinstance(batches, Batch):
batches = [batches]
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
# x.is_empty() means that x is Batch() and should be ignored
batches = [x for x in batches if not x.is_empty()]
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if x.is_empty(recurse=True) else len(x)
for x in batches]
except TypeError as e:
e2 = ValueError(
f'Batch.cat_ meets an exception. Maybe because there is '
f'any scalar in {batches} but Batch.cat_ does not support'
f'the concatenation of scalar.')
raise Exception([e, e2])
if not self.is_empty():
batches = [self] + list(batches)
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
return self.__cat(batches, lens)
@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a
single new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros with
appropriate shapes. E.g.
::
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
"""
batch = Batch()
batch.cat_(batches)
return batch
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a list of :class:`~tianshou.data.Batch` object into current
batch.
"""
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if not self.is_empty():
batches = [self] + list(batches)
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
if not (isinstance(v, Batch) and v.is_empty()))
for batch in batches]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
self.__dict__[k] = Batch.stack(v, axis)
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.stack(v, axis)
else:
v = np.stack(v, axis)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that are either partial or reserved
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} "
f"is only supported with axis=0, but got axis={axis}!")
_assert_type_keys(keys_reserve_or_partial)
for k in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
for k in keys_partial:
for i, e in enumerate(batches):
if k not in e.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
continue
try:
self.__dict__[k][i] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, len(batches))
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros. E.g.
::
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)
.. note::
If there are keys that are not shared across all batches, ``stack``
with ``axis != 0`` is undefined, and will cause an exception.
"""
batch = Batch()
batch.stack_(batches, axis)
return batch
def empty_(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled. If ``index`` is specified, it will only reset the
specific indexed-data.
"""
for k, v in self.items():
if v is None:
continue
if isinstance(v, Batch):
self.__dict__[k].empty_(index=index)
elif isinstance(v, torch.Tensor):
self.__dict__[k][index] = 0
elif isinstance(v, np.ndarray):
if v.dtype == np.object:
self.__dict__[k][index] = None
else:
self.__dict__[k][index] = 0
else: # scalar value
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
'which may cause undefined behaviors.')
if isinstance(v, (np.number, np.bool_, Number)):
self.__dict__[k] = v.__class__(0)
else:
self.__dict__[k] = None
return self
@staticmethod
def empty(batch: 'Batch', index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`.
"""
return deepcopy(batch).empty_(index)
def update(self, batch: Optional[Union[dict, 'Batch']] = None,
**kwargs) -> None:
"""Update this batch from another dict/Batch."""
if batch is None:
self.update(kwargs)
return
if isinstance(batch, dict):
batch = Batch(batch)
for k, v in batch.items():
self.__dict__[k] = v
if kwargs:
self.update(kwargs)
def __len__(self) -> int:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True):
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
else:
raise TypeError(f"Object {v} in {self} has no len()")
if len(r) == 0:
raise TypeError(f"Object {self} has no len()")
return min(r)
def is_empty(self, recurse: bool = False):
"""
Test if a Batch is empty. If ``recurse=True``, it further tests the
values of the object; else it only tests the existence of any key.
``b.is_empty(recurse=True)`` is mainly used to distinguish
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
exceptions when applied to ``len()``, but the former can be used in
``cat``, while the latter is a scalar and cannot be used in ``cat``.
Another usage is in ``__len__``, where we have to skip checking the
length of recursely empty Batch.
::
>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
"""
if len(self.__dict__) == 0:
return True
if not recurse:
return False
return all(False if not isinstance(x, Batch)
else x.is_empty(recurse=True) for x in self.values())
@property
def shape(self) -> List[int]:
"""Return self.shape."""
if self.is_empty():
return []
else:
data_shape = []
for v in self.__dict__.values():
try:
data_shape.append(v.shape)
except AttributeError:
raise TypeError("No support for 'shape' method with "
f"type {type(v)} in class Batch.")
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
else data_shape[0]
def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']:
"""Split whole data into multiple small batches.
:param int size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
Default to ``None``.
:param bool shuffle: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``.
"""
length = len(self)
if size is None:
size = length
if shuffle:
indices = np.random.permutation(length)
else:
indices = np.arange(length)
for idx in np.arange(0, length, size):
yield self[indices[idx:(idx + size)]]