| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  | import cv2 | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  | import gymnasium as gym | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | import vizdoom as vzd | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-08 12:42:16 -04:00
										 |  |  | from tianshou.env import ShmemVectorEnv | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     import envpool | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     envpool = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def normal_button_comb(): | 
					
						
							|  |  |  |     actions = [] | 
					
						
							|  |  |  |     m_forward = [[0.0], [1.0]] | 
					
						
							|  |  |  |     t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] | 
					
						
							|  |  |  |     for i in m_forward: | 
					
						
							|  |  |  |         for j in t_left_right: | 
					
						
							|  |  |  |             actions.append(i + j) | 
					
						
							|  |  |  |     return actions | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def battle_button_comb(): | 
					
						
							|  |  |  |     actions = [] | 
					
						
							|  |  |  |     m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] | 
					
						
							|  |  |  |     m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] | 
					
						
							|  |  |  |     t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] | 
					
						
							|  |  |  |     attack = [[0.0], [1.0]] | 
					
						
							|  |  |  |     speed = [[0.0], [1.0]] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for m in attack: | 
					
						
							|  |  |  |         for n in speed: | 
					
						
							|  |  |  |             for j in m_left_right: | 
					
						
							|  |  |  |                 for i in m_forward_backward: | 
					
						
							|  |  |  |                     for k in t_left_right: | 
					
						
							|  |  |  |                         actions.append(i + j + k + m + n) | 
					
						
							|  |  |  |     return actions | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Env(gym.Env): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False): | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.save_lmp = save_lmp | 
					
						
							|  |  |  |         self.health_setting = "battle" in cfg_path | 
					
						
							|  |  |  |         if save_lmp: | 
					
						
							|  |  |  |             os.makedirs("lmps", exist_ok=True) | 
					
						
							|  |  |  |         self.res = res | 
					
						
							|  |  |  |         self.skip = frameskip | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         self.observation_space = gym.spaces.Box(low=0, high=255, shape=res, dtype=np.float32) | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |         self.game = vzd.DoomGame() | 
					
						
							|  |  |  |         self.game.load_config(cfg_path) | 
					
						
							|  |  |  |         self.game.init() | 
					
						
							|  |  |  |         if "battle" in cfg_path: | 
					
						
							|  |  |  |             self.available_actions = battle_button_comb() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.available_actions = normal_button_comb() | 
					
						
							|  |  |  |         self.action_num = len(self.available_actions) | 
					
						
							|  |  |  |         self.action_space = gym.spaces.Discrete(self.action_num) | 
					
						
							|  |  |  |         self.spec = gym.envs.registration.EnvSpec("vizdoom-v0") | 
					
						
							|  |  |  |         self.count = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_obs(self): | 
					
						
							|  |  |  |         state = self.game.get_state() | 
					
						
							|  |  |  |         if state is None: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         obs = state.screen_buffer | 
					
						
							|  |  |  |         self.obs_buffer[:-1] = self.obs_buffer[1:] | 
					
						
							|  |  |  |         self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2])) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self): | 
					
						
							|  |  |  |         if self.save_lmp: | 
					
						
							|  |  |  |             self.game.new_episode(f"lmps/episode_{self.count}.lmp") | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.game.new_episode() | 
					
						
							|  |  |  |         self.count += 1 | 
					
						
							|  |  |  |         self.obs_buffer = np.zeros(self.res, dtype=np.uint8) | 
					
						
							|  |  |  |         self.get_obs() | 
					
						
							|  |  |  |         self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |         self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) | 
					
						
							|  |  |  |         return self.obs_buffer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def step(self, action): | 
					
						
							|  |  |  |         self.game.make_action(self.available_actions[action], self.skip) | 
					
						
							|  |  |  |         reward = 0.0 | 
					
						
							|  |  |  |         self.get_obs() | 
					
						
							|  |  |  |         health = self.game.get_game_variable(vzd.GameVariable.HEALTH) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         if self.health_setting or health > self.health:  # positive health reward only for d1/d2 | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |             reward += health - self.health | 
					
						
							|  |  |  |         self.health = health | 
					
						
							|  |  |  |         killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) | 
					
						
							|  |  |  |         reward += 20 * (killcount - self.killcount) | 
					
						
							|  |  |  |         self.killcount = killcount | 
					
						
							|  |  |  |         ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) | 
					
						
							|  |  |  |         # if ammo2 > self.ammo2: | 
					
						
							|  |  |  |         reward += ammo2 - self.ammo2 | 
					
						
							|  |  |  |         self.ammo2 = ammo2 | 
					
						
							|  |  |  |         done = False | 
					
						
							|  |  |  |         info = {} | 
					
						
							|  |  |  |         if self.game.is_player_dead() or self.game.get_state() is None: | 
					
						
							|  |  |  |             done = True | 
					
						
							|  |  |  |         elif self.game.is_episode_finished(): | 
					
						
							|  |  |  |             done = True | 
					
						
							|  |  |  |             info["TimeLimit.truncated"] = True | 
					
						
							|  |  |  |         return self.obs_buffer, reward, done, info | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def render(self): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def close(self): | 
					
						
							|  |  |  |         self.game.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-08 12:42:16 -04:00
										 |  |  | def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num): | 
					
						
							|  |  |  |     test_num = min(os.cpu_count() - 1, test_num) | 
					
						
							|  |  |  |     if envpool is not None: | 
					
						
							|  |  |  |         task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" | 
					
						
							|  |  |  |         lmp_save_dir = "lmps/" if save_lmp else "" | 
					
						
							|  |  |  |         reward_config = { | 
					
						
							|  |  |  |             "KILLCOUNT": [20.0, -20.0], | 
					
						
							|  |  |  |             "HEALTH": [1.0, 0.0], | 
					
						
							|  |  |  |             "AMMO2": [1.0, -1.0], | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if "battle" in task: | 
					
						
							|  |  |  |             reward_config["HEALTH"] = [1.0, -1.0] | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |         env = train_envs = envpool.make_gymnasium( | 
					
						
							| 
									
										
										
										
											2022-05-08 12:42:16 -04:00
										 |  |  |             task_id, | 
					
						
							|  |  |  |             frame_skip=frame_skip, | 
					
						
							|  |  |  |             stack_num=res[0], | 
					
						
							|  |  |  |             seed=seed, | 
					
						
							|  |  |  |             num_envs=training_num, | 
					
						
							|  |  |  |             reward_config=reward_config, | 
					
						
							|  |  |  |             use_combined_action=True, | 
					
						
							|  |  |  |             max_episode_steps=2625, | 
					
						
							|  |  |  |             use_inter_area_resize=False, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 |  |  |         test_envs = envpool.make_gymnasium( | 
					
						
							| 
									
										
										
										
											2022-05-08 12:42:16 -04:00
										 |  |  |             task_id, | 
					
						
							|  |  |  |             frame_skip=frame_skip, | 
					
						
							|  |  |  |             stack_num=res[0], | 
					
						
							|  |  |  |             lmp_save_dir=lmp_save_dir, | 
					
						
							|  |  |  |             seed=seed, | 
					
						
							|  |  |  |             num_envs=test_num, | 
					
						
							|  |  |  |             reward_config=reward_config, | 
					
						
							|  |  |  |             use_combined_action=True, | 
					
						
							|  |  |  |             max_episode_steps=2625, | 
					
						
							|  |  |  |             use_inter_area_resize=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         cfg_path = f"maps/{task}.cfg" | 
					
						
							|  |  |  |         env = Env(cfg_path, frame_skip, res) | 
					
						
							|  |  |  |         train_envs = ShmemVectorEnv( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)], | 
					
						
							| 
									
										
										
										
											2022-05-08 12:42:16 -04:00
										 |  |  |         ) | 
					
						
							|  |  |  |         test_envs = ShmemVectorEnv( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)], | 
					
						
							| 
									
										
										
										
											2022-05-08 12:42:16 -04:00
										 |  |  |         ) | 
					
						
							|  |  |  |         train_envs.seed(seed) | 
					
						
							|  |  |  |         test_envs.seed(seed) | 
					
						
							|  |  |  |     return env, train_envs, test_envs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |     # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84)) | 
					
						
							|  |  |  |     env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) | 
					
						
							|  |  |  |     print(env.available_actions) | 
					
						
							|  |  |  |     action_num = env.action_space.n | 
					
						
							|  |  |  |     obs = env.reset() | 
					
						
							|  |  |  |     print(env.spec.reward_threshold) | 
					
						
							|  |  |  |     print(obs.shape, action_num) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     for _ in range(4000): | 
					
						
							| 
									
										
										
										
											2023-11-16 18:27:53 +01:00
										 |  |  |         obs, rew, terminated, truncated, info = env.step(0) | 
					
						
							|  |  |  |         if terminated or truncated: | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |             env.reset() | 
					
						
							| 
									
										
										
										
											2023-11-16 18:27:53 +01:00
										 |  |  |     print(obs.shape, rew, terminated, truncated) | 
					
						
							| 
									
										
										
										
											2021-06-26 18:08:41 +08:00
										 |  |  |     cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3]) |