fix Batch init for types other than number and bool (#115)

* fix Batch init for types other than number and bool

* change doc to involve bool type

* use type check

* Batch type check complete
This commit is contained in:
youkaichao 2020-07-08 13:45:29 +08:00 committed by GitHub
parent dbbb859ec5
commit f5e007932f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,9 +45,13 @@ def _valid_bounds(length: int, index: Union[
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
return np.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
dtype=inst.dtype)
fill_value=None if target_type == np.object else 0,
dtype=target_type)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
@ -87,9 +91,9 @@ class Batch:
In short, you can define a :class:`Batch` with any key-value pair.
For Numpy arrays, only data types with ``np.object`` and numbers are
supported. For strings or other data types, however, they can be held
in ``np.object`` arrays.
For Numpy arrays, only data types with ``np.object``, bool, and number
are supported. For strings or other data types, however, they can be
held in ``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
@ -113,7 +117,7 @@ class Batch:
>>> print(data[0])
Batch(
a: Batch(
b: array(['0.0', 'info'], dtype='<U32'),
b: array([0.0, 'info'], dtype=object),
),
)
@ -222,7 +226,7 @@ class Batch:
Batch(
a: array([False, True]),
b: Batch(
c: array([0., 3.]),
c: array([None, 'st']),
d: array([0., 0.]),
),
)
@ -268,6 +272,8 @@ class Batch:
else:
if isinstance(v, list):
v = np.array(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy)
@ -279,6 +285,8 @@ class Batch:
value = Batch(value)
else:
value = np.array(value)
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
elif isinstance(value, dict):
value = Batch(value)
self.__dict__[key] = value
@ -319,6 +327,9 @@ class Batch:
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(value, np.ndarray):
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
if isinstance(index, str):
self.__dict__[index] = value
return
@ -335,25 +346,24 @@ class Batch:
if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, np.ndarray) and \
val.dtype == np.integer:
# Fallback for np.array of integer,
# since neither None or nan is supported.
issubclass(val.dtype.type, (np.bool_, np.number)):
self.__dict__[key][index] = 0
else:
self.__dict__[key][index] = None
def __iadd__(self, other: Union['Batch', Number]):
def __iadd__(self, other: Union['Batch', Number, np.number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()):
# TODO are keys consistent?
if r is None:
continue
else:
self.__dict__[k] += v
return self
elif isinstance(other, Number):
elif isinstance(other, (Number, np.number)):
for k, r in self.items():
if r is None:
continue
@ -363,33 +373,33 @@ class Batch:
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Union['Batch', Number]):
def __add__(self, other: Union['Batch', Number, np.number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self, val: Number):
def __imul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, Number), \
assert isinstance(val, (Number, np.number)), \
"Only multiplication by a number is supported."
for k in self.__dict__.keys():
self.__dict__[k] *= val
return self
def __mul__(self, val: Number):
def __mul__(self, val: Union[Number, np.number]):
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place."""
assert isinstance(val, Number), \
def __itruediv__(self, val: Union[Number, np.number]):
"""Algebraic division with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
"Only division by a number is supported."
for k in self.__dict__.keys():
self.__dict__[k] /= val
return self
def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place."""
def __truediv__(self, val: Union[Number, np.number]):
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(val)
def __repr__(self) -> str:
@ -523,7 +533,10 @@ class Batch:
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, axis)
else:
self.__dict__[k] = np.stack(v, axis)
v = np.stack(v, axis)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = reduce(set.symmetric_difference, keys_map)
for k in keys_partial:
for i, e in enumerate(batches):