fix Batch init for types other than number and bool (#115)
* fix Batch init for types other than number and bool * change doc to involve bool type * use type check * Batch type check complete
This commit is contained in:
parent
dbbb859ec5
commit
f5e007932f
@ -45,9 +45,13 @@ def _valid_bounds(length: int, index: Union[
|
||||
|
||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
||||
if isinstance(inst, np.ndarray):
|
||||
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||
target_type = inst.dtype.type
|
||||
else:
|
||||
target_type = np.object
|
||||
return np.full((size, *inst.shape),
|
||||
fill_value=None if inst.dtype == np.object else 0,
|
||||
dtype=inst.dtype)
|
||||
fill_value=None if target_type == np.object else 0,
|
||||
dtype=target_type)
|
||||
elif isinstance(inst, torch.Tensor):
|
||||
return torch.full((size, *inst.shape),
|
||||
fill_value=None if inst.dtype == np.object else 0,
|
||||
@ -87,9 +91,9 @@ class Batch:
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair.
|
||||
|
||||
For Numpy arrays, only data types with ``np.object`` and numbers are
|
||||
supported. For strings or other data types, however, they can be held
|
||||
in ``np.object`` arrays.
|
||||
For Numpy arrays, only data types with ``np.object``, bool, and number
|
||||
are supported. For strings or other data types, however, they can be
|
||||
held in ``np.object`` arrays.
|
||||
|
||||
The current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
@ -113,7 +117,7 @@ class Batch:
|
||||
>>> print(data[0])
|
||||
Batch(
|
||||
a: Batch(
|
||||
b: array(['0.0', 'info'], dtype='<U32'),
|
||||
b: array([0.0, 'info'], dtype=object),
|
||||
),
|
||||
)
|
||||
|
||||
@ -222,7 +226,7 @@ class Batch:
|
||||
Batch(
|
||||
a: array([False, True]),
|
||||
b: Batch(
|
||||
c: array([0., 3.]),
|
||||
c: array([None, 'st']),
|
||||
d: array([0., 0.]),
|
||||
),
|
||||
)
|
||||
@ -268,6 +272,8 @@ class Batch:
|
||||
else:
|
||||
if isinstance(v, list):
|
||||
v = np.array(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs, copy=copy)
|
||||
@ -279,6 +285,8 @@ class Batch:
|
||||
value = Batch(value)
|
||||
else:
|
||||
value = np.array(value)
|
||||
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||
value = value.astype(np.object)
|
||||
elif isinstance(value, dict):
|
||||
value = Batch(value)
|
||||
self.__dict__[key] = value
|
||||
@ -319,6 +327,9 @@ class Batch:
|
||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
if isinstance(value, np.ndarray):
|
||||
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||
value = value.astype(np.object)
|
||||
if isinstance(index, str):
|
||||
self.__dict__[index] = value
|
||||
return
|
||||
@ -335,25 +346,24 @@ class Batch:
|
||||
if isinstance(val, Batch):
|
||||
self.__dict__[key][index] = Batch()
|
||||
elif isinstance(val, np.ndarray) and \
|
||||
val.dtype == np.integer:
|
||||
# Fallback for np.array of integer,
|
||||
# since neither None or nan is supported.
|
||||
issubclass(val.dtype.type, (np.bool_, np.number)):
|
||||
self.__dict__[key][index] = 0
|
||||
else:
|
||||
self.__dict__[key][index] = None
|
||||
|
||||
def __iadd__(self, other: Union['Batch', Number]):
|
||||
def __iadd__(self, other: Union['Batch', Number, np.number]):
|
||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||
instance in-place."""
|
||||
if isinstance(other, Batch):
|
||||
for (k, r), v in zip(self.__dict__.items(),
|
||||
other.__dict__.values()):
|
||||
# TODO are keys consistent?
|
||||
if r is None:
|
||||
continue
|
||||
else:
|
||||
self.__dict__[k] += v
|
||||
return self
|
||||
elif isinstance(other, Number):
|
||||
elif isinstance(other, (Number, np.number)):
|
||||
for k, r in self.items():
|
||||
if r is None:
|
||||
continue
|
||||
@ -363,33 +373,33 @@ class Batch:
|
||||
else:
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
|
||||
def __add__(self, other: Union['Batch', Number]):
|
||||
def __add__(self, other: Union['Batch', Number, np.number]):
|
||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||
instance out-of-place."""
|
||||
return deepcopy(self).__iadd__(other)
|
||||
|
||||
def __imul__(self, val: Number):
|
||||
def __imul__(self, val: Union[Number, np.number]):
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
assert isinstance(val, Number), \
|
||||
assert isinstance(val, (Number, np.number)), \
|
||||
"Only multiplication by a number is supported."
|
||||
for k in self.__dict__.keys():
|
||||
self.__dict__[k] *= val
|
||||
return self
|
||||
|
||||
def __mul__(self, val: Number):
|
||||
def __mul__(self, val: Union[Number, np.number]):
|
||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||
return deepcopy(self).__imul__(val)
|
||||
|
||||
def __itruediv__(self, val: Number):
|
||||
"""Algebraic division wibyth a scalar value in-place."""
|
||||
assert isinstance(val, Number), \
|
||||
def __itruediv__(self, val: Union[Number, np.number]):
|
||||
"""Algebraic division with a scalar value in-place."""
|
||||
assert isinstance(val, (Number, np.number)), \
|
||||
"Only division by a number is supported."
|
||||
for k in self.__dict__.keys():
|
||||
self.__dict__[k] /= val
|
||||
return self
|
||||
|
||||
def __truediv__(self, val: Number):
|
||||
"""Algebraic division wibyth a scalar value out-of-place."""
|
||||
def __truediv__(self, val: Union[Number, np.number]):
|
||||
"""Algebraic division with a scalar value out-of-place."""
|
||||
return deepcopy(self).__itruediv__(val)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -523,7 +533,10 @@ class Batch:
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
self.__dict__[k] = torch.stack(v, axis)
|
||||
else:
|
||||
self.__dict__[k] = np.stack(v, axis)
|
||||
v = np.stack(v, axis)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_partial = reduce(set.symmetric_difference, keys_map)
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
|
Loading…
x
Reference in New Issue
Block a user