buffer update bug fix (#154)
* buffer update bug fix * some fix in buffer update * polish Co-authored-by: n+e <463003665@qq.com>
This commit is contained in:
		
							parent
							
								
									fe5555d2a1
								
							
						
					
					
						commit
						d09b69e594
					
				| @ -1,5 +1,6 @@ | ||||
| import numpy as np | ||||
| from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer | ||||
| 
 | ||||
| from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     from env import MyTestEnv | ||||
| @ -10,7 +11,6 @@ else:  # pytest | ||||
| def test_replaybuffer(size=10, bufsize=20): | ||||
|     env = MyTestEnv(size) | ||||
|     buf = ReplayBuffer(bufsize) | ||||
|     buf2 = ReplayBuffer(bufsize) | ||||
|     obs = env.reset() | ||||
|     action_list = [1] * 5 + [0] * 10 + [1] * 10 | ||||
|     for i, a in enumerate(action_list): | ||||
| @ -22,11 +22,6 @@ def test_replaybuffer(size=10, bufsize=20): | ||||
|     assert (indice < len(buf)).all() | ||||
|     assert (data.obs < size).all() | ||||
|     assert (0 <= data.done).all() and (data.done <= 1).all() | ||||
|     assert len(buf) > len(buf2) | ||||
|     buf2.update(buf) | ||||
|     assert len(buf) == len(buf2) | ||||
|     assert buf2[0].obs == buf[5].obs | ||||
|     assert buf2[-1].obs == buf[4].obs | ||||
|     b = ReplayBuffer(size=10) | ||||
|     b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) | ||||
|     assert b.obs[0] == 1 | ||||
| @ -104,8 +99,22 @@ def test_priortized_replaybuffer(size=32, bufsize=15): | ||||
|         buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) | ||||
| 
 | ||||
| 
 | ||||
| def test_update(): | ||||
|     buf1 = ReplayBuffer(4, stack_num=2) | ||||
|     buf2 = ReplayBuffer(4, stack_num=2) | ||||
|     for i in range(5): | ||||
|         buf1.add(obs=np.array([i]), act=float(i), rew=i * i, | ||||
|                  done=False, info={'incident': 'found'}) | ||||
|     assert len(buf1) > len(buf2) | ||||
|     buf2.update(buf1) | ||||
|     assert len(buf1) == len(buf2) | ||||
|     assert (buf2[0].obs == buf1[1].obs).all() | ||||
|     assert (buf2[-1].obs == buf1[0].obs).all() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     test_replaybuffer() | ||||
|     test_ignore_obs_next() | ||||
|     test_stack() | ||||
|     test_priortized_replaybuffer(233333, 200000) | ||||
|     test_update() | ||||
|  | ||||
| @ -157,16 +157,25 @@ class ReplayBuffer: | ||||
|                 value.__dict__[key] = _create_value(inst[key], self._maxsize) | ||||
|             value[self._index] = inst | ||||
| 
 | ||||
|     def _get_stack_num(self): | ||||
|         return self._stack | ||||
| 
 | ||||
|     def _set_stack_num(self, num): | ||||
|         self._stack = num | ||||
| 
 | ||||
|     def update(self, buffer: 'ReplayBuffer') -> None: | ||||
|         """Move the data from the given buffer to self.""" | ||||
|         if len(buffer) == 0: | ||||
|             return | ||||
|         i = begin = buffer._index % len(buffer) | ||||
|         origin = buffer._get_stack_num() | ||||
|         buffer._set_stack_num(0) | ||||
|         while True: | ||||
|             self.add(**buffer[i]) | ||||
|             i = (i + 1) % len(buffer) | ||||
|             if i == begin: | ||||
|                 break | ||||
|         buffer._set_stack_num(origin) | ||||
| 
 | ||||
|     def add(self, | ||||
|             obs: Union[dict, Batch, np.ndarray], | ||||
| @ -408,7 +417,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): | ||||
|                 replace=self._replace) | ||||
|             p = p[indice]  # weight of each sample | ||||
|         elif batch_size == 0: | ||||
|             p = np.full(shape=self._size, fill_value=1.0/self._size) | ||||
|             p = np.full(shape=self._size, fill_value=1.0 / self._size) | ||||
|             indice = np.concatenate([ | ||||
|                 np.arange(self._index, self._size), | ||||
|                 np.arange(0, self._index), | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user