add type check for each element rather than the first element (#112)
This PR does the following: - improvement: dramatic reduce of the call to _is_batch_set - bugfix: list(Batch()) fail; Batch(a=[torch.ones(3), torch.ones(3)]) fail; - misc: add type check for each element rather than the first element; add test case; _create_value with torch.Tensor does not have np.object type;
This commit is contained in:
parent
481015932c
commit
7f9a1f1328
@ -8,6 +8,9 @@ from tianshou.data import Batch, to_torch
|
||||
|
||||
|
||||
def test_batch():
|
||||
assert list(Batch()) == []
|
||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||
assert torch.allclose(batch.a, torch.ones(2, 3))
|
||||
batch = Batch(obs=[0], np=np.zeros([3, 4]))
|
||||
assert batch.obs == batch["obs"]
|
||||
batch.obs = [1]
|
||||
|
@ -16,10 +16,10 @@ warnings.filterwarnings(
|
||||
|
||||
def _is_batch_set(data: Any) -> bool:
|
||||
if isinstance(data, (list, tuple)):
|
||||
if len(data) > 0 and isinstance(data[0], (dict, Batch)):
|
||||
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
|
||||
return True
|
||||
elif isinstance(data, np.ndarray):
|
||||
if isinstance(data.item(0), (dict, Batch)):
|
||||
elif isinstance(data, np.ndarray) and data.dtype == np.object:
|
||||
if all(isinstance(e, (dict, Batch)) for e in data.tolist()):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -43,7 +43,8 @@ def _valid_bounds(length: int, index: Union[
|
||||
return start_valid and stop_valid
|
||||
|
||||
|
||||
def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
||||
def _create_value(inst: Any, size: int) -> Union[
|
||||
'Batch', np.ndarray, torch.Tensor]:
|
||||
if isinstance(inst, np.ndarray):
|
||||
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||
target_type = inst.dtype.type
|
||||
@ -54,7 +55,7 @@ def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]:
|
||||
dtype=target_type)
|
||||
elif isinstance(inst, torch.Tensor):
|
||||
return torch.full((size, *inst.shape),
|
||||
fill_value=None if inst.dtype == np.object else 0,
|
||||
fill_value=0,
|
||||
device=inst.device,
|
||||
dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
@ -263,18 +264,36 @@ class Batch:
|
||||
**kwargs) -> None:
|
||||
if copy:
|
||||
batch_dict = deepcopy(batch_dict)
|
||||
if _is_batch_set(batch_dict):
|
||||
self.stack_(batch_dict)
|
||||
elif isinstance(batch_dict, (dict, Batch)):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, dict) or _is_batch_set(v):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
if isinstance(v, list):
|
||||
v = np.array(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
if batch_dict is not None:
|
||||
if isinstance(batch_dict, (dict, Batch)):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, (list, tuple, np.ndarray)):
|
||||
v_ = None
|
||||
if not isinstance(v, np.ndarray) and \
|
||||
all(isinstance(e, torch.Tensor) for e in v):
|
||||
v_ = torch.stack(v)
|
||||
self.__dict__[k] = v_
|
||||
continue
|
||||
else:
|
||||
v_ = np.asanyarray(v)
|
||||
if v_.dtype != np.object:
|
||||
v = v_ # normal data list, this is the main case
|
||||
if not issubclass(v.dtype.type,
|
||||
(np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
else:
|
||||
if _is_batch_set(v):
|
||||
v = Batch(v) # list of dict / Batch
|
||||
else:
|
||||
# this is actually a data list with objects
|
||||
v = v_
|
||||
self.__dict__[k] = v
|
||||
elif isinstance(v, dict):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
self.__dict__[k] = v
|
||||
elif _is_batch_set(batch_dict):
|
||||
self.stack_(batch_dict)
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs, copy=copy)
|
||||
|
||||
@ -536,9 +555,9 @@ class Batch:
|
||||
values_shared = [
|
||||
[e[k] for e in batches] for k in keys_shared]
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if isinstance(v[0], (dict, Batch)):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
self.__dict__[k] = Batch.stack(v, axis)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||
self.__dict__[k] = torch.stack(v, axis)
|
||||
else:
|
||||
v = np.stack(v, axis)
|
||||
|
@ -37,7 +37,7 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, list) and len(x) > 0 and \
|
||||
isinstance(x[0], (np.number, np.bool_, Number)):
|
||||
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, np.ndarray) and \
|
||||
isinstance(x.item(0), (np.number, np.bool_, Number)):
|
||||
|
Loading…
x
Reference in New Issue
Block a user