diff --git a/test/base/test_batch.py b/test/base/test_batch.py index be88a5a..b7adeb0 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -54,6 +54,9 @@ def test_batch_over_batch(): assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.b == [4, 5, 6, 4, 5, 6] + batch4 = Batch(({'a': {'b': np.array([1.0])}},)) + assert batch4.a.b.ndim == 2 + assert batch4.a.b[0, 0] == 1.0 def test_batch_cat_and_stack(): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9572db4..e0fe605 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,7 +3,7 @@ import copy import pprint import warnings import numpy as np -from typing import Any, List, Union, Iterator, Optional +from typing import Any, List, Tuple, Union, Iterator, Optional # Disable pickle warning related to torch, since it has been removed # on torch master branch. See Pull Request #39003 for details: @@ -76,29 +76,28 @@ class Batch: def __init__(self, batch_dict: Optional[ - Union[dict, List[dict], np.ndarray]] = None, + Union[dict, Tuple[dict], List[dict], np.ndarray]] = None, **kwargs) -> None: - if isinstance(batch_dict, (list, np.ndarray)) \ + if isinstance(batch_dict, (list, tuple, np.ndarray)) \ and len(batch_dict) > 0 and isinstance(batch_dict[0], dict): for k, v in zip(batch_dict[0].keys(), zip(*[e.values() for e in batch_dict])): - if isinstance(v, (list, np.ndarray)) \ - and len(v) > 0 and isinstance(v[0], dict): - self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v]) + if isinstance(v[0], dict) \ + or (isinstance(v, (list, tuple, np.ndarray)) + and len(v) > 0 and isinstance(v[0], dict)): + self.__dict__[k] = Batch(v) elif isinstance(v[0], np.ndarray): self.__dict__[k] = np.stack(v, axis=0) elif isinstance(v[0], torch.Tensor): self.__dict__[k] = torch.stack(v, dim=0) elif isinstance(v[0], Batch): self.__dict__[k] = Batch.stack(v) - elif isinstance(v[0], dict): - self.__dict__[k] = Batch(v) else: self.__dict__[k] = list(v) elif isinstance(batch_dict, dict): for k, v in batch_dict.items(): if isinstance(v, dict) \ - or (isinstance(v, (list, np.ndarray)) + or (isinstance(v, (list, tuple, np.ndarray)) and len(v) > 0 and isinstance(v[0], dict)): self.__dict__[k] = Batch(v) else: