diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7fddda2..e204151 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -19,6 +19,8 @@ def _is_batch_set(obj: Any) -> bool: # "for element in obj" will just unpack the first dimension, # but obj.tolist() will flatten ndarray of objects # so do not use obj.tolist() + if obj.shape == (): + return False return obj.dtype == object and \ all(isinstance(element, (dict, Batch)) for element in obj) elif isinstance(obj, (list, tuple)):