add test_buffer
This commit is contained in:
		
							parent
							
								
									04557fdb82
								
							
						
					
					
						commit
						6632e47b9d
					
				
							
								
								
									
										5
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							| @ -30,7 +30,10 @@ jobs: | ||||
|     - name: Lint with flake8 | ||||
|       run: | | ||||
|         pip install flake8 | ||||
|         ./flake8.sh | ||||
|         # stop the build if there are Python syntax errors or undefined names | ||||
|         flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | ||||
|         # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide | ||||
|         flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | ||||
|     - name: Test with pytest | ||||
|       run: | | ||||
|         pip install pytest pytest-cov | ||||
|  | ||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							| @ -32,6 +32,7 @@ setup( | ||||
|         # that you indicate whether you support Python 2, Python 3 or both. | ||||
|         'Programming Language :: Python :: 3.6', | ||||
|         'Programming Language :: Python :: 3.7', | ||||
|         'Programming Language :: Python :: 3.8', | ||||
|     ], | ||||
|     keywords='reinforcement learning platform', | ||||
|     # You can just specify the packages manually here if your project is | ||||
|  | ||||
							
								
								
									
										0
									
								
								test/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								test/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										19
									
								
								test/test_buffer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								test/test_buffer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| from tianshou.data import ReplayBuffer | ||||
| from test.test_env import MyTestEnv | ||||
| 
 | ||||
| 
 | ||||
| def test_replaybuffer(bufsize=20): | ||||
|     env = MyTestEnv(10) | ||||
|     buf = ReplayBuffer(bufsize) | ||||
|     obs = env.reset() | ||||
|     action_list = [1] * 5 + [0] * 10 + [1] * 9 | ||||
|     for i, a in enumerate(action_list): | ||||
|         obs_next, rew, done, info = env.step(a) | ||||
|         buf.add(obs, a, rew, done, obs_next, info) | ||||
|         assert len(buf) == min(bufsize, i + 1), print(len(buf), i) | ||||
|     indice = buf.sample_indice(4) | ||||
|     data = buf.sample(4) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     test_replaybuffer() | ||||
| @ -80,6 +80,8 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): | ||||
|                 e.step([a] * num) | ||||
|             t[i] = time.time() - t[i] | ||||
|         print(f'VectorEnv: {t[0]:.6f}s\nSubprocVectorEnv: {t[1]:.6f}s\nRayVectorEnv: {t[2]:.6f}s') | ||||
|     for v in venv: | ||||
|         v.close() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| class Batch(object): | ||||
|     """Suggested keys: [obs, act, rew, done, obs_next, info]""" | ||||
| 
 | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__() | ||||
|         self.__dict__.update(kwargs) | ||||
|  | ||||
| @ -7,7 +7,7 @@ class ReplayBuffer(object): | ||||
|     def __init__(self, size): | ||||
|         super().__init__() | ||||
|         self._maxsize = size | ||||
|         self._index = self._size = 0 | ||||
|         self.reset() | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self._size | ||||
| @ -45,7 +45,7 @@ class ReplayBuffer(object): | ||||
|         return np.random.choice(self._size, batch_size) | ||||
| 
 | ||||
|     def sample(self, batch_size): | ||||
|         indice = self.sample_index(batch_size) | ||||
|         indice = self.sample_indice(batch_size) | ||||
|         return Batch( | ||||
|             obs=self.obs[indice], | ||||
|             act=self.act[indice], | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user