fix a bug
This commit is contained in:
		
							parent
							
								
									6632e47b9d
								
							
						
					
					
						commit
						4a1a7dd670
					
				| @ -1,18 +1,24 @@ | |||||||
| from tianshou.data import ReplayBuffer | from tianshou.data import ReplayBuffer | ||||||
| from test.test_env import MyTestEnv | if __name__ == '__main__': | ||||||
|  |     from test_env import MyTestEnv | ||||||
|  | else: | ||||||
|  |     from test.test_env import MyTestEnv | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_replaybuffer(bufsize=20): | def test_replaybuffer(size=10, bufsize=20): | ||||||
|     env = MyTestEnv(10) |     env = MyTestEnv(size) | ||||||
|     buf = ReplayBuffer(bufsize) |     buf = ReplayBuffer(bufsize) | ||||||
|     obs = env.reset() |     obs = env.reset() | ||||||
|     action_list = [1] * 5 + [0] * 10 + [1] * 9 |     action_list = [1] * 5 + [0] * 10 + [1] * 15 | ||||||
|     for i, a in enumerate(action_list): |     for i, a in enumerate(action_list): | ||||||
|         obs_next, rew, done, info = env.step(a) |         obs_next, rew, done, info = env.step(a) | ||||||
|         buf.add(obs, a, rew, done, obs_next, info) |         buf.add(obs, a, rew, done, obs_next, info) | ||||||
|         assert len(buf) == min(bufsize, i + 1), print(len(buf), i) |         assert len(buf) == min(bufsize, i + 1), print(len(buf), i) | ||||||
|     indice = buf.sample_indice(4) |     indice = buf.sample_indice(4) | ||||||
|     data = buf.sample(4) |     data = buf.sample(4) | ||||||
|  |     assert (indice < len(buf)).all() | ||||||
|  |     assert (data.obs < size).all() | ||||||
|  |     assert (0 <= data.done).all() and (data.done <= 1).all() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  | |||||||
| @ -59,11 +59,11 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): | |||||||
|         VectorEnv(env_fns, reset_after_done=True), |         VectorEnv(env_fns, reset_after_done=True), | ||||||
|         SubprocVectorEnv(env_fns, reset_after_done=True), |         SubprocVectorEnv(env_fns, reset_after_done=True), | ||||||
|     ] |     ] | ||||||
|     if verbose: |     if __name__ == '__main__': | ||||||
|         venv.append(RayVectorEnv(env_fns, reset_after_done=True)) |         venv.append(RayVectorEnv(env_fns, reset_after_done=True)) | ||||||
|     for v in venv: |     for v in venv: | ||||||
|         v.seed() |         v.seed() | ||||||
|     action_list = [1] * 5 + [0] * 10 + [1] * 9 |     action_list = [1] * 5 + [0] * 10 + [1] * 15 | ||||||
|     if not verbose: |     if not verbose: | ||||||
|         o = [v.reset() for v in venv] |         o = [v.reset() for v in venv] | ||||||
|         for i, a in enumerate(action_list): |         for i, a in enumerate(action_list): | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							| @ -75,9 +75,9 @@ class VectorEnv(object): | |||||||
|         result = zip(*[e.step(a) for e, a in zip(self.envs, action)]) |         result = zip(*[e.step(a) for e, a in zip(self.envs, action)]) | ||||||
|         obs, rew, done, info = result |         obs, rew, done, info = result | ||||||
|         if self._reset_after_done and sum(done): |         if self._reset_after_done and sum(done): | ||||||
|             for i, e in enumerate(self.envs): |             obs = np.stack(obs) | ||||||
|                 if done[i]: |             for i in np.where(done)[0]: | ||||||
|                     e.reset() |                 obs[i] = self.envs[i].reset() | ||||||
|         return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) |         return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) | ||||||
| 
 | 
 | ||||||
|     def seed(self, seed=None): |     def seed(self, seed=None): | ||||||
| @ -198,6 +198,14 @@ class RayVectorEnv(object): | |||||||
|         assert len(action) == self.env_num |         assert len(action) == self.env_num | ||||||
|         result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] |         result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] | ||||||
|         obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) |         obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) | ||||||
|  |         if self._reset_after_done and sum(done): | ||||||
|  |             obs = np.stack(obs) | ||||||
|  |             index = np.where(done)[0] | ||||||
|  |             result_obj = [] | ||||||
|  |             for i in range(len(index)): | ||||||
|  |                 result_obj.append(self.envs[index[i]].reset.remote()) | ||||||
|  |             for i in range(len(index)): | ||||||
|  |                 obs[index[i]] = ray.get(result_obj[i]) | ||||||
|         return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) |         return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) | ||||||
| 
 | 
 | ||||||
|     def reset(self): |     def reset(self): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user