diff --git a/test/base/test_batch.py b/test/base/test_batch.py index ee66144..1303fb0 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -8,6 +8,9 @@ from tianshou.data import Batch, to_torch def test_batch(): + assert list(Batch()) == [] + batch = Batch(a=[torch.ones(3), torch.ones(3)]) + assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 542564a..4647e7f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -16,10 +16,10 @@ warnings.filterwarnings( def _is_batch_set(data: Any) -> bool: if isinstance(data, (list, tuple)): - if len(data) > 0 and isinstance(data[0], (dict, Batch)): + if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True - elif isinstance(data, np.ndarray): - if isinstance(data.item(0), (dict, Batch)): + elif isinstance(data, np.ndarray) and data.dtype == np.object: + if all(isinstance(e, (dict, Batch)) for e in data.tolist()): return True return False @@ -43,7 +43,8 @@ def _valid_bounds(length: int, index: Union[ return start_valid and stop_valid -def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: +def _create_value(inst: Any, size: int) -> Union[ + 'Batch', np.ndarray, torch.Tensor]: if isinstance(inst, np.ndarray): if issubclass(inst.dtype.type, (np.bool_, np.number)): target_type = inst.dtype.type @@ -54,7 +55,7 @@ def _create_value(inst: Any, size: int) -> Union['Batch', np.ndarray]: dtype=target_type) elif isinstance(inst, torch.Tensor): return torch.full((size, *inst.shape), - fill_value=None if inst.dtype == np.object else 0, + fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): @@ -263,18 +264,36 @@ class Batch: **kwargs) -> None: if copy: batch_dict = deepcopy(batch_dict) - if _is_batch_set(batch_dict): - self.stack_(batch_dict) - elif isinstance(batch_dict, (dict, Batch)): - for k, v in batch_dict.items(): - if isinstance(v, dict) or _is_batch_set(v): - self.__dict__[k] = Batch(v) - else: - if isinstance(v, list): - v = np.array(v) - if not issubclass(v.dtype.type, (np.bool_, np.number)): - v = v.astype(np.object) - self.__dict__[k] = v + if batch_dict is not None: + if isinstance(batch_dict, (dict, Batch)): + for k, v in batch_dict.items(): + if isinstance(v, (list, tuple, np.ndarray)): + v_ = None + if not isinstance(v, np.ndarray) and \ + all(isinstance(e, torch.Tensor) for e in v): + v_ = torch.stack(v) + self.__dict__[k] = v_ + continue + else: + v_ = np.asanyarray(v) + if v_.dtype != np.object: + v = v_ # normal data list, this is the main case + if not issubclass(v.dtype.type, + (np.bool_, np.number)): + v = v.astype(np.object) + else: + if _is_batch_set(v): + v = Batch(v) # list of dict / Batch + else: + # this is actually a data list with objects + v = v_ + self.__dict__[k] = v + elif isinstance(v, dict): + self.__dict__[k] = Batch(v) + else: + self.__dict__[k] = v + elif _is_batch_set(batch_dict): + self.stack_(batch_dict) if len(kwargs) > 0: self.__init__(kwargs, copy=copy) @@ -536,9 +555,9 @@ class Batch: values_shared = [ [e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): - if isinstance(v[0], (dict, Batch)): + if all(isinstance(e, (dict, Batch)) for e in v): self.__dict__[k] = Batch.stack(v, axis) - elif isinstance(v[0], torch.Tensor): + elif all(isinstance(e, torch.Tensor) for e in v): self.__dict__[k] = torch.stack(v, axis) else: v = np.stack(v, axis) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index cf70ea7..bf4b3f6 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -37,7 +37,7 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, list) and len(x) > 0 and \ - isinstance(x[0], (np.number, np.bool_, Number)): + all(isinstance(e, (np.number, np.bool_, Number)) for e in x): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, np.ndarray) and \ isinstance(x.item(0), (np.number, np.bool_, Number)):