Improve Batch (#128)
* minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
		
							parent
							
								
									2564e989fb
								
							
						
					
					
						commit
						affeec13de
					
				
							
								
								
									
										1
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							| @ -5,6 +5,7 @@ on: [push, pull_request] | ||||
| jobs: | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     if: "!contains(github.event.head_commit.message, 'ci skip')" | ||||
|     strategy: | ||||
|       matrix: | ||||
|         python-version: [3.6, 3.7, 3.8] | ||||
|  | ||||
| @ -10,7 +10,17 @@ from tianshou.data import Batch, to_torch | ||||
| def test_batch(): | ||||
|     assert list(Batch()) == [] | ||||
|     assert Batch().is_empty() | ||||
|     assert Batch(b={'c': {}}).is_empty() | ||||
|     assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 | ||||
|     assert not Batch(a=[1, 2, 3]).is_empty() | ||||
|     b = Batch() | ||||
|     b.update() | ||||
|     assert b.is_empty() | ||||
|     b.update(c=[3, 5]) | ||||
|     assert np.allclose(b.c, [3, 5]) | ||||
|     # mimic the behavior of dict.update, where kwargs can overwrite keys | ||||
|     b.update({'a': 2}, a=3) | ||||
|     assert b.a == 3 | ||||
|     with pytest.raises(AssertionError): | ||||
|         Batch({1: 2}) | ||||
|     batch = Batch(a=[torch.ones(3), torch.ones(3)]) | ||||
| @ -86,6 +96,18 @@ def test_batch(): | ||||
|     assert batch3.a.d.f[0] == 5.0 | ||||
|     with pytest.raises(KeyError): | ||||
|         batch3.a.d[0] = Batch(f=5.0, g=0.0) | ||||
|     # auto convert | ||||
|     batch4 = Batch(a=np.array(['a', 'b'])) | ||||
|     assert batch4.a.dtype == np.object  # auto convert to np.object | ||||
|     batch4.update(a=np.array(['c', 'd'])) | ||||
|     assert list(batch4.a) == ['c', 'd'] | ||||
|     assert batch4.a.dtype == np.object  # auto convert to np.object | ||||
|     batch5 = Batch(a=np.array([{'index': 0}])) | ||||
|     assert isinstance(batch5.a, Batch) | ||||
|     assert np.allclose(batch5.a.index, [0]) | ||||
|     batch5.b = np.array([{'index': 1}]) | ||||
|     assert isinstance(batch5.b, Batch) | ||||
|     assert np.allclose(batch5.b.index, [1]) | ||||
| 
 | ||||
| 
 | ||||
| def test_batch_over_batch(): | ||||
| @ -100,6 +122,11 @@ def test_batch_over_batch(): | ||||
|     assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) | ||||
|     assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) | ||||
|     assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) | ||||
|     batch2.update(batch2.b, six=[6, 6, 6]) | ||||
|     assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) | ||||
|     assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) | ||||
|     assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) | ||||
|     assert np.allclose(batch2.six, [6, 6, 6]) | ||||
|     d = {'a': [3, 4, 5], 'b': [4, 5, 6]} | ||||
|     batch3 = Batch(c=[6, 7, 8], b=d) | ||||
|     batch3.cat_(Batch(c=[6, 7, 8], b=d)) | ||||
| @ -124,18 +151,32 @@ def test_batch_over_batch(): | ||||
| 
 | ||||
| 
 | ||||
| def test_batch_cat_and_stack(): | ||||
|     # test cat with compatible keys | ||||
|     b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) | ||||
|     b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) | ||||
|     b12_cat_out = Batch.cat((b1, b2)) | ||||
|     b12_cat_out = Batch.cat([b1, b2]) | ||||
|     b12_cat_in = copy.deepcopy(b1) | ||||
|     b12_cat_in.cat_(b2) | ||||
|     assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) | ||||
|     assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) | ||||
|     assert isinstance(b12_cat_in.a.d.e, np.ndarray) | ||||
|     assert b12_cat_in.a.d.e.ndim == 1 | ||||
| 
 | ||||
|     b12_stack = Batch.stack((b1, b2)) | ||||
|     assert isinstance(b12_stack.a.d.e, np.ndarray) | ||||
|     assert b12_stack.a.d.e.ndim == 2 | ||||
| 
 | ||||
|     # test batch with incompatible keys | ||||
|     b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) | ||||
|     b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) | ||||
|     test = Batch.cat([b1, b2]) | ||||
|     ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), | ||||
|                 b=torch.cat([torch.zeros(3, 3), b2.b]), | ||||
|                 common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) | ||||
|     assert np.allclose(test.a, ans.a) | ||||
|     assert torch.allclose(test.b, ans.b) | ||||
|     assert np.allclose(test.common.c, ans.common.c) | ||||
| 
 | ||||
|     b3 = Batch(a=np.zeros((3, 4)), | ||||
|                b=torch.ones((2, 5)), | ||||
|                c=Batch(d=[[1], [2]])) | ||||
|  | ||||
| @ -259,8 +259,7 @@ class Batch: | ||||
|                         v_ = None | ||||
|                         if not isinstance(v, np.ndarray) and \ | ||||
|                                 all(isinstance(e, torch.Tensor) for e in v): | ||||
|                             v_ = torch.stack(v) | ||||
|                             self.__dict__[k] = v_ | ||||
|                             self.__dict__[k] = torch.stack(v) | ||||
|                             continue | ||||
|                         else: | ||||
|                             v_ = np.asanyarray(v) | ||||
| @ -294,7 +293,8 @@ class Batch: | ||||
|                 value = np.array(value) | ||||
|                 if not issubclass(value.dtype.type, (np.bool_, np.number)): | ||||
|                     value = value.astype(np.object) | ||||
|         elif isinstance(value, dict): | ||||
|         elif isinstance(value, dict) or isinstance(value, np.ndarray) \ | ||||
|                 and value.dtype == np.object and _is_batch_set(value): | ||||
|             value = Batch(value) | ||||
|         self.__dict__[key] = value | ||||
| 
 | ||||
| @ -333,9 +333,8 @@ class Batch: | ||||
|         else: | ||||
|             raise IndexError("Cannot access item from empty Batch object.") | ||||
| 
 | ||||
|     def __setitem__( | ||||
|             self, | ||||
|             index: Union[str, slice, int, np.integer, np.ndarray, List[int]], | ||||
|     def __setitem__(self, index: Union[ | ||||
|             str, slice, int, np.integer, np.ndarray, List[int]], | ||||
|             value: Any) -> None: | ||||
|         """Assign value to self[index].""" | ||||
|         if isinstance(value, np.ndarray): | ||||
| @ -454,10 +453,8 @@ class Batch: | ||||
|             elif isinstance(v, Batch): | ||||
|                 v.to_numpy() | ||||
| 
 | ||||
|     def to_torch(self, | ||||
|                  dtype: Optional[torch.dtype] = None, | ||||
|                  device: Union[str, int, torch.device] = 'cpu' | ||||
|                  ) -> None: | ||||
|     def to_torch(self, dtype: Optional[torch.dtype] = None, | ||||
|                  device: Union[str, int, torch.device] = 'cpu') -> None: | ||||
|         """Change all numpy.ndarray to torch.Tensor. This is an in-place | ||||
|         operation. | ||||
|         """ | ||||
| @ -473,66 +470,111 @@ class Batch: | ||||
|                     v = v.type(dtype) | ||||
|                 self.__dict__[k] = v | ||||
|             elif isinstance(v, torch.Tensor): | ||||
|                 if dtype is not None and v.dtype != dtype: | ||||
|                     must_update_tensor = True | ||||
|                 elif v.device.type != device.type: | ||||
|                     must_update_tensor = True | ||||
|                 elif device.index is not None and \ | ||||
|                 if dtype is not None and v.dtype != dtype or \ | ||||
|                         v.device.type != device.type or \ | ||||
|                         device.index is not None and \ | ||||
|                         device.index != v.device.index: | ||||
|                     must_update_tensor = True | ||||
|                 else: | ||||
|                     must_update_tensor = False | ||||
|                 if must_update_tensor: | ||||
|                     if dtype is not None: | ||||
|                         v = v.type(dtype) | ||||
|                     self.__dict__[k] = v.to(device) | ||||
|             elif isinstance(v, Batch): | ||||
|                 v.to_torch(dtype, device) | ||||
| 
 | ||||
|     def append(self, batch: 'Batch') -> None: | ||||
|         warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be ' | ||||
|                       'removed soon, please use ' | ||||
|                       ':meth:`~tianshou.data.Batch.cat`') | ||||
|         return self.cat_(batch) | ||||
| 
 | ||||
|     def cat_(self, batch: 'Batch') -> None: | ||||
|         """Concatenate a :class:`~tianshou.data.Batch` object into current | ||||
|         batch. | ||||
|     def cat_(self, | ||||
|              batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: | ||||
|         """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects | ||||
|         into current batch. | ||||
|         """ | ||||
|         assert isinstance(batch, Batch), \ | ||||
|             'Only Batch is allowed to be concatenated in-place!' | ||||
|         for k, v in batch.items(): | ||||
|             if v is None: | ||||
|                 continue | ||||
|             if not hasattr(self, k) or self.__dict__[k] is None: | ||||
|                 self.__dict__[k] = deepcopy(v) | ||||
|             elif isinstance(v, np.ndarray) and v.ndim > 0: | ||||
|                 self.__dict__[k] = np.concatenate([self.__dict__[k], v]) | ||||
|             elif isinstance(v, torch.Tensor): | ||||
|                 self.__dict__[k] = torch.cat([self.__dict__[k], v]) | ||||
|             elif isinstance(v, Batch): | ||||
|                 self.__dict__[k].cat_(v) | ||||
|         if isinstance(batches, Batch): | ||||
|             batches = [batches] | ||||
|         if len(batches) == 0: | ||||
|             return | ||||
|         batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] | ||||
|         if len(self.__dict__) > 0: | ||||
|             batches = [self] + list(batches) | ||||
|         # partial keys will be padded by zeros | ||||
|         # with the shape of [len, rest_shape] | ||||
|         lens = [len(x) for x in batches] | ||||
|         keys_map = list(map(lambda e: set(e.keys()), batches)) | ||||
|         keys_shared = set.intersection(*keys_map) | ||||
|         values_shared = [ | ||||
|             [e[k] for e in batches] for k in keys_shared] | ||||
|         _assert_type_keys(keys_shared) | ||||
|         for k, v in zip(keys_shared, values_shared): | ||||
|             if all(isinstance(e, (dict, Batch)) for e in v): | ||||
|                 self.__dict__[k] = Batch.cat(v) | ||||
|             elif all(isinstance(e, torch.Tensor) for e in v): | ||||
|                 self.__dict__[k] = torch.cat(v) | ||||
|             else: | ||||
|                 s = 'No support for method "cat" with type '\ | ||||
|                     f'{type(v)} in class Batch.' | ||||
|                 raise TypeError(s) | ||||
|                 v = np.concatenate(v) | ||||
|                 if not issubclass(v.dtype.type, (np.bool_, np.number)): | ||||
|                     v = v.astype(np.object) | ||||
|                 self.__dict__[k] = v | ||||
|         keys_partial = set.union(*keys_map) - keys_shared | ||||
|         _assert_type_keys(keys_partial) | ||||
|         for k in keys_partial: | ||||
|             is_dict = False | ||||
|             value = None | ||||
|             for i, e in enumerate(batches): | ||||
|                 val = e.get(k, None) | ||||
|                 if val is not None: | ||||
|                     if isinstance(val, (dict, Batch)): | ||||
|                         is_dict = True | ||||
|                     else:  # np.ndarray or torch.Tensor | ||||
|                         value = val | ||||
|                     break | ||||
|             if is_dict: | ||||
|                 self.__dict__[k] = Batch.cat( | ||||
|                     [e.get(k, Batch()) for e in batches]) | ||||
|             else: | ||||
|                 if isinstance(value, np.ndarray): | ||||
|                     arrs = [] | ||||
|                     for i, e in enumerate(batches): | ||||
|                         shape = [lens[i]] + list(value.shape[1:]) | ||||
|                         pad = np.zeros(shape, dtype=value.dtype) | ||||
|                         arrs.append(e.get(k, pad)) | ||||
|                     self.__dict__[k] = np.concatenate(arrs) | ||||
|                 elif isinstance(value, torch.Tensor): | ||||
|                     arrs = [] | ||||
|                     for i, e in enumerate(batches): | ||||
|                         shape = [lens[i]] + list(value.shape[1:]) | ||||
|                         pad = torch.zeros(shape, | ||||
|                                           dtype=value.dtype, | ||||
|                                           device=value.device) | ||||
|                         arrs.append(e.get(k, pad)) | ||||
|                     self.__dict__[k] = torch.cat(arrs) | ||||
|                 else: | ||||
|                     raise TypeError( | ||||
|                         f"cannot cat value with type {type(value)}, we only " | ||||
|                         "support dict, Batch, np.ndarray, and torch.Tensor") | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': | ||||
|         """Concatenate a list of :class:`~tianshou.data.Batch` object into a single | ||||
|         new batch. | ||||
|         """Concatenate a list of :class:`~tianshou.data.Batch` object into a | ||||
|         single new batch. For keys that are not shared across all batches, | ||||
|         batches that do not have these keys will be padded by zeros with | ||||
|         appropriate shapes. E.g. | ||||
|         :: | ||||
| 
 | ||||
|             >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) | ||||
|             >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) | ||||
|             >>> c = Batch.cat([a, b]) | ||||
|             >>> c.a.shape | ||||
|             (7, 4) | ||||
|             >>> c.b.shape | ||||
|             (7, 3) | ||||
|             >>> c.common.c.shape | ||||
|             (7, 5) | ||||
|         """ | ||||
|         batch = Batch() | ||||
|         for batch_ in batches: | ||||
|             if isinstance(batch_, dict): | ||||
|                 batch_ = Batch(batch_) | ||||
|             batch.cat_(batch_) | ||||
|         batch.cat_(batches) | ||||
|         return batch | ||||
| 
 | ||||
|     def stack_(self, | ||||
|                batches: List[Union[dict, 'Batch']], | ||||
|                axis: int = 0) -> None: | ||||
|         """Stack a :class:`~tianshou.data.Batch` object i into current batch. | ||||
|         """Stack a list of :class:`~tianshou.data.Batch` object into current | ||||
|         batch. | ||||
|         """ | ||||
|         if len(self.__dict__) > 0: | ||||
|             batches = [self] + list(batches) | ||||
| @ -566,8 +608,8 @@ class Batch: | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': | ||||
|         """Stack a :class:`~tianshou.data.Batch` object into a single new | ||||
|         batch. | ||||
|         """Stack a list of :class:`~tianshou.data.Batch` object into a single | ||||
|         new batch. | ||||
|         """ | ||||
|         batch = Batch() | ||||
|         batch.stack_(batches, axis) | ||||
| @ -611,11 +653,24 @@ class Batch: | ||||
|         """ | ||||
|         return deepcopy(batch).empty_(index) | ||||
| 
 | ||||
|     def update(self, batch: Optional[Union[dict, 'Batch']] = None, | ||||
|                **kwargs) -> None: | ||||
|         """Update this batch from another dict/Batch.""" | ||||
|         if batch is None: | ||||
|             self.update(kwargs) | ||||
|             return | ||||
|         if isinstance(batch, dict): | ||||
|             batch = Batch(batch) | ||||
|         for k, v in batch.items(): | ||||
|             self.__dict__[k] = v | ||||
|         if kwargs: | ||||
|             self.update(kwargs) | ||||
| 
 | ||||
|     def __len__(self) -> int: | ||||
|         """Return len(self).""" | ||||
|         r = [] | ||||
|         for v in self.__dict__.values(): | ||||
|             if isinstance(v, Batch) and len(v.__dict__) == 0: | ||||
|             if isinstance(v, Batch) and v.is_empty(): | ||||
|                 continue | ||||
|             elif hasattr(v, '__len__') and (not isinstance( | ||||
|                     v, (np.ndarray, torch.Tensor)) or v.ndim > 0): | ||||
| @ -627,7 +682,9 @@ class Batch: | ||||
|         return min(r) | ||||
| 
 | ||||
|     def is_empty(self): | ||||
|         return len(self.__dict__.keys()) == 0 | ||||
|         return not any( | ||||
|             not x.is_empty() if isinstance(x, Batch) | ||||
|             else hasattr(x, '__len__') and len(x) > 0 for x in self.values()) | ||||
| 
 | ||||
|     @property | ||||
|     def shape(self) -> List[int]: | ||||
|  | ||||
| @ -108,8 +108,7 @@ class ReplayBuffer: | ||||
|         super().__init__() | ||||
|         self._maxsize = size | ||||
|         self._stack = stack_num | ||||
|         assert stack_num != 1, \ | ||||
|             'stack_num should greater than 1' | ||||
|         assert stack_num != 1, 'stack_num should greater than 1' | ||||
|         self._avail = sample_avail and stack_num > 1 | ||||
|         self._avail_index = [] | ||||
|         self._save_s_ = not ignore_obs_next | ||||
| @ -136,12 +135,11 @@ class ReplayBuffer: | ||||
|         except KeyError: | ||||
|             self._meta.__dict__[name] = _create_value(inst, self._maxsize) | ||||
|             value = self._meta.__dict__[name] | ||||
|         if isinstance(inst, np.ndarray) and \ | ||||
|                 value.shape[1:] != inst.shape: | ||||
|         if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: | ||||
|             raise ValueError( | ||||
|                 "Cannot add data to a buffer with different shape, key: " | ||||
|                 f"{name}, expect shape: {value.shape[1:]}" | ||||
|                 f", given shape: {inst.shape}.") | ||||
|                 f"{name}, expect shape: {value.shape[1:]}, " | ||||
|                 f"given shape: {inst.shape}.") | ||||
|         try: | ||||
|             value[self._index] = inst | ||||
|         except KeyError: | ||||
| @ -357,7 +355,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|         self._weight_sum = 0.0 | ||||
|         self._amortization_freq = 50 | ||||
|         self._replace = replace | ||||
|         self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64) | ||||
|         self._meta.weight = np.zeros(size, dtype=np.float64) | ||||
| 
 | ||||
|     def add(self, | ||||
|             obs: Union[dict, np.ndarray], | ||||
| @ -372,7 +370,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|         """Add a batch of data into replay buffer.""" | ||||
|         # we have to sacrifice some convenience for speed | ||||
|         self._weight_sum += np.abs(weight) ** self._alpha - \ | ||||
|             self._meta.__dict__['weight'][self._index] | ||||
|             self._meta.weight[self._index] | ||||
|         self._add_to_buffer('weight', np.abs(weight) ** self._alpha) | ||||
|         super().add(obs, act, rew, done, obs_next, info, policy) | ||||
| 
 | ||||
| @ -410,14 +408,9 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|                 f"batch_size should be less than {len(self)}, \ | ||||
|                     or set replace=True") | ||||
|         batch = self[indice] | ||||
|         impt_weight = Batch( | ||||
|             impt_weight=(self._size * p) ** (-self._beta)) | ||||
|         batch.cat_(impt_weight) | ||||
|         batch["impt_weight"] = (self._size * p) ** (-self._beta) | ||||
|         return batch, indice | ||||
| 
 | ||||
|     def reset(self) -> None: | ||||
|         super().reset() | ||||
| 
 | ||||
|     def update_weight(self, indice: Union[slice, np.ndarray], | ||||
|                       new_weight: np.ndarray) -> None: | ||||
|         """Update priority weight by indice in this buffer. | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user