add type check for each element rather than the first element (#112)
This PR does the following: - improvement: dramatic reduce of the call to _is_batch_set - bugfix: list(Batch()) fail; Batch(a=[torch.ones(3), torch.ones(3)]) fail; - misc: add type check for each element rather than the first element; add test case; _create_value with torch.Tensor does not have np.object type;
This commit is contained in:
parent
481015932c
commit
7f9a1f1328
@ -8,6 +8,9 @@ from tianshou.data import Batch, to_torch
|
|||||||
|
|
||||||
|
|
||||||
def test_batch():
|
def test_batch():
|
||||||
|
assert list(Batch()) == []
|
||||||
|
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||||
|
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||||
assert batch.obs == batch["obs"]
|
assert batch.obs == batch["obs"]
|
||||||
batch.obs = [1]
|
batch.obs = [1]
|
||||||
|
@ -16,10 +16,10 @@ warnings.filterwarnings(
|
|||||||
|
|
||||||
def _is_batch_set(data: Any) -> bool:
|
def _is_batch_set(data: Any) -> bool:
|
||||||
if isinstance(data, (list, tuple)):
|
if isinstance(data, (list, tuple)):
|
||||||
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
|
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
|
||||||
return True
|
return True
|
||||||
elif isinstance(data, np.ndarray):
|
elif isinstance(data, np.ndarray) and data.dtype == np.object:
|
||||||
if isinstance(data.item(0), (dict, Batch)):
|
if all(isinstance(e, (dict, Batch)) for e in data.tolist()):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -43,7 +43,8 @@ def _valid_bounds(length: int, index: Union[
|
|||||||
return start_valid and stop_valid
|
return start_valid and stop_valid
|
||||||
|
|
||||||
|
|
||||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
def _create_value(inst: Any, size: int) -> Union[
|
||||||
|
'Batch', np.ndarray, torch.Tensor]:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||||
target_type = inst.dtype.type
|
target_type = inst.dtype.type
|
||||||
@ -54,7 +55,7 @@ def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
|||||||
dtype=target_type)
|
dtype=target_type)
|
||||||
elif isinstance(inst, torch.Tensor):
|
elif isinstance(inst, torch.Tensor):
|
||||||
return torch.full((size, *inst.shape),
|
return torch.full((size, *inst.shape),
|
||||||
fill_value=None if inst.dtype == np.object else 0,
|
fill_value=0,
|
||||||
device=inst.device,
|
device=inst.device,
|
||||||
dtype=inst.dtype)
|
dtype=inst.dtype)
|
||||||
elif isinstance(inst, (dict, Batch)):
|
elif isinstance(inst, (dict, Batch)):
|
||||||
@ -263,18 +264,36 @@ class Batch:
|
|||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
if copy:
|
if copy:
|
||||||
batch_dict = deepcopy(batch_dict)
|
batch_dict = deepcopy(batch_dict)
|
||||||
if _is_batch_set(batch_dict):
|
if batch_dict is not None:
|
||||||
self.stack_(batch_dict)
|
if isinstance(batch_dict, (dict, Batch)):
|
||||||
elif isinstance(batch_dict, (dict, Batch)):
|
for k, v in batch_dict.items():
|
||||||
for k, v in batch_dict.items():
|
if isinstance(v, (list, tuple, np.ndarray)):
|
||||||
if isinstance(v, dict) or _is_batch_set(v):
|
v_ = None
|
||||||
self.__dict__[k] = Batch(v)
|
if not isinstance(v, np.ndarray) and \
|
||||||
else:
|
all(isinstance(e, torch.Tensor) for e in v):
|
||||||
if isinstance(v, list):
|
v_ = torch.stack(v)
|
||||||
v = np.array(v)
|
self.__dict__[k] = v_
|
||||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
continue
|
||||||
v = v.astype(np.object)
|
else:
|
||||||
self.__dict__[k] = v
|
v_ = np.asanyarray(v)
|
||||||
|
if v_.dtype != np.object:
|
||||||
|
v = v_ # normal data list, this is the main case
|
||||||
|
if not issubclass(v.dtype.type,
|
||||||
|
(np.bool_, np.number)):
|
||||||
|
v = v.astype(np.object)
|
||||||
|
else:
|
||||||
|
if _is_batch_set(v):
|
||||||
|
v = Batch(v) # list of dict / Batch
|
||||||
|
else:
|
||||||
|
# this is actually a data list with objects
|
||||||
|
v = v_
|
||||||
|
self.__dict__[k] = v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
self.__dict__[k] = Batch(v)
|
||||||
|
else:
|
||||||
|
self.__dict__[k] = v
|
||||||
|
elif _is_batch_set(batch_dict):
|
||||||
|
self.stack_(batch_dict)
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.__init__(kwargs, copy=copy)
|
self.__init__(kwargs, copy=copy)
|
||||||
|
|
||||||
@ -536,9 +555,9 @@ class Batch:
|
|||||||
values_shared = [
|
values_shared = [
|
||||||
[e[k] for e in batches] for k in keys_shared]
|
[e[k] for e in batches] for k in keys_shared]
|
||||||
for k, v in zip(keys_shared, values_shared):
|
for k, v in zip(keys_shared, values_shared):
|
||||||
if isinstance(v[0], (dict, Batch)):
|
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||||
self.__dict__[k] = Batch.stack(v, axis)
|
self.__dict__[k] = Batch.stack(v, axis)
|
||||||
elif isinstance(v[0], torch.Tensor):
|
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||||
self.__dict__[k] = torch.stack(v, axis)
|
self.__dict__[k] = torch.stack(v, axis)
|
||||||
else:
|
else:
|
||||||
v = np.stack(v, axis)
|
v = np.stack(v, axis)
|
||||||
|
@ -37,7 +37,7 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
|||||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||||
x = to_torch(np.asanyarray(x), dtype, device)
|
x = to_torch(np.asanyarray(x), dtype, device)
|
||||||
elif isinstance(x, list) and len(x) > 0 and \
|
elif isinstance(x, list) and len(x) > 0 and \
|
||||||
isinstance(x[0], (np.number, np.bool_, Number)):
|
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
|
||||||
x = to_torch(np.asanyarray(x), dtype, device)
|
x = to_torch(np.asanyarray(x), dtype, device)
|
||||||
elif isinstance(x, np.ndarray) and \
|
elif isinstance(x, np.ndarray) and \
|
||||||
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user