From f5e007932f8641c1ff5681a3a0928b1440be2c19 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 Jul 2020 13:45:29 +0800 Subject: [PATCH] fix Batch init for types other than number and bool (#115) * fix Batch init for types other than number and bool * change doc to involve bool type * use type check * Batch type check complete --- tianshou/data/batch.py | 57 ++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2fbe431..6fa0d60 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -45,9 +45,13 @@ def _valid_bounds(length: int, index: Union[ def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: if isinstance(inst, np.ndarray): + if issubclass(inst.dtype.type, (np.bool_, np.number)): + target_type = inst.dtype.type + else: + target_type = np.object return np.full((size, *inst.shape), - fill_value=None if inst.dtype == np.object else 0, - dtype=inst.dtype) + fill_value=None if target_type == np.object else 0, + dtype=target_type) elif isinstance(inst, torch.Tensor): return torch.full((size, *inst.shape), fill_value=None if inst.dtype == np.object else 0, @@ -87,9 +91,9 @@ class Batch: In short, you can define a :class:`Batch` with any key-value pair. - For Numpy arrays, only data types with ``np.object`` and numbers are - supported. For strings or other data types, however, they can be held - in ``np.object`` arrays. + For Numpy arrays, only data types with ``np.object``, bool, and number + are supported. For strings or other data types, however, they can be + held in ``np.object`` arrays. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: @@ -113,7 +117,7 @@ class Batch: >>> print(data[0]) Batch( a: Batch( - b: array(['0.0', 'info'], dtype=' 0: self.__init__(kwargs, copy=copy) @@ -279,6 +285,8 @@ class Batch: value = Batch(value) else: value = np.array(value) + if not issubclass(value.dtype.type, (np.bool_, np.number)): + value = value.astype(np.object) elif isinstance(value, dict): value = Batch(value) self.__dict__[key] = value @@ -319,6 +327,9 @@ class Batch: index: Union[str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" + if isinstance(value, np.ndarray): + if not issubclass(value.dtype.type, (np.bool_, np.number)): + value = value.astype(np.object) if isinstance(index, str): self.__dict__[index] = value return @@ -335,25 +346,24 @@ class Batch: if isinstance(val, Batch): self.__dict__[key][index] = Batch() elif isinstance(val, np.ndarray) and \ - val.dtype == np.integer: - # Fallback for np.array of integer, - # since neither None or nan is supported. + issubclass(val.dtype.type, (np.bool_, np.number)): self.__dict__[key][index] = 0 else: self.__dict__[key][index] = None - def __iadd__(self, other: Union['Batch', Number]): + def __iadd__(self, other: Union['Batch', Number, np.number]): """Algebraic addition with another :class:`~tianshou.data.Batch` instance in-place.""" if isinstance(other, Batch): for (k, r), v in zip(self.__dict__.items(), other.__dict__.values()): + # TODO are keys consistent? if r is None: continue else: self.__dict__[k] += v return self - elif isinstance(other, Number): + elif isinstance(other, (Number, np.number)): for k, r in self.items(): if r is None: continue @@ -363,33 +373,33 @@ class Batch: else: raise TypeError("Only addition of Batch or number is supported.") - def __add__(self, other: Union['Batch', Number]): + def __add__(self, other: Union['Batch', Number, np.number]): """Algebraic addition with another :class:`~tianshou.data.Batch` instance out-of-place.""" return deepcopy(self).__iadd__(other) - def __imul__(self, val: Number): + def __imul__(self, val: Union[Number, np.number]): """Algebraic multiplication with a scalar value in-place.""" - assert isinstance(val, Number), \ + assert isinstance(val, (Number, np.number)), \ "Only multiplication by a number is supported." for k in self.__dict__.keys(): self.__dict__[k] *= val return self - def __mul__(self, val: Number): + def __mul__(self, val: Union[Number, np.number]): """Algebraic multiplication with a scalar value out-of-place.""" return deepcopy(self).__imul__(val) - def __itruediv__(self, val: Number): - """Algebraic division wibyth a scalar value in-place.""" - assert isinstance(val, Number), \ + def __itruediv__(self, val: Union[Number, np.number]): + """Algebraic division with a scalar value in-place.""" + assert isinstance(val, (Number, np.number)), \ "Only division by a number is supported." for k in self.__dict__.keys(): self.__dict__[k] /= val return self - def __truediv__(self, val: Number): - """Algebraic division wibyth a scalar value out-of-place.""" + def __truediv__(self, val: Union[Number, np.number]): + """Algebraic division with a scalar value out-of-place.""" return deepcopy(self).__itruediv__(val) def __repr__(self) -> str: @@ -523,7 +533,10 @@ class Batch: elif isinstance(v[0], torch.Tensor): self.__dict__[k] = torch.stack(v, axis) else: - self.__dict__[k] = np.stack(v, axis) + v = np.stack(v, axis) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + self.__dict__[k] = v keys_partial = reduce(set.symmetric_difference, keys_map) for k in keys_partial: for i, e in enumerate(batches):