diff --git a/docs/contributor.rst b/docs/contributor.rst index 03097c0..044ea23 100644 --- a/docs/contributor.rst +++ b/docs/contributor.rst @@ -5,3 +5,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom * Jiayi Weng (`Trinkle23897 `_) * Minghao Zhang (`Mehooz `_) +* Alexis Duburcq (`duburcqa `_) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 4b528a3..5f47723 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -13,9 +13,9 @@ def test_batch(): batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) - assert batch.obs == [1, 1] + assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) - assert batch[0].obs == batch[1].obs + assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: @@ -39,14 +39,14 @@ def test_batch(): 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 - assert Batch().size == 0 - assert batch2.size == 1 + assert Batch().shape == [] + assert batch2.shape[0] == 1 with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] - assert batch2[0].size == 1 - with pytest.raises(TypeError): + assert batch2[0].shape == [] + with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) @@ -87,24 +87,36 @@ def test_batch_over_batch(): batch2.b.b[-1] = 0 print(batch2) for k, v in batch2.items(): - assert batch2[k] == v + assert np.all(batch2[k] == v) assert batch2[-1].b.b == 0 batch2.cat_(Batch(c=[6, 7, 8], b=batch)) - assert batch2.c == [6, 7, 8, 6, 7, 8] - assert batch2.b.a == [3, 4, 5, 3, 4, 5] - assert batch2.b.b == [4, 5, 0, 4, 5, 0] + assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) + assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) + assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) - assert batch3.c == [6, 7, 8, 6, 7, 8] - assert batch3.b.a == [3, 4, 5, 3, 4, 5] - assert batch3.b.b == [4, 5, 6, 4, 5, 6] + assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) + assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) + assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) batch4 = Batch(({'a': {'b': np.array([1.0])}},)) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0 + # advanced slicing + batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])}) + assert batch5.shape == [1, 2] + with pytest.raises(IndexError): + batch5[2] + with pytest.raises(IndexError): + batch5[:, 3] + with pytest.raises(IndexError): + batch5[:, :, -1] + batch5[:, -1] += 1 + assert np.allclose(batch5.a, [1, 3]) + assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) -def test_batch_cat_and_stack(): +def test_batch_cat_and_stack_and_empty(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat((b1, b2)) @@ -133,6 +145,24 @@ def test_batch_cat_and_stack(): assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 + b5[1] = Batch.empty(b5[0]) + assert np.allclose(b5.a, [False, False]) + assert np.allclose(b5.b.c, [2, 0]) + assert np.allclose(b5.b.d, [1, 0]) + data = Batch(a=[False, True], + b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]}, + c=np.array([1, 3, 4], dtype=np.int), + t=torch.tensor([4, 5, 6, 7.])) + data[-1] = Batch.empty(data[1]) + assert np.allclose(data.c, [1, 3, 0]) + assert np.allclose(data.a, [False, False]) + assert list(data.b.c) == ['2.0', ''] + assert list(data.b.d) == [1, None] + assert np.allclose(data.b.e, [2, 0]) + assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.])) + b0 = Batch() + b0.empty_() + assert b0.shape == [] def test_batch_over_batch_to_torch(): @@ -215,3 +245,5 @@ if __name__ == '__main__': test_utils_to_torch() test_batch_pickle() test_batch_from_to_numpy_without_copy() + test_batch_numpy_compatibility() + test_batch_cat_and_stack_and_empty() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 65c638a..bbf1c6a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -74,8 +74,9 @@ class Batch: >>> import numpy as np >>> from tianshou.data import Batch >>> data = Batch(a=4, b=[5, 5], c='2312312') + >>> # the list will automatically be converted to numpy array >>> data.b - [5, 5] + array([5, 5]) >>> data.b = np.array([3, 4, 5]) >>> print(data) Batch( @@ -104,8 +105,6 @@ class Batch: together: :: - >>> import numpy as np - >>> from tianshou.data import Batch >>> data = Batch([{'a': {'b': [0.0, "info"]}}]) >>> print(data[0]) Batch( @@ -119,7 +118,6 @@ class Batch: key, or iterate over stored data: :: - >>> from tianshou.data import Batch >>> data = Batch(a=4, b=[5, 5]) >>> print(data["a"]) 4 @@ -130,28 +128,36 @@ class Batch: :class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for - arrays. You can access or iterate over the individual samples, if any: + arrays. It also supports the advanced slicing method, such as batch[:, i], + if the index is valid. You can access or iterate over the individual + samples, if any: :: - >>> import numpy as np - >>> from tianshou.data import Batch - >>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5]) + >>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]]) >>> print(data[0]) Batch( - a: np.array([0.0, 2.0]) - b: 5 + a: array([0., 2.]) + b: array([ 5, -5]), ) >>> for sample in data: >>> print(sample.a) - [0.0, 2.0] - [1.0, 3.0] + [0., 2.] + [1., 3.] + + >>> print(data.shape) + [1, 2] + >>> data[:, 1] += 1 + >>> print(data) + Batch( + a: array([[0., 3.], + [1., 4.]]), + b: array([[ 5, -4]]), + ) Similarly, one can also perform simple algebra on it, and stack, split or concatenate multiple instances: :: - >>> import numpy as np - >>> from tianshou.data import Batch >>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5) >>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5) >>> data = Batch.stack((data_1, data_2)) @@ -169,11 +175,10 @@ class Batch: >>> data_split = list(data.split(1, False)) >>> print(list(data.split(1, False))) [Batch( - b: [5], + b: array([5]), a: array([[0., 2.]]), - ), - Batch( - b: [-5], + ), Batch( + b: array([-5]), a: array([[1., 3.]]), )] >>> data_cat = Batch.cat(data_split) @@ -188,8 +193,6 @@ class Batch: None is added in list or :class:`np.ndarray` of objects, 0 otherwise. :: - >>> import numpy as np - >>> from tianshou.data import Batch >>> data_1 = Batch(a=np.array([0.0, 2.0])) >>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done') >>> data = Batch.stack((data_1, data_2)) @@ -200,23 +203,40 @@ class Batch: b: array([None, 'done'], dtype=object), ) - :meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__` - methods are also provided to respectively get the length and the size of - a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which - means that getting the length of a scalar Batch raises an exception, while - the size is 1. The size is only 0 if empty. Note that the size and length - are the identical if multiple samples are stored: + Also with method empty (which will set to 0 or ``None`` (with np.object)) + :: + + >>> data.empty_() + >>> print(data) + Batch( + a: array([[0., 0.], + [0., 0.]]), + b: array([None, None], dtype=object), + ) + >>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]}) + >>> data[0] = Batch.empty(data[1]) + >>> data + Batch( + a: array([False, True]), + b: Batch( + c: array([0., 3.]), + d: array([0., 0.]), + ), + ) + + :meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__` + methods are also provided to respectively get the shape and the length of + a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which + means that getting the length of a scalar Batch raises an exception. :: - >>> import numpy as np - >>> from tianshou.data import Batch >>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4))) - >>> data.size - 2 + >>> data.shape + [2] >>> len(data) 2 - >>> data[0].size - 1 + >>> data[0].shape + [] >>> len(data[0]) TypeError: Object of type 'Batch' has no len() @@ -240,13 +260,26 @@ class Batch: if isinstance(v, dict) or _is_batch_set(v): self.__dict__[k] = Batch(v) else: + if isinstance(v, list): + v = np.array(v) self.__dict__[k] = v if len(kwargs) > 0: self.__init__(kwargs) + def __setattr__(self, key: str, value: Any): + """self[key] = value""" + if isinstance(value, list): + if _is_batch_set(value): + value = Batch(value) + else: + value = np.array(value) + elif isinstance(value, dict): + value = Batch(value) + self.__dict__[key] = value + def __getstate__(self): - """Pickling interface. Only the actual data are serialized - for both efficiency and simplicity. + """Pickling interface. Only the actual data are serialized for both + efficiency and simplicity. """ state = {} for k, v in self.items(): @@ -256,9 +289,9 @@ class Batch: return state def __setstate__(self, state): - """Unpickling interface. At this point, self is an empty Batch - instance that has not been initialized, so it can safely be - initialized by the pickle state. + """Unpickling interface. At this point, self is an empty Batch instance + that has not been initialized, so it can safely be initialized by the + pickle state. """ self.__init__(**state) @@ -267,26 +300,18 @@ class Batch: """Return self[index].""" if isinstance(index, str): return self.__dict__[index] + b = Batch() + for k, v in self.items(): + if isinstance(v, Batch) and len(v.__dict__) == 0: + b.__dict__[k] = Batch() + else: + b.__dict__[k] = v[index] + return b - if not _valid_bounds(len(self), index): - raise IndexError( - f"Index {index} out of bounds for Batch of len {len(self)}.") - else: - b = Batch() - is_index_scalar = isinstance(index, (int, np.integer)) or \ - (isinstance(index, np.ndarray) and index.ndim == 0) - for k, v in self.items(): - if isinstance(v, Batch) and len(v.__dict__) == 0: - b.__dict__[k] = Batch() - elif is_index_scalar or not isinstance(v, list): - b.__dict__[k] = v[index] - else: - b.__dict__[k] = [v[i] for i in index] - return b - - def __setitem__(self, index: Union[ - str, slice, int, np.integer, np.ndarray, List[int]], - value: Any) -> None: + def __setitem__( + self, + index: Union[str, slice, int, np.integer, np.ndarray, List[int]], + value: Any) -> None: """Assign value to self[index].""" if isinstance(index, str): self.__dict__[index] = value @@ -319,8 +344,6 @@ class Batch: other.__dict__.values()): if r is None: continue - elif isinstance(r, list): - self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)] else: self.__dict__[k] += v return self @@ -328,8 +351,6 @@ class Batch: for k, r in self.items(): if r is None: continue - elif isinstance(r, list): - self.__dict__[k] = [r_ + other for r_ in r] else: self.__dict__[k] += other return self @@ -440,13 +461,14 @@ class Batch: v.to_torch(dtype, device) def append(self, batch: 'Batch') -> None: - warnings.warn('Method append will be removed soon, please use ' + warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be ' + 'removed soon, please use ' ':meth:`~tianshou.data.Batch.cat`') return self.cat_(batch) def cat_(self, batch: 'Batch') -> None: - """Concatenate a :class:`~tianshou.data.Batch` object into - current batch. + """Concatenate a :class:`~tianshou.data.Batch` object into current + batch. """ assert isinstance(batch, Batch), \ 'Only Batch is allowed to be concatenated in-place!' @@ -459,8 +481,6 @@ class Batch: self.__dict__[k] = np.concatenate([self.__dict__[k], v]) elif isinstance(v, torch.Tensor): self.__dict__[k] = torch.cat([self.__dict__[k], v]) - elif isinstance(v, list): - self.__dict__[k] += copy.deepcopy(v) elif isinstance(v, Batch): self.__dict__[k].cat_(v) else: @@ -468,12 +488,12 @@ class Batch: f'{type(v)} in class Batch.' raise TypeError(s) - @classmethod - def cat(cls, batches: List['Batch']) -> 'Batch': - """Concatenate a :class:`~tianshou.data.Batch` object into a - single new batch. + @staticmethod + def cat(batches: List['Batch']) -> 'Batch': + """Concatenate a :class:`~tianshou.data.Batch` object into a single + new batch. """ - batch = cls() + batch = Batch() for batch_ in batches: batch.cat_(batch_) return batch @@ -481,8 +501,7 @@ class Batch: def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: - """Stack a :class:`~tianshou.data.Batch` object i into current - batch. + """Stack a :class:`~tianshou.data.Batch` object i into current batch. """ if len(self.__dict__) > 0: batches = [self] + list(batches) @@ -511,13 +530,42 @@ class Batch: @staticmethod def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': - """Stack a :class:`~tianshou.data.Batch` object into a - single new batch. + """Stack a :class:`~tianshou.data.Batch` object into a single new + batch. """ batch = Batch() batch.stack_(batches, axis) return batch + def empty_(self) -> 'Batch': + """Return an empty a :class:`~tianshou.data.Batch` object with 0 or + ``None`` filled. + """ + for k, v in self.items(): + if v is None: + continue + if isinstance(v, Batch): + self.__dict__[k].empty_() + elif isinstance(v, np.ndarray) and v.dtype == np.object: + self.__dict__[k].fill(None) + elif isinstance(v, torch.Tensor): # cannot apply fill_ directly + self.__dict__[k] = torch.zeros_like(self.__dict__[k]) + else: # np + self.__dict__[k] *= 0 + if hasattr(v, 'dtype') and v.dtype.kind in 'fc': + self.__dict__[k] = np.nan_to_num(self.__dict__[k]) + return self + + @staticmethod + def empty(batch: 'Batch') -> 'Batch': + """Return an empty :class:`~tianshou.data.Batch` object with 0 or + ``None`` filled, the shape is the same as the given + :class:`~tianshou.data.Batch`. + """ + batch = Batch(**batch) + batch.empty_() + return batch + def __len__(self) -> int: """Return len(self).""" r = [] @@ -534,21 +582,20 @@ class Batch: return min(r) @property - def size(self) -> int: - """Return self.size.""" + def shape(self) -> List[int]: + """Return self.shape.""" if len(self.__dict__.keys()) == 0: - return 0 + return [] else: - r = [] + data_shape = [] for v in self.__dict__.values(): - if isinstance(v, Batch): - r.append(v.size) - elif hasattr(v, '__len__') and (not isinstance( - v, (np.ndarray, torch.Tensor)) or v.ndim > 0): - r.append(len(v)) - else: - r.append(1) - return min(r) if len(r) > 0 else 0 + try: + data_shape.append(v.shape) + except AttributeError: + raise TypeError("No support for 'shape' method with " + f"type {type(v)} in class Batch.") + return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ + else data_shape[0] def split(self, size: Optional[int] = None, shuffle: bool = True) -> Iterator['Batch']: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 4b72308..75c948a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -200,8 +200,10 @@ class Collector(object): return if isinstance(self.state, list): self.state[id] = None - elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)): + elif isinstance(self.state, (torch.Tensor, np.ndarray)): self.state[id] *= 0 + else: # Batch + self.state[id].empty_() def collect(self, n_step: int = 0,