bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer (#120)
* bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer * point out that ListReplayBuffer cannot be sampled * remove useless _amortization_counter variable
This commit is contained in:
		
							parent
							
								
									e767de044b
								
							
						
					
					
						commit
						ff99662fe6
					
				| @ -92,20 +92,16 @@ def test_priortized_replaybuffer(size=32, bufsize=15): | ||||
|         obs_next, rew, done, info = env.step(a) | ||||
|         buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) | ||||
|         obs = obs_next | ||||
|         assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]), | ||||
|                           1, rtol=1e-12) | ||||
|         data, indice = buf.sample(len(buf) // 2) | ||||
|         if len(buf) // 2 == 0: | ||||
|             assert len(data) == len(buf) | ||||
|         else: | ||||
|             assert len(data) == len(buf) // 2 | ||||
|         assert len(buf) == min(bufsize, i + 1) | ||||
|         assert np.isclose(buf._weight_sum, (buf.weight).sum()) | ||||
|     data, indice = buf.sample(len(buf) // 2) | ||||
|     buf.update_weight(indice, -data.weight / 2) | ||||
|     assert np.isclose(buf.weight[indice], np.power( | ||||
|         np.abs(-data.weight / 2), buf._alpha)).all() | ||||
|     assert np.isclose(buf._weight_sum, (buf.weight).sum()) | ||||
|     assert np.allclose( | ||||
|         buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  | ||||
| @ -151,6 +151,8 @@ class ReplayBuffer: | ||||
| 
 | ||||
|     def update(self, buffer: 'ReplayBuffer') -> None: | ||||
|         """Move the data from the given buffer to self.""" | ||||
|         if len(buffer) == 0: | ||||
|             return | ||||
|         i = begin = buffer._index % len(buffer) | ||||
|         while True: | ||||
|             self.add(**buffer[i]) | ||||
| @ -298,7 +300,9 @@ class ReplayBuffer: | ||||
| class ListReplayBuffer(ReplayBuffer): | ||||
|     """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the | ||||
|     same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that | ||||
|     :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. | ||||
|     :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore, | ||||
|     it does not support advanced indexing, which means you cannot sample a | ||||
|     batch of data out of it. It is typically used for storing data. | ||||
| 
 | ||||
|     .. seealso:: | ||||
| 
 | ||||
| @ -309,6 +313,9 @@ class ListReplayBuffer(ReplayBuffer): | ||||
|     def __init__(self, **kwargs) -> None: | ||||
|         super().__init__(size=0, ignore_obs_next=False, **kwargs) | ||||
| 
 | ||||
|     def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: | ||||
|         raise NotImplementedError("ListReplayBuffer cannot be sampled!") | ||||
| 
 | ||||
|     def _add_to_buffer( | ||||
|             self, name: str, | ||||
|             inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: | ||||
| @ -349,7 +356,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|         self._beta = beta | ||||
|         self._weight_sum = 0.0 | ||||
|         self._amortization_freq = 50 | ||||
|         self._amortization_counter = 0 | ||||
|         self._replace = replace | ||||
|         self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64) | ||||
| 
 | ||||
| @ -369,7 +375,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|             self._meta.__dict__['weight'][self._index] | ||||
|         self._add_to_buffer('weight', np.abs(weight) ** self._alpha) | ||||
|         super().add(obs, act, rew, done, obs_next, info, policy) | ||||
|         self._check_weight_sum() | ||||
| 
 | ||||
|     @property | ||||
|     def replace(self): | ||||
| @ -379,46 +384,38 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|     def replace(self, v: bool): | ||||
|         self._replace = v | ||||
| 
 | ||||
|     def sample(self, batch_size: int, | ||||
|                importance_sample: bool = True | ||||
|                ) -> Tuple[Batch, np.ndarray]: | ||||
|     def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: | ||||
|         """Get a random sample from buffer with priority probability. \ | ||||
|         Return all the data in the buffer if batch_size is ``0``. | ||||
| 
 | ||||
|         :return: Sample data and its corresponding index inside the buffer. | ||||
|         """ | ||||
|         if batch_size > 0 and batch_size <= self._size: | ||||
|             # Multiple sampling of the same sample | ||||
|             # will cause weight update conflict | ||||
|         assert self._size > 0, 'cannot sample a buffer with size == 0 !' | ||||
|         p = None | ||||
|         if batch_size > 0 and (self._replace or batch_size <= self._size): | ||||
|             # sampling weight | ||||
|             p = (self.weight / self.weight.sum())[:self._size] | ||||
|             indice = np.random.choice( | ||||
|                 self._size, batch_size, | ||||
|                 p=(self.weight / self.weight.sum())[:self._size], | ||||
|                 self._size, batch_size, p=p, | ||||
|                 replace=self._replace) | ||||
|             # self._weight_sum is not work for the accuracy issue | ||||
|             # p=(self.weight/self._weight_sum)[:self._size], replace=False) | ||||
|             p = p[indice]  # weight of each sample | ||||
|         elif batch_size == 0: | ||||
|             p = np.full(shape=self._size, fill_value=1.0/self._size) | ||||
|             indice = np.concatenate([ | ||||
|                 np.arange(self._index, self._size), | ||||
|                 np.arange(0, self._index), | ||||
|             ]) | ||||
|         else: | ||||
|             # if batch_size larger than len(self), | ||||
|             # it will lead to a bug in update weight | ||||
|             raise ValueError( | ||||
|                 "batch_size should be less than len(self), \ | ||||
|                     or set replace=False") | ||||
|                 f"batch_size should be less than {len(self)}, \ | ||||
|                     or set replace=True") | ||||
|         batch = self[indice] | ||||
|         if importance_sample: | ||||
|             impt_weight = Batch( | ||||
|                 impt_weight=1 / np.power( | ||||
|                     self._size * (batch.weight / self._weight_sum), | ||||
|                     self._beta)) | ||||
|             batch.cat_(impt_weight) | ||||
|         self._check_weight_sum() | ||||
|         impt_weight = Batch( | ||||
|             impt_weight=(self._size * p) ** (-self._beta)) | ||||
|         batch.cat_(impt_weight) | ||||
|         return batch, indice | ||||
| 
 | ||||
|     def reset(self) -> None: | ||||
|         self._amortization_counter = 0 | ||||
|         super().reset() | ||||
| 
 | ||||
|     def update_weight(self, indice: Union[slice, np.ndarray], | ||||
| @ -436,8 +433,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|             indice, unique_indice = np.unique( | ||||
|                 indice, return_index=True) | ||||
|             new_weight = new_weight[unique_indice] | ||||
|         self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \ | ||||
|             - self.weight[indice].sum() | ||||
|         self.weight[indice] = np.power(np.abs(new_weight), self._alpha) | ||||
| 
 | ||||
|     def __getitem__(self, index: Union[ | ||||
| @ -452,10 +447,3 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|             weight=self.weight[index], | ||||
|             policy=self.get(index, 'policy'), | ||||
|         ) | ||||
| 
 | ||||
|     def _check_weight_sum(self) -> None: | ||||
|         # keep an accurate _weight_sum | ||||
|         self._amortization_counter += 1 | ||||
|         if self._amortization_counter % self._amortization_freq == 0: | ||||
|             self._weight_sum = np.sum(self.weight) | ||||
|             self._amortization_counter = 0 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user