add type check for each element rather than the first element (#112)

This PR does the following:
- improvement: dramatic reduce of the call to _is_batch_set
- bugfix: list(Batch()) fail; Batch(a=[torch.ones(3), torch.ones(3)]) fail;
- misc: add type check for each element rather than the first element; add test case; _create_value with torch.Tensor does not have np.object type;
This commit is contained in:
youkaichao 2020-07-08 21:00:00 +08:00 committed by GitHub
parent 481015932c
commit 7f9a1f1328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 20 deletions

View File

@ -8,6 +8,9 @@ from tianshou.data import Batch, to_torch
def test_batch():
assert list(Batch()) == []
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3))
batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"]
batch.obs = [1]

View File

@ -16,10 +16,10 @@ warnings.filterwarnings(
def _is_batch_set(data: Any) -> bool:
if isinstance(data, (list, tuple)):
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
return True
elif isinstance(data, np.ndarray):
if isinstance(data.item(0), (dict, Batch)):
elif isinstance(data, np.ndarray) and data.dtype == np.object:
if all(isinstance(e, (dict, Batch)) for e in data.tolist()):
return True
return False
@ -43,7 +43,8 @@ def _valid_bounds(length: int, index: Union[
return start_valid and stop_valid
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
def _create_value(inst: Any, size: int) -> Union[
'Batch', np.ndarray, torch.Tensor]:
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
@ -54,7 +55,7 @@ def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
dtype=target_type)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
fill_value=None if inst.dtype == np.object else 0,
fill_value=0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
@ -263,18 +264,36 @@ class Batch:
**kwargs) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if _is_batch_set(batch_dict):
self.stack_(batch_dict)
elif isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items():
if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v)
else:
if isinstance(v, list):
v = np.array(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
if batch_dict is not None:
if isinstance(batch_dict, (dict, Batch)):
for k, v in batch_dict.items():
if isinstance(v, (list, tuple, np.ndarray)):
v_ = None
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v_ = torch.stack(v)
self.__dict__[k] = v_
continue
else:
v_ = np.asanyarray(v)
if v_.dtype != np.object:
v = v_ # normal data list, this is the main case
if not issubclass(v.dtype.type,
(np.bool_, np.number)):
v = v.astype(np.object)
else:
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# this is actually a data list with objects
v = v_
self.__dict__[k] = v
elif isinstance(v, dict):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = v
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy)
@ -536,9 +555,9 @@ class Batch:
values_shared = [
[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
if isinstance(v[0], (dict, Batch)):
if all(isinstance(e, (dict, Batch)) for e in v):
self.__dict__[k] = Batch.stack(v, axis)
elif isinstance(v[0], torch.Tensor):
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.stack(v, axis)
else:
v = np.stack(v, axis)

View File

@ -37,7 +37,7 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, list) and len(x) > 0 and \
isinstance(x[0], (np.number, np.bool_, Number)):
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, np.ndarray) and \
isinstance(x.item(0), (np.number, np.bool_, Number)):