diff --git a/test/base/test_batch.py b/test/base/test_batch.py index b7adeb0..0503cd5 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -34,6 +34,18 @@ def test_batch(): assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] + batch2 = Batch(a=[{ + 'b': np.float64(1.0), + 'c': np.zeros(1), + 'd': Batch(e=np.array(3.0))}]) + assert len(batch2) == 1 + assert list(batch2[1].keys()) == ['a'] + assert len(batch2[-2].a.d.keys()) == 0 + assert len(batch2[-1].keys()) > 0 + assert batch2[0][0].a.c == 0.0 + assert isinstance(batch2[0].a.c, np.ndarray) + assert isinstance(batch2[0].a.b, np.float64) + assert isinstance(batch2[0].a.d.e, np.float64) def test_batch_over_batch(): @@ -60,15 +72,18 @@ def test_batch_over_batch(): def test_batch_cat_and_stack(): - b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}]) - b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}]) + b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) + b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b_cat_out = Batch.cat((b1, b2)) b_cat_in = copy.deepcopy(b1) b_cat_in.cat_(b2) assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) - assert b_cat_in.a.d.e.ndim == 2 + assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) + assert isinstance(b_cat_in.a.d.e, np.ndarray) + assert b_cat_in.a.d.e.ndim == 1 b_stack = Batch.stack((b1, b2)) - assert b_stack.a.d.e.ndim == 3 + assert isinstance(b_stack.a.d.e, np.ndarray) + assert b_stack.a.d.e.ndim == 2 def test_batch_over_batch_to_torch(): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e0fe605..a255b8d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -78,15 +78,23 @@ class Batch: batch_dict: Optional[ Union[dict, Tuple[dict], List[dict], np.ndarray]] = None, **kwargs) -> None: - if isinstance(batch_dict, (list, tuple, np.ndarray)) \ - and len(batch_dict) > 0 and isinstance(batch_dict[0], dict): + def _is_batch_set(data: Any) -> bool: + if isinstance(data, (list, tuple)): + if len(data) > 0 and isinstance(data[0], dict): + return True + elif isinstance(data, np.ndarray): + if isinstance(data.item(0), dict): + return True + return False + + if isinstance(batch_dict, np.ndarray) and batch_dict.ndim == 0: + batch_dict = batch_dict[()] + if _is_batch_set(batch_dict): for k, v in zip(batch_dict[0].keys(), zip(*[e.values() for e in batch_dict])): - if isinstance(v[0], dict) \ - or (isinstance(v, (list, tuple, np.ndarray)) - and len(v) > 0 and isinstance(v[0], dict)): + if isinstance(v[0], dict) or _is_batch_set(v[0]): self.__dict__[k] = Batch(v) - elif isinstance(v[0], np.ndarray): + elif isinstance(v[0], (np.generic, np.ndarray)): self.__dict__[k] = np.stack(v, axis=0) elif isinstance(v[0], torch.Tensor): self.__dict__[k] = torch.stack(v, dim=0) @@ -96,9 +104,7 @@ class Batch: self.__dict__[k] = list(v) elif isinstance(batch_dict, dict): for k, v in batch_dict.items(): - if isinstance(v, dict) \ - or (isinstance(v, (list, tuple, np.ndarray)) - and len(v) > 0 and isinstance(v[0], dict)): + if isinstance(v, dict) or _is_batch_set(v): self.__dict__[k] = Batch(v) else: self.__dict__[k] = v @@ -124,18 +130,32 @@ class Batch: """ self.__init__(**state) - def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]: + def __getitem__(self, index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch': """Return self[index].""" + def _valid_bounds(length: int, index: Union[ + slice, int, np.integer, np.ndarray, List[int]]) -> bool: + if isinstance(index, (int, np.integer)): + return -length <= index and index < length + elif isinstance(index, (list, np.ndarray)): + return _valid_bounds(length, min(index)) and \ + _valid_bounds(length, max(index)) + elif isinstance(index, slice): + return _valid_bounds(length, index.start) and \ + _valid_bounds(length, index.stop - 1) + if isinstance(index, str): return self.__getattr__(index) - b = Batch() - for k, v in self.__dict__.items(): - if hasattr(v, '__len__'): - try: - b.__dict__.update(**{k: v[index]}) - except IndexError: - continue - return b + else: + b = Batch() + for k, v in self.__dict__.items(): + if isinstance(v, Batch): + b.__dict__[k] = v[index] + elif hasattr(v, '__len__') and (not isinstance( + v, (np.ndarray, torch.Tensor)) or v.ndim > 0): + if _valid_bounds(len(v), index): + b.__dict__[k] = v[index] + return b def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" @@ -198,7 +218,7 @@ class Batch: device = torch.device(device) for k, v in self.__dict__.items(): - if isinstance(v, np.ndarray): + if isinstance(v, (np.generic, np.ndarray)): v = torch.from_numpy(v).to(device) if dtype is not None: v = v.type(dtype) @@ -236,7 +256,7 @@ class Batch: continue if not hasattr(self, k) or self.__dict__[k] is None: self.__dict__[k] = copy.deepcopy(v) - elif isinstance(v, np.ndarray): + elif isinstance(v, np.ndarray) and v.ndim > 0: self.__dict__[k] = np.concatenate([self.__dict__[k], v]) elif isinstance(v, torch.Tensor): self.__dict__[k] = torch.cat([self.__dict__[k], v]) @@ -274,7 +294,11 @@ class Batch: def __len__(self) -> int: """Return len(self).""" - r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')] + r = [] + for v in self.__dict__.values(): + if hasattr(v, '__len__') and (not isinstance( + v, (np.ndarray, torch.Tensor)) or v.ndim > 0): + r.append(len(v)) return max(r) if len(r) > 0 else 0 def split(self, size: Optional[int] = None,