| 
									
										
										
										
											2020-07-26 12:01:21 +02:00
										 |  |  | import random | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2021-04-25 15:23:46 +08:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import gym | 
					
						
							|  |  |  | import networkx as nx | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MyTestEnv(gym.Env): | 
					
						
							| 
									
										
										
										
											2020-07-13 00:24:31 +08:00
										 |  |  |     """This is a "going right" task. The task is to go right ``size`` steps.
 | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         size, | 
					
						
							|  |  |  |         sleep=0, | 
					
						
							|  |  |  |         dict_state=False, | 
					
						
							|  |  |  |         recurse_state=False, | 
					
						
							|  |  |  |         ma_rew=0, | 
					
						
							|  |  |  |         multidiscrete_action=False, | 
					
						
							|  |  |  |         random_sleep=False, | 
					
						
							|  |  |  |         array_state=False | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         assert dict_state + recurse_state + array_state <= 1, \ | 
					
						
							|  |  |  |             "dict_state / recurse_state / array_state can be only one true" | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         self.size = size | 
					
						
							|  |  |  |         self.sleep = sleep | 
					
						
							| 
									
										
										
										
											2020-07-26 12:01:21 +02:00
										 |  |  |         self.random_sleep = random_sleep | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |         self.dict_state = dict_state | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         self.recurse_state = recurse_state | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         self.array_state = array_state | 
					
						
							| 
									
										
										
										
											2020-07-13 00:24:31 +08:00
										 |  |  |         self.ma_rew = ma_rew | 
					
						
							| 
									
										
										
										
											2020-07-24 17:38:12 +08:00
										 |  |  |         self._md_action = multidiscrete_action | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         # how many steps this env has stepped | 
					
						
							|  |  |  |         self.steps = 0 | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         if dict_state: | 
					
						
							|  |  |  |             self.observation_space = Dict( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 { | 
					
						
							|  |  |  |                     "index": Box(shape=(1, ), low=0, high=size - 1), | 
					
						
							|  |  |  |                     "rand": Box(shape=(1, ), low=0, high=1, dtype=np.float64) | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         elif recurse_state: | 
					
						
							|  |  |  |             self.observation_space = Dict( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |                 { | 
					
						
							|  |  |  |                     "index": | 
					
						
							|  |  |  |                     Box(shape=(1, ), low=0, high=size - 1), | 
					
						
							|  |  |  |                     "dict": | 
					
						
							|  |  |  |                     Dict( | 
					
						
							|  |  |  |                         { | 
					
						
							|  |  |  |                             "tuple": | 
					
						
							|  |  |  |                             Tuple( | 
					
						
							|  |  |  |                                 ( | 
					
						
							|  |  |  |                                     Discrete(2), | 
					
						
							|  |  |  |                                     Box(shape=(2, ), low=0, high=1, dtype=np.float64) | 
					
						
							|  |  |  |                                 ) | 
					
						
							|  |  |  |                             ), | 
					
						
							|  |  |  |                             "rand": | 
					
						
							|  |  |  |                             Box(shape=(1, 2), low=0, high=1, dtype=np.float64) | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         elif array_state: | 
					
						
							|  |  |  |             self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             self.observation_space = Box(shape=(1, ), low=0, high=size - 1) | 
					
						
							| 
									
										
										
										
											2020-07-24 17:38:12 +08:00
										 |  |  |         if multidiscrete_action: | 
					
						
							|  |  |  |             self.action_space = MultiDiscrete([2, 2]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.action_space = Discrete(2) | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         self.done = False | 
					
						
							|  |  |  |         self.index = 0 | 
					
						
							|  |  |  |         self.seed() | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-24 17:38:12 +08:00
										 |  |  |     def seed(self, seed=0): | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         self.rng = np.random.RandomState(seed) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |         return [seed] | 
					
						
							| 
									
										
										
										
											2020-07-24 17:38:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     def reset(self, state=0): | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         self.done = False | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |         self.index = state | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         return self._get_state() | 
					
						
							| 
									
										
										
										
											2020-07-13 00:24:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def _get_reward(self): | 
					
						
							|  |  |  |         """Generate a non-scalar reward if ma_rew is True.""" | 
					
						
							|  |  |  |         x = int(self.done) | 
					
						
							|  |  |  |         if self.ma_rew > 0: | 
					
						
							|  |  |  |             return [x] * self.ma_rew | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |     def _get_state(self): | 
					
						
							|  |  |  |         """Generate state(observation) of MyTestEnv""" | 
					
						
							|  |  |  |         if self.dict_state: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             return { | 
					
						
							|  |  |  |                 'index': np.array([self.index], dtype=np.float32), | 
					
						
							|  |  |  |                 'rand': self.rng.rand(1) | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         elif self.recurse_state: | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             return { | 
					
						
							|  |  |  |                 'index': np.array([self.index], dtype=np.float32), | 
					
						
							|  |  |  |                 'dict': { | 
					
						
							|  |  |  |                     "tuple": (np.array([1], dtype=int), self.rng.rand(2)), | 
					
						
							|  |  |  |                     "rand": self.rng.rand(1, 2) | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |         elif self.array_state: | 
					
						
							| 
									
										
										
										
											2021-03-31 15:14:22 +08:00
										 |  |  |             img = np.zeros([4, 84, 84], int) | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |             img[3, np.arange(84), np.arange(84)] = self.index | 
					
						
							|  |  |  |             img[2, np.arange(84)] = self.index | 
					
						
							|  |  |  |             img[1, :, np.arange(84)] = self.index | 
					
						
							|  |  |  |             img[0] = self.index | 
					
						
							|  |  |  |             return img | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             return np.array([self.index], dtype=np.float32) | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def step(self, action): | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         self.steps += 1 | 
					
						
							| 
									
										
										
										
											2020-07-24 17:38:12 +08:00
										 |  |  |         if self._md_action: | 
					
						
							|  |  |  |             action = action[0] | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         if self.done: | 
					
						
							|  |  |  |             raise ValueError('step after done !!!') | 
					
						
							|  |  |  |         if self.sleep > 0: | 
					
						
							| 
									
										
										
										
											2020-07-26 12:01:21 +02:00
										 |  |  |             sleep_time = random.random() if self.random_sleep else 1 | 
					
						
							|  |  |  |             sleep_time *= self.sleep | 
					
						
							|  |  |  |             time.sleep(sleep_time) | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         if self.index == self.size: | 
					
						
							|  |  |  |             self.done = True | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |             return self._get_state(), self._get_reward(), self.done, {} | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         if action == 0: | 
					
						
							|  |  |  |             self.index = max(self.index - 1, 0) | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |             return self._get_state(), self._get_reward(), self.done, \ | 
					
						
							| 
									
										
										
										
											2020-07-13 00:24:31 +08:00
										 |  |  |                 {'key': 1, 'env': self} if self.dict_state else {} | 
					
						
							| 
									
										
										
										
											2020-03-21 10:58:01 +08:00
										 |  |  |         elif action == 1: | 
					
						
							|  |  |  |             self.index += 1 | 
					
						
							|  |  |  |             self.done = self.index == self.size | 
					
						
							| 
									
										
										
										
											2020-08-04 13:39:05 +08:00
										 |  |  |             return self._get_state(), self._get_reward(), \ | 
					
						
							| 
									
										
										
										
											2020-07-13 00:24:31 +08:00
										 |  |  |                 self.done, {'key': 1, 'env': self} | 
					
						
							| 
									
										
										
										
											2021-04-25 15:23:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class NXEnv(gym.Env): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-25 15:23:46 +08:00
										 |  |  |     def __init__(self, size, obs_type, feat_dim=32): | 
					
						
							|  |  |  |         self.size = size | 
					
						
							|  |  |  |         self.feat_dim = feat_dim | 
					
						
							|  |  |  |         self.graph = nx.Graph() | 
					
						
							|  |  |  |         self.graph.add_nodes_from(list(range(size))) | 
					
						
							|  |  |  |         assert obs_type in ["array", "object"] | 
					
						
							|  |  |  |         self.obs_type = obs_type | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _encode_obs(self): | 
					
						
							|  |  |  |         if self.obs_type == "array": | 
					
						
							|  |  |  |             return np.stack([v["data"] for v in self.graph._node.values()]) | 
					
						
							|  |  |  |         return deepcopy(self.graph) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self): | 
					
						
							|  |  |  |         graph_state = np.random.rand(self.size, self.feat_dim) | 
					
						
							|  |  |  |         for i in range(self.size): | 
					
						
							|  |  |  |             self.graph.nodes[i]["data"] = graph_state[i] | 
					
						
							|  |  |  |         return self._encode_obs() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def step(self, action): | 
					
						
							|  |  |  |         next_graph_state = np.random.rand(self.size, self.feat_dim) | 
					
						
							|  |  |  |         for i in range(self.size): | 
					
						
							|  |  |  |             self.graph.nodes[i]["data"] = next_graph_state[i] | 
					
						
							|  |  |  |         return self._encode_obs(), 1.0, 0, {} |