fix Batch init for types other than number and bool (#115)
* fix Batch init for types other than number and bool * change doc to involve bool type * use type check * Batch type check complete
This commit is contained in:
parent
dbbb859ec5
commit
f5e007932f
@ -45,9 +45,13 @@ def _valid_bounds(length: int, index: Union[
|
|||||||
|
|
||||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
|
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||||
|
target_type = inst.dtype.type
|
||||||
|
else:
|
||||||
|
target_type = np.object
|
||||||
return np.full((size, *inst.shape),
|
return np.full((size, *inst.shape),
|
||||||
fill_value=None if inst.dtype == np.object else 0,
|
fill_value=None if target_type == np.object else 0,
|
||||||
dtype=inst.dtype)
|
dtype=target_type)
|
||||||
elif isinstance(inst, torch.Tensor):
|
elif isinstance(inst, torch.Tensor):
|
||||||
return torch.full((size, *inst.shape),
|
return torch.full((size, *inst.shape),
|
||||||
fill_value=None if inst.dtype == np.object else 0,
|
fill_value=None if inst.dtype == np.object else 0,
|
||||||
@ -87,9 +91,9 @@ class Batch:
|
|||||||
|
|
||||||
In short, you can define a :class:`Batch` with any key-value pair.
|
In short, you can define a :class:`Batch` with any key-value pair.
|
||||||
|
|
||||||
For Numpy arrays, only data types with ``np.object`` and numbers are
|
For Numpy arrays, only data types with ``np.object``, bool, and number
|
||||||
supported. For strings or other data types, however, they can be held
|
are supported. For strings or other data types, however, they can be
|
||||||
in ``np.object`` arrays.
|
held in ``np.object`` arrays.
|
||||||
|
|
||||||
The current implementation of Tianshou typically use 7 reserved keys in
|
The current implementation of Tianshou typically use 7 reserved keys in
|
||||||
:class:`~tianshou.data.Batch`:
|
:class:`~tianshou.data.Batch`:
|
||||||
@ -113,7 +117,7 @@ class Batch:
|
|||||||
>>> print(data[0])
|
>>> print(data[0])
|
||||||
Batch(
|
Batch(
|
||||||
a: Batch(
|
a: Batch(
|
||||||
b: array(['0.0', 'info'], dtype='<U32'),
|
b: array([0.0, 'info'], dtype=object),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -222,7 +226,7 @@ class Batch:
|
|||||||
Batch(
|
Batch(
|
||||||
a: array([False, True]),
|
a: array([False, True]),
|
||||||
b: Batch(
|
b: Batch(
|
||||||
c: array([0., 3.]),
|
c: array([None, 'st']),
|
||||||
d: array([0., 0.]),
|
d: array([0., 0.]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -268,6 +272,8 @@ class Batch:
|
|||||||
else:
|
else:
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = np.array(v)
|
v = np.array(v)
|
||||||
|
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||||
|
v = v.astype(np.object)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.__init__(kwargs, copy=copy)
|
self.__init__(kwargs, copy=copy)
|
||||||
@ -279,6 +285,8 @@ class Batch:
|
|||||||
value = Batch(value)
|
value = Batch(value)
|
||||||
else:
|
else:
|
||||||
value = np.array(value)
|
value = np.array(value)
|
||||||
|
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||||
|
value = value.astype(np.object)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
value = Batch(value)
|
value = Batch(value)
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
@ -319,6 +327,9 @@ class Batch:
|
|||||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||||
value: Any) -> None:
|
value: Any) -> None:
|
||||||
"""Assign value to self[index]."""
|
"""Assign value to self[index]."""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||||
|
value = value.astype(np.object)
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
self.__dict__[index] = value
|
self.__dict__[index] = value
|
||||||
return
|
return
|
||||||
@ -335,25 +346,24 @@ class Batch:
|
|||||||
if isinstance(val, Batch):
|
if isinstance(val, Batch):
|
||||||
self.__dict__[key][index] = Batch()
|
self.__dict__[key][index] = Batch()
|
||||||
elif isinstance(val, np.ndarray) and \
|
elif isinstance(val, np.ndarray) and \
|
||||||
val.dtype == np.integer:
|
issubclass(val.dtype.type, (np.bool_, np.number)):
|
||||||
# Fallback for np.array of integer,
|
|
||||||
# since neither None or nan is supported.
|
|
||||||
self.__dict__[key][index] = 0
|
self.__dict__[key][index] = 0
|
||||||
else:
|
else:
|
||||||
self.__dict__[key][index] = None
|
self.__dict__[key][index] = None
|
||||||
|
|
||||||
def __iadd__(self, other: Union['Batch', Number]):
|
def __iadd__(self, other: Union['Batch', Number, np.number]):
|
||||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||||
instance in-place."""
|
instance in-place."""
|
||||||
if isinstance(other, Batch):
|
if isinstance(other, Batch):
|
||||||
for (k, r), v in zip(self.__dict__.items(),
|
for (k, r), v in zip(self.__dict__.items(),
|
||||||
other.__dict__.values()):
|
other.__dict__.values()):
|
||||||
|
# TODO are keys consistent?
|
||||||
if r is None:
|
if r is None:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] += v
|
self.__dict__[k] += v
|
||||||
return self
|
return self
|
||||||
elif isinstance(other, Number):
|
elif isinstance(other, (Number, np.number)):
|
||||||
for k, r in self.items():
|
for k, r in self.items():
|
||||||
if r is None:
|
if r is None:
|
||||||
continue
|
continue
|
||||||
@ -363,33 +373,33 @@ class Batch:
|
|||||||
else:
|
else:
|
||||||
raise TypeError("Only addition of Batch or number is supported.")
|
raise TypeError("Only addition of Batch or number is supported.")
|
||||||
|
|
||||||
def __add__(self, other: Union['Batch', Number]):
|
def __add__(self, other: Union['Batch', Number, np.number]):
|
||||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||||
instance out-of-place."""
|
instance out-of-place."""
|
||||||
return deepcopy(self).__iadd__(other)
|
return deepcopy(self).__iadd__(other)
|
||||||
|
|
||||||
def __imul__(self, val: Number):
|
def __imul__(self, val: Union[Number, np.number]):
|
||||||
"""Algebraic multiplication with a scalar value in-place."""
|
"""Algebraic multiplication with a scalar value in-place."""
|
||||||
assert isinstance(val, Number), \
|
assert isinstance(val, (Number, np.number)), \
|
||||||
"Only multiplication by a number is supported."
|
"Only multiplication by a number is supported."
|
||||||
for k in self.__dict__.keys():
|
for k in self.__dict__.keys():
|
||||||
self.__dict__[k] *= val
|
self.__dict__[k] *= val
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __mul__(self, val: Number):
|
def __mul__(self, val: Union[Number, np.number]):
|
||||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||||
return deepcopy(self).__imul__(val)
|
return deepcopy(self).__imul__(val)
|
||||||
|
|
||||||
def __itruediv__(self, val: Number):
|
def __itruediv__(self, val: Union[Number, np.number]):
|
||||||
"""Algebraic division wibyth a scalar value in-place."""
|
"""Algebraic division with a scalar value in-place."""
|
||||||
assert isinstance(val, Number), \
|
assert isinstance(val, (Number, np.number)), \
|
||||||
"Only division by a number is supported."
|
"Only division by a number is supported."
|
||||||
for k in self.__dict__.keys():
|
for k in self.__dict__.keys():
|
||||||
self.__dict__[k] /= val
|
self.__dict__[k] /= val
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __truediv__(self, val: Number):
|
def __truediv__(self, val: Union[Number, np.number]):
|
||||||
"""Algebraic division wibyth a scalar value out-of-place."""
|
"""Algebraic division with a scalar value out-of-place."""
|
||||||
return deepcopy(self).__itruediv__(val)
|
return deepcopy(self).__itruediv__(val)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -523,7 +533,10 @@ class Batch:
|
|||||||
elif isinstance(v[0], torch.Tensor):
|
elif isinstance(v[0], torch.Tensor):
|
||||||
self.__dict__[k] = torch.stack(v, axis)
|
self.__dict__[k] = torch.stack(v, axis)
|
||||||
else:
|
else:
|
||||||
self.__dict__[k] = np.stack(v, axis)
|
v = np.stack(v, axis)
|
||||||
|
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||||
|
v = v.astype(np.object)
|
||||||
|
self.__dict__[k] = v
|
||||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
keys_partial = reduce(set.symmetric_difference, keys_map)
|
||||||
for k in keys_partial:
|
for k in keys_partial:
|
||||||
for i, e in enumerate(batches):
|
for i, e in enumerate(batches):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user