fix Batch init for types other than number and bool (#115)

* fix Batch init for types other than number and bool

* change doc to involve bool type

* use type check

* Batch type check complete
This commit is contained in:
youkaichao 2020-07-08 13:45:29 +08:00 committed by GitHub
parent dbbb859ec5
commit f5e007932f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,9 +45,13 @@ def _valid_bounds(length: int, index: Union[
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
return np.full((size, *inst.shape), return np.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0, fill_value=None if target_type == np.object else 0,
dtype=inst.dtype) dtype=target_type)
elif isinstance(inst, torch.Tensor): elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape), return torch.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0, fill_value=None if inst.dtype == np.object else 0,
@ -87,9 +91,9 @@ class Batch:
In short, you can define a :class:`Batch` with any key-value pair. In short, you can define a :class:`Batch` with any key-value pair.
For Numpy arrays, only data types with ``np.object`` and numbers are For Numpy arrays, only data types with ``np.object``, bool, and number
supported. For strings or other data types, however, they can be held are supported. For strings or other data types, however, they can be
in ``np.object`` arrays. held in ``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`: :class:`~tianshou.data.Batch`:
@ -113,7 +117,7 @@ class Batch:
>>> print(data[0]) >>> print(data[0])
Batch( Batch(
a: Batch( a: Batch(
b: array(['0.0', 'info'], dtype='<U32'), b: array([0.0, 'info'], dtype=object),
), ),
) )
@ -222,7 +226,7 @@ class Batch:
Batch( Batch(
a: array([False, True]), a: array([False, True]),
b: Batch( b: Batch(
c: array([0., 3.]), c: array([None, 'st']),
d: array([0., 0.]), d: array([0., 0.]),
), ),
) )
@ -268,6 +272,8 @@ class Batch:
else: else:
if isinstance(v, list): if isinstance(v, list):
v = np.array(v) v = np.array(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v self.__dict__[k] = v
if len(kwargs) > 0: if len(kwargs) > 0:
self.__init__(kwargs, copy=copy) self.__init__(kwargs, copy=copy)
@ -279,6 +285,8 @@ class Batch:
value = Batch(value) value = Batch(value)
else: else:
value = np.array(value) value = np.array(value)
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
elif isinstance(value, dict): elif isinstance(value, dict):
value = Batch(value) value = Batch(value)
self.__dict__[key] = value self.__dict__[key] = value
@ -319,6 +327,9 @@ class Batch:
index: Union[str, slice, int, np.integer, np.ndarray, List[int]], index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None: value: Any) -> None:
"""Assign value to self[index].""" """Assign value to self[index]."""
if isinstance(value, np.ndarray):
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
if isinstance(index, str): if isinstance(index, str):
self.__dict__[index] = value self.__dict__[index] = value
return return
@ -335,25 +346,24 @@ class Batch:
if isinstance(val, Batch): if isinstance(val, Batch):
self.__dict__[key][index] = Batch() self.__dict__[key][index] = Batch()
elif isinstance(val, np.ndarray) and \ elif isinstance(val, np.ndarray) and \
val.dtype == np.integer: issubclass(val.dtype.type, (np.bool_, np.number)):
# Fallback for np.array of integer,
# since neither None or nan is supported.
self.__dict__[key][index] = 0 self.__dict__[key][index] = 0
else: else:
self.__dict__[key][index] = None self.__dict__[key][index] = None
def __iadd__(self, other: Union['Batch', Number]): def __iadd__(self, other: Union['Batch', Number, np.number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch` """Algebraic addition with another :class:`~tianshou.data.Batch`
instance in-place.""" instance in-place."""
if isinstance(other, Batch): if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(), for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()): other.__dict__.values()):
# TODO are keys consistent?
if r is None: if r is None:
continue continue
else: else:
self.__dict__[k] += v self.__dict__[k] += v
return self return self
elif isinstance(other, Number): elif isinstance(other, (Number, np.number)):
for k, r in self.items(): for k, r in self.items():
if r is None: if r is None:
continue continue
@ -363,33 +373,33 @@ class Batch:
else: else:
raise TypeError("Only addition of Batch or number is supported.") raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Union['Batch', Number]): def __add__(self, other: Union['Batch', Number, np.number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch` """Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place.""" instance out-of-place."""
return deepcopy(self).__iadd__(other) return deepcopy(self).__iadd__(other)
def __imul__(self, val: Number): def __imul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value in-place.""" """Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, Number), \ assert isinstance(val, (Number, np.number)), \
"Only multiplication by a number is supported." "Only multiplication by a number is supported."
for k in self.__dict__.keys(): for k in self.__dict__.keys():
self.__dict__[k] *= val self.__dict__[k] *= val
return self return self
def __mul__(self, val: Number): def __mul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value out-of-place.""" """Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(val) return deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number): def __itruediv__(self, val: Union[Number, np.number]):
"""Algebraic division wibyth a scalar value in-place.""" """Algebraic division with a scalar value in-place."""
assert isinstance(val, Number), \ assert isinstance(val, (Number, np.number)), \
"Only division by a number is supported." "Only division by a number is supported."
for k in self.__dict__.keys(): for k in self.__dict__.keys():
self.__dict__[k] /= val self.__dict__[k] /= val
return self return self
def __truediv__(self, val: Number): def __truediv__(self, val: Union[Number, np.number]):
"""Algebraic division wibyth a scalar value out-of-place.""" """Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(val) return deepcopy(self).__itruediv__(val)
def __repr__(self) -> str: def __repr__(self) -> str:
@ -523,7 +533,10 @@ class Batch:
elif isinstance(v[0], torch.Tensor): elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, axis) self.__dict__[k] = torch.stack(v, axis)
else: else:
self.__dict__[k] = np.stack(v, axis) v = np.stack(v, axis)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = reduce(set.symmetric_difference, keys_map) keys_partial = reduce(set.symmetric_difference, keys_map)
for k in keys_partial: for k in keys_partial:
for i, e in enumerate(batches): for i, e in enumerate(batches):