From affeec13de764c46e09588b3242cd6a9cbde87ea Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 21:46:01 +0800 Subject: [PATCH] Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> --- .github/workflows/pytest.yml | 1 + test/base/test_batch.py | 43 ++++++++- tianshou/data/batch.py | 167 +++++++++++++++++++++++------------ tianshou/data/buffer.py | 21 ++--- 4 files changed, 162 insertions(+), 70 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index eb14a87..c1e3604 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -5,6 +5,7 @@ on: [push, pull_request] jobs: build: runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: python-version: [3.6, 3.7, 3.8] diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a9f2cdd..03031ff 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -10,7 +10,17 @@ from tianshou.data import Batch, to_torch def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() + assert Batch(b={'c': {}}).is_empty() + assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() + b = Batch() + b.update() + assert b.is_empty() + b.update(c=[3, 5]) + assert np.allclose(b.c, [3, 5]) + # mimic the behavior of dict.update, where kwargs can overwrite keys + b.update({'a': 2}, a=3) + assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) @@ -86,6 +96,18 @@ def test_batch(): assert batch3.a.d.f[0] == 5.0 with pytest.raises(KeyError): batch3.a.d[0] = Batch(f=5.0, g=0.0) + # auto convert + batch4 = Batch(a=np.array(['a', 'b'])) + assert batch4.a.dtype == np.object # auto convert to np.object + batch4.update(a=np.array(['c', 'd'])) + assert list(batch4.a) == ['c', 'd'] + assert batch4.a.dtype == np.object # auto convert to np.object + batch5 = Batch(a=np.array([{'index': 0}])) + assert isinstance(batch5.a, Batch) + assert np.allclose(batch5.a.index, [0]) + batch5.b = np.array([{'index': 1}]) + assert isinstance(batch5.b, Batch) + assert np.allclose(batch5.b.index, [1]) def test_batch_over_batch(): @@ -100,6 +122,11 @@ def test_batch_over_batch(): assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) + batch2.update(batch2.b, six=[6, 6, 6]) + assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) + assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) + assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) + assert np.allclose(batch2.six, [6, 6, 6]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) @@ -124,18 +151,32 @@ def test_batch_over_batch(): def test_batch_cat_and_stack(): + # test cat with compatible keys b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) - b12_cat_out = Batch.cat((b1, b2)) + b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 + b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 + + # test batch with incompatible keys + b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) + b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) + test = Batch.cat([b1, b2]) + ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6b10517..1240cfe 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -259,8 +259,7 @@ class Batch: v_ = None if not isinstance(v, np.ndarray) and \ all(isinstance(e, torch.Tensor) for e in v): - v_ = torch.stack(v) - self.__dict__[k] = v_ + self.__dict__[k] = torch.stack(v) continue else: v_ = np.asanyarray(v) @@ -294,7 +293,8 @@ class Batch: value = np.array(value) if not issubclass(value.dtype.type, (np.bool_, np.number)): value = value.astype(np.object) - elif isinstance(value, dict): + elif isinstance(value, dict) or isinstance(value, np.ndarray) \ + and value.dtype == np.object and _is_batch_set(value): value = Batch(value) self.__dict__[key] = value @@ -333,9 +333,8 @@ class Batch: else: raise IndexError("Cannot access item from empty Batch object.") - def __setitem__( - self, - index: Union[str, slice, int, np.integer, np.ndarray, List[int]], + def __setitem__(self, index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" if isinstance(value, np.ndarray): @@ -454,10 +453,8 @@ class Batch: elif isinstance(v, Batch): v.to_numpy() - def to_torch(self, - dtype: Optional[torch.dtype] = None, - device: Union[str, int, torch.device] = 'cpu' - ) -> None: + def to_torch(self, dtype: Optional[torch.dtype] = None, + device: Union[str, int, torch.device] = 'cpu') -> None: """Change all numpy.ndarray to torch.Tensor. This is an in-place operation. """ @@ -473,66 +470,111 @@ class Batch: v = v.type(dtype) self.__dict__[k] = v elif isinstance(v, torch.Tensor): - if dtype is not None and v.dtype != dtype: - must_update_tensor = True - elif v.device.type != device.type: - must_update_tensor = True - elif device.index is not None and \ + if dtype is not None and v.dtype != dtype or \ + v.device.type != device.type or \ + device.index is not None and \ device.index != v.device.index: - must_update_tensor = True - else: - must_update_tensor = False - if must_update_tensor: if dtype is not None: v = v.type(dtype) self.__dict__[k] = v.to(device) elif isinstance(v, Batch): v.to_torch(dtype, device) - def append(self, batch: 'Batch') -> None: - warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be ' - 'removed soon, please use ' - ':meth:`~tianshou.data.Batch.cat`') - return self.cat_(batch) - - def cat_(self, batch: 'Batch') -> None: - """Concatenate a :class:`~tianshou.data.Batch` object into current - batch. + def cat_(self, + batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: + """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects + into current batch. """ - assert isinstance(batch, Batch), \ - 'Only Batch is allowed to be concatenated in-place!' - for k, v in batch.items(): - if v is None: - continue - if not hasattr(self, k) or self.__dict__[k] is None: - self.__dict__[k] = deepcopy(v) - elif isinstance(v, np.ndarray) and v.ndim > 0: - self.__dict__[k] = np.concatenate([self.__dict__[k], v]) - elif isinstance(v, torch.Tensor): - self.__dict__[k] = torch.cat([self.__dict__[k], v]) - elif isinstance(v, Batch): - self.__dict__[k].cat_(v) + if isinstance(batches, Batch): + batches = [batches] + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] + if len(self.__dict__) > 0: + batches = [self] + list(batches) + # partial keys will be padded by zeros + # with the shape of [len, rest_shape] + lens = [len(x) for x in batches] + keys_map = list(map(lambda e: set(e.keys()), batches)) + keys_shared = set.intersection(*keys_map) + values_shared = [ + [e[k] for e in batches] for k in keys_shared] + _assert_type_keys(keys_shared) + for k, v in zip(keys_shared, values_shared): + if all(isinstance(e, (dict, Batch)) for e in v): + self.__dict__[k] = Batch.cat(v) + elif all(isinstance(e, torch.Tensor) for e in v): + self.__dict__[k] = torch.cat(v) else: - s = 'No support for method "cat" with type '\ - f'{type(v)} in class Batch.' - raise TypeError(s) + v = np.concatenate(v) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + self.__dict__[k] = v + keys_partial = set.union(*keys_map) - keys_shared + _assert_type_keys(keys_partial) + for k in keys_partial: + is_dict = False + value = None + for i, e in enumerate(batches): + val = e.get(k, None) + if val is not None: + if isinstance(val, (dict, Batch)): + is_dict = True + else: # np.ndarray or torch.Tensor + value = val + break + if is_dict: + self.__dict__[k] = Batch.cat( + [e.get(k, Batch()) for e in batches]) + else: + if isinstance(value, np.ndarray): + arrs = [] + for i, e in enumerate(batches): + shape = [lens[i]] + list(value.shape[1:]) + pad = np.zeros(shape, dtype=value.dtype) + arrs.append(e.get(k, pad)) + self.__dict__[k] = np.concatenate(arrs) + elif isinstance(value, torch.Tensor): + arrs = [] + for i, e in enumerate(batches): + shape = [lens[i]] + list(value.shape[1:]) + pad = torch.zeros(shape, + dtype=value.dtype, + device=value.device) + arrs.append(e.get(k, pad)) + self.__dict__[k] = torch.cat(arrs) + else: + raise TypeError( + f"cannot cat value with type {type(value)}, we only " + "support dict, Batch, np.ndarray, and torch.Tensor") @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': - """Concatenate a list of :class:`~tianshou.data.Batch` object into a single - new batch. + """Concatenate a list of :class:`~tianshou.data.Batch` object into a + single new batch. For keys that are not shared across all batches, + batches that do not have these keys will be padded by zeros with + appropriate shapes. E.g. + :: + + >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) + >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.cat([a, b]) + >>> c.a.shape + (7, 4) + >>> c.b.shape + (7, 3) + >>> c.common.c.shape + (7, 5) """ batch = Batch() - for batch_ in batches: - if isinstance(batch_, dict): - batch_ = Batch(batch_) - batch.cat_(batch_) + batch.cat_(batches) return batch def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: - """Stack a :class:`~tianshou.data.Batch` object i into current batch. + """Stack a list of :class:`~tianshou.data.Batch` object into current + batch. """ if len(self.__dict__) > 0: batches = [self] + list(batches) @@ -566,8 +608,8 @@ class Batch: @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': - """Stack a :class:`~tianshou.data.Batch` object into a single new - batch. + """Stack a list of :class:`~tianshou.data.Batch` object into a single + new batch. """ batch = Batch() batch.stack_(batches, axis) @@ -611,11 +653,24 @@ class Batch: """ return deepcopy(batch).empty_(index) + def update(self, batch: Optional[Union[dict, 'Batch']] = None, + **kwargs) -> None: + """Update this batch from another dict/Batch.""" + if batch is None: + self.update(kwargs) + return + if isinstance(batch, dict): + batch = Batch(batch) + for k, v in batch.items(): + self.__dict__[k] = v + if kwargs: + self.update(kwargs) + def __len__(self) -> int: """Return len(self).""" r = [] for v in self.__dict__.values(): - if isinstance(v, Batch) and len(v.__dict__) == 0: + if isinstance(v, Batch) and v.is_empty(): continue elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): @@ -627,7 +682,9 @@ class Batch: return min(r) def is_empty(self): - return len(self.__dict__.keys()) == 0 + return not any( + not x.is_empty() if isinstance(x, Batch) + else hasattr(x, '__len__') and len(x) > 0 for x in self.values()) @property def shape(self) -> List[int]: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 33d3178..f593d2a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -108,8 +108,7 @@ class ReplayBuffer: super().__init__() self._maxsize = size self._stack = stack_num - assert stack_num != 1, \ - 'stack_num should greater than 1' + assert stack_num != 1, 'stack_num should greater than 1' self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next @@ -136,12 +135,11 @@ class ReplayBuffer: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, np.ndarray) and \ - value.shape[1:] != inst.shape: + if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, key: " - f"{name}, expect shape: {value.shape[1:]}" - f", given shape: {inst.shape}.") + f"{name}, expect shape: {value.shape[1:]}, " + f"given shape: {inst.shape}.") try: value[self._index] = inst except KeyError: @@ -357,7 +355,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._weight_sum = 0.0 self._amortization_freq = 50 self._replace = replace - self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64) + self._meta.weight = np.zeros(size, dtype=np.float64) def add(self, obs: Union[dict, np.ndarray], @@ -372,7 +370,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): """Add a batch of data into replay buffer.""" # we have to sacrifice some convenience for speed self._weight_sum += np.abs(weight) ** self._alpha - \ - self._meta.__dict__['weight'][self._index] + self._meta.weight[self._index] self._add_to_buffer('weight', np.abs(weight) ** self._alpha) super().add(obs, act, rew, done, obs_next, info, policy) @@ -410,14 +408,9 @@ class PrioritizedReplayBuffer(ReplayBuffer): f"batch_size should be less than {len(self)}, \ or set replace=True") batch = self[indice] - impt_weight = Batch( - impt_weight=(self._size * p) ** (-self._beta)) - batch.cat_(impt_weight) + batch["impt_weight"] = (self._size * p) ** (-self._beta) return batch, indice - def reset(self) -> None: - super().reset() - def update_weight(self, indice: Union[slice, np.ndarray], new_weight: np.ndarray) -> None: """Update priority weight by indice in this buffer.