add first test
This commit is contained in:
		
							parent
							
								
									5550aed0a1
								
							
						
					
					
						commit
						7533e5b0ac
					
				
							
								
								
									
										2
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							| @ -30,7 +30,7 @@ jobs: | ||||
|     - name: Lint with flake8 | ||||
|       run: | | ||||
|         pip install flake8 | ||||
|         ./flake_check.sh | ||||
|         ./flake8.sh | ||||
|     - name: Test with pytest | ||||
|       run: | | ||||
|         pip install pytest | ||||
|  | ||||
							
								
								
									
										57
									
								
								test/test_env.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								test/test_env.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | ||||
| import gym | ||||
| import time | ||||
| import numpy as np | ||||
| from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv | ||||
| 
 | ||||
| 
 | ||||
| class MyTestEnv(gym.Env): | ||||
|     def __init__(self, size, sleep=0): | ||||
|         self.size = size | ||||
|         self.sleep = sleep | ||||
|         self.index = 0 | ||||
| 
 | ||||
|     def reset(self): | ||||
|         self.index = 0 | ||||
|         return self.index | ||||
| 
 | ||||
|     def step(self, action): | ||||
|         if self.sleep > 0: | ||||
|             time.sleep(self.sleep) | ||||
|         if self.index == self.size: | ||||
|             return self.index, 0, True, {} | ||||
|         if action == 0: | ||||
|             self.index = max(self.index - 1, 0) | ||||
|             return self.index, 0, False, {} | ||||
|         elif action == 1: | ||||
|             self.index += 1 | ||||
|             finished = self.index == self.size | ||||
|             return self.index, int(finished), finished, {} | ||||
| 
 | ||||
| def test_framestack(): | ||||
|     k = 4 | ||||
|     size = 10 | ||||
|     env = MyTestEnv(size=size) | ||||
|     fsenv = FrameStack(env, k) | ||||
|     fsenv.seed() | ||||
|     obs = fsenv.reset() | ||||
|     assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0 | ||||
|     for i in range(5): | ||||
|         obs, rew, done, info = fsenv.step(1) | ||||
|     assert abs(obs - np.array([2, 3, 4, 5])).sum() == 0 | ||||
|     for i in range(10): | ||||
|         obs, rew, done, info = fsenv.step(0) | ||||
|     assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0 | ||||
|     for i in range(9): | ||||
|         obs, rew, done, info = fsenv.step(1) | ||||
|     assert abs(obs - np.array([6, 7, 8, 9])).sum() == 0 | ||||
|     assert (rew, done) == (0, False) | ||||
|     obs, rew, done, info = fsenv.step(1) | ||||
|     assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0 | ||||
|     assert (rew, done) == (1, True) | ||||
|     obs, rew, done, info = fsenv.step(0) | ||||
|     assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0 | ||||
|     assert (rew, done) == (0, True) | ||||
|     fsenv.close() | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     test_framestack() | ||||
							
								
								
									
										12
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							| @ -50,7 +50,10 @@ class FrameStack(EnvWrapper): | ||||
|         return self._get_obs() | ||||
| 
 | ||||
|     def _get_obs(self): | ||||
|         return np.concatenate(self._frames, axis=-1) | ||||
|         try: | ||||
|             return np.concatenate(self._frames, axis=-1) | ||||
|         except ValueError: | ||||
|             return np.stack(self._frames, axis=-1) | ||||
| 
 | ||||
| 
 | ||||
| class VectorEnv(object): | ||||
| @ -177,11 +180,10 @@ class RayVectorEnv(object): | ||||
|         self.env_num = len(env_fns) | ||||
|         self._reset_after_done = kwargs.get('reset_after_done', False) | ||||
|         try: | ||||
|             import ray | ||||
|         except ImportError: | ||||
|             if not ray.is_initialized(): | ||||
|                 ray.init() | ||||
|         except NameError: | ||||
|             raise ImportError('Please install ray to support VectorEnv: pip3 install ray -U') | ||||
|         if not ray.is_initialized(): | ||||
|             ray.init() | ||||
|         self.envs = [ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] | ||||
| 
 | ||||
|     def __len__(self): | ||||
|  | ||||
| @ -1,3 +1,4 @@ | ||||
| from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper | ||||
| from tianshou.utils.config import tqdm_config | ||||
| 
 | ||||
| __all__ = ['CloudpickleWrapper'] | ||||
| __all__ = ['CloudpickleWrapper', 'tqdm_config'] | ||||
|  | ||||
							
								
								
									
										4
									
								
								tianshou/utils/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								tianshou/utils/config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | ||||
| tqdm_config = { | ||||
|     'dynamic_ncols': True, | ||||
|     'ascii': True, | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user