| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | # see issue #322 for detail | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import copy | 
					
						
							|  |  |  | from collections import Counter | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  | import gymnasium as gym | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  | from gymnasium.spaces import Box | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from torch.utils.data import DataLoader, Dataset, DistributedSampler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch, Collector | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from tianshou.policy import BasePolicy | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DummyDataset(Dataset): | 
					
						
							|  |  |  |     def __init__(self, length): | 
					
						
							|  |  |  |         self.length = length | 
					
						
							|  |  |  |         self.episodes = [3 * i % 5 + 1 for i in range(self.length)] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __getitem__(self, index): | 
					
						
							|  |  |  |         assert 0 <= index < self.length | 
					
						
							|  |  |  |         return index, self.episodes[index] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __len__(self): | 
					
						
							|  |  |  |         return self.length | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FiniteEnv(gym.Env): | 
					
						
							|  |  |  |     def __init__(self, dataset, num_replicas, rank): | 
					
						
							|  |  |  |         self.dataset = dataset | 
					
						
							|  |  |  |         self.num_replicas = num_replicas | 
					
						
							|  |  |  |         self.rank = rank | 
					
						
							|  |  |  |         self.loader = DataLoader( | 
					
						
							|  |  |  |             dataset, | 
					
						
							|  |  |  |             sampler=DistributedSampler(dataset, num_replicas, rank), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             batch_size=None, | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         self.iterator = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self): | 
					
						
							|  |  |  |         if self.iterator is None: | 
					
						
							|  |  |  |             self.iterator = iter(self.loader) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             self.current_sample, self.step_count = next(self.iterator) | 
					
						
							|  |  |  |             self.current_step = 0 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |             return self.current_sample, {} | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         except StopIteration: | 
					
						
							|  |  |  |             self.iterator = None | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |             return None, {} | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def step(self, action): | 
					
						
							|  |  |  |         self.current_step += 1 | 
					
						
							|  |  |  |         assert self.current_step <= self.step_count | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         return ( | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             1.0, | 
					
						
							|  |  |  |             self.current_step >= self.step_count, | 
					
						
							|  |  |  |             False, | 
					
						
							|  |  |  |             {"sample": self.current_sample, "action": action, "metric": 2.0}, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FiniteVectorEnv(BaseVectorEnv): | 
					
						
							|  |  |  |     def __init__(self, env_fns, **kwargs): | 
					
						
							|  |  |  |         super().__init__(env_fns, **kwargs) | 
					
						
							|  |  |  |         self._alive_env_ids = set() | 
					
						
							|  |  |  |         self._reset_alive_envs() | 
					
						
							|  |  |  |         self._default_obs = self._default_info = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _reset_alive_envs(self): | 
					
						
							|  |  |  |         if not self._alive_env_ids: | 
					
						
							|  |  |  |             # starting or running out | 
					
						
							|  |  |  |             self._alive_env_ids = set(range(self.env_num)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # to workaround with tianshou's buffer and batch | 
					
						
							|  |  |  |     def _set_default_obs(self, obs): | 
					
						
							|  |  |  |         if obs is not None and self._default_obs is None: | 
					
						
							|  |  |  |             self._default_obs = copy.deepcopy(obs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _set_default_info(self, info): | 
					
						
							|  |  |  |         if info is not None and self._default_info is None: | 
					
						
							|  |  |  |             self._default_info = copy.deepcopy(info) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_default_obs(self): | 
					
						
							|  |  |  |         return copy.deepcopy(self._default_obs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_default_info(self): | 
					
						
							|  |  |  |         return copy.deepcopy(self._default_info) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |     # END | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self, id=None): | 
					
						
							|  |  |  |         id = self._wrap_id(id) | 
					
						
							|  |  |  |         self._reset_alive_envs() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # ask super to reset alive envs and remap to current index | 
					
						
							|  |  |  |         request_id = list(filter(lambda i: i in self._alive_env_ids, id)) | 
					
						
							|  |  |  |         obs = [None] * len(id) | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |         infos = [None] * len(id) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         id2idx = {i: k for k, i in enumerate(id)} | 
					
						
							|  |  |  |         if request_id: | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |             for k, o, info in zip(request_id, *super().reset(request_id), strict=True): | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |                 obs[id2idx[k]] = o | 
					
						
							|  |  |  |                 infos[id2idx[k]] = info | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |         for i, o in zip(id, obs, strict=True): | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |             if o is None and i in self._alive_env_ids: | 
					
						
							|  |  |  |                 self._alive_env_ids.remove(i) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # fill empty observation with default(fake) observation | 
					
						
							|  |  |  |         for o in obs: | 
					
						
							|  |  |  |             self._set_default_obs(o) | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         for i in range(len(obs)): | 
					
						
							|  |  |  |             if obs[i] is None: | 
					
						
							|  |  |  |                 obs[i] = self._get_default_obs() | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |             if infos[i] is None: | 
					
						
							|  |  |  |                 infos[i] = self._get_default_info() | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if not self._alive_env_ids: | 
					
						
							|  |  |  |             self.reset() | 
					
						
							|  |  |  |             raise StopIteration | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |         return np.stack(obs), infos | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def step(self, action, id=None): | 
					
						
							|  |  |  |         id = self._wrap_id(id) | 
					
						
							|  |  |  |         id2idx = {i: k for k, i in enumerate(id)} | 
					
						
							|  |  |  |         request_id = list(filter(lambda i: i in self._alive_env_ids, id)) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         result = [[None, 0.0, False, False, None] for _ in range(len(id))] | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # ask super to step alive envs and remap to current index | 
					
						
							|  |  |  |         if request_id: | 
					
						
							|  |  |  |             valid_act = np.stack([action[id2idx[i]] for i in request_id]) | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |             for i, r in zip( | 
					
						
							|  |  |  |                 request_id, | 
					
						
							|  |  |  |                 zip(*super().step(valid_act, request_id), strict=True), | 
					
						
							|  |  |  |                 strict=True, | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |                 result[id2idx[i]] = r | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # logging | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |         for i, r in zip(id, result, strict=True): | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |             if i in self._alive_env_ids: | 
					
						
							|  |  |  |                 self.tracker.log(*r) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # fill empty observation/info with default(fake) | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |         for _, __, ___, ____, i in result: | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |             self._set_default_info(i) | 
					
						
							|  |  |  |         for i in range(len(result)): | 
					
						
							|  |  |  |             if result[i][0] is None: | 
					
						
							|  |  |  |                 result[i][0] = self._get_default_obs() | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |             if result[i][-1] is None: | 
					
						
							|  |  |  |                 result[i][-1] = self._get_default_info() | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |         return list(map(np.stack, zip(*result, strict=True))) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class AnyPolicy(BasePolicy): | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |     def __init__(self): | 
					
						
							|  |  |  |         super().__init__(action_space=Box(-1, 1, (1,))) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |     def forward(self, batch, state=None): | 
					
						
							|  |  |  |         return Batch(act=np.stack([1] * len(batch))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def learn(self, batch): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _finite_env_factory(dataset, num_replicas, rank): | 
					
						
							|  |  |  |     return lambda: FiniteEnv(dataset, num_replicas, rank) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MetricTracker: | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         self.counter = Counter() | 
					
						
							|  |  |  |         self.finished = set() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |     def log(self, obs, rew, terminated, truncated, info): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         assert rew == 1.0 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |         done = terminated or truncated | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         index = info["sample"] | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         if done: | 
					
						
							|  |  |  |             assert index not in self.finished | 
					
						
							|  |  |  |             self.finished.add(index) | 
					
						
							|  |  |  |         self.counter[index] += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def validate(self): | 
					
						
							|  |  |  |         assert len(self.finished) == 100 | 
					
						
							|  |  |  |         for k, v in self.counter.items(): | 
					
						
							|  |  |  |             assert v == k * 3 % 5 + 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_finite_dummy_vector_env(): | 
					
						
							|  |  |  |     dataset = DummyDataset(100) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |     policy = AnyPolicy() | 
					
						
							|  |  |  |     test_collector = Collector(policy, envs, exploration_noise=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for _ in range(3): | 
					
						
							|  |  |  |         envs.tracker = MetricTracker() | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             test_collector.collect(n_step=10**18) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         except StopIteration: | 
					
						
							|  |  |  |             envs.tracker.validate() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_finite_subproc_vector_env(): | 
					
						
							|  |  |  |     dataset = DummyDataset(100) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |     policy = AnyPolicy() | 
					
						
							|  |  |  |     test_collector = Collector(policy, envs, exploration_noise=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for _ in range(3): | 
					
						
							|  |  |  |         envs.tracker = MetricTracker() | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             test_collector.collect(n_step=10**18) | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |         except StopIteration: | 
					
						
							|  |  |  |             envs.tracker.validate() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2021-03-25 22:59:34 +08:00
										 |  |  |     test_finite_dummy_vector_env() | 
					
						
							|  |  |  |     test_finite_subproc_vector_env() |