Improves typing in examples and tests, towards mypy passing there. Introduces the SpaceInfo utility
		
			
				
	
	
		
			180 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			180 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| 
 | |
| import cv2
 | |
| import gymnasium as gym
 | |
| import numpy as np
 | |
| import vizdoom as vzd
 | |
| 
 | |
| from tianshou.env import ShmemVectorEnv
 | |
| 
 | |
| try:
 | |
|     import envpool
 | |
| except ImportError:
 | |
|     envpool = None
 | |
| 
 | |
| 
 | |
| def normal_button_comb():
 | |
|     actions = []
 | |
|     m_forward = [[0.0], [1.0]]
 | |
|     t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
 | |
|     for i in m_forward:
 | |
|         for j in t_left_right:
 | |
|             actions.append(i + j)
 | |
|     return actions
 | |
| 
 | |
| 
 | |
| def battle_button_comb():
 | |
|     actions = []
 | |
|     m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
 | |
|     m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
 | |
|     t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
 | |
|     attack = [[0.0], [1.0]]
 | |
|     speed = [[0.0], [1.0]]
 | |
| 
 | |
|     for m in attack:
 | |
|         for n in speed:
 | |
|             for j in m_left_right:
 | |
|                 for i in m_forward_backward:
 | |
|                     for k in t_left_right:
 | |
|                         actions.append(i + j + k + m + n)
 | |
|     return actions
 | |
| 
 | |
| 
 | |
| class Env(gym.Env):
 | |
|     def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False) -> None:
 | |
|         super().__init__()
 | |
|         self.save_lmp = save_lmp
 | |
|         self.health_setting = "battle" in cfg_path
 | |
|         if save_lmp:
 | |
|             os.makedirs("lmps", exist_ok=True)
 | |
|         self.res = res
 | |
|         self.skip = frameskip
 | |
|         self.observation_space = gym.spaces.Box(low=0, high=255, shape=res, dtype=np.float32)
 | |
|         self.game = vzd.DoomGame()
 | |
|         self.game.load_config(cfg_path)
 | |
|         self.game.init()
 | |
|         if "battle" in cfg_path:
 | |
|             self.available_actions = battle_button_comb()
 | |
|         else:
 | |
|             self.available_actions = normal_button_comb()
 | |
|         self.action_num = len(self.available_actions)
 | |
|         self.action_space = gym.spaces.Discrete(self.action_num)
 | |
|         self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
 | |
|         self.count = 0
 | |
| 
 | |
|     def get_obs(self):
 | |
|         state = self.game.get_state()
 | |
|         if state is None:
 | |
|             return
 | |
|         obs = state.screen_buffer
 | |
|         self.obs_buffer[:-1] = self.obs_buffer[1:]
 | |
|         self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))
 | |
| 
 | |
|     def reset(self):
 | |
|         if self.save_lmp:
 | |
|             self.game.new_episode(f"lmps/episode_{self.count}.lmp")
 | |
|         else:
 | |
|             self.game.new_episode()
 | |
|         self.count += 1
 | |
|         self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
 | |
|         self.get_obs()
 | |
|         self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
 | |
|         self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
 | |
|         self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
 | |
|         return self.obs_buffer
 | |
| 
 | |
|     def step(self, action):
 | |
|         self.game.make_action(self.available_actions[action], self.skip)
 | |
|         reward = 0.0
 | |
|         self.get_obs()
 | |
|         health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
 | |
|         if self.health_setting or health > self.health:  # positive health reward only for d1/d2
 | |
|             reward += health - self.health
 | |
|         self.health = health
 | |
|         killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
 | |
|         reward += 20 * (killcount - self.killcount)
 | |
|         self.killcount = killcount
 | |
|         ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
 | |
|         # if ammo2 > self.ammo2:
 | |
|         reward += ammo2 - self.ammo2
 | |
|         self.ammo2 = ammo2
 | |
|         done = False
 | |
|         info = {}
 | |
|         if self.game.is_player_dead() or self.game.get_state() is None:
 | |
|             done = True
 | |
|         elif self.game.is_episode_finished():
 | |
|             done = True
 | |
|             info["TimeLimit.truncated"] = True
 | |
|         return self.obs_buffer, reward, done, info
 | |
| 
 | |
|     def render(self):
 | |
|         pass
 | |
| 
 | |
|     def close(self):
 | |
|         self.game.close()
 | |
| 
 | |
| 
 | |
| def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num):
 | |
|     test_num = min(os.cpu_count() - 1, test_num)
 | |
|     if envpool is not None:
 | |
|         task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1"
 | |
|         lmp_save_dir = "lmps/" if save_lmp else ""
 | |
|         reward_config = {
 | |
|             "KILLCOUNT": [20.0, -20.0],
 | |
|             "HEALTH": [1.0, 0.0],
 | |
|             "AMMO2": [1.0, -1.0],
 | |
|         }
 | |
|         if "battle" in task:
 | |
|             reward_config["HEALTH"] = [1.0, -1.0]
 | |
|         env = train_envs = envpool.make_gymnasium(
 | |
|             task_id,
 | |
|             frame_skip=frame_skip,
 | |
|             stack_num=res[0],
 | |
|             seed=seed,
 | |
|             num_envs=training_num,
 | |
|             reward_config=reward_config,
 | |
|             use_combined_action=True,
 | |
|             max_episode_steps=2625,
 | |
|             use_inter_area_resize=False,
 | |
|         )
 | |
|         test_envs = envpool.make_gymnasium(
 | |
|             task_id,
 | |
|             frame_skip=frame_skip,
 | |
|             stack_num=res[0],
 | |
|             lmp_save_dir=lmp_save_dir,
 | |
|             seed=seed,
 | |
|             num_envs=test_num,
 | |
|             reward_config=reward_config,
 | |
|             use_combined_action=True,
 | |
|             max_episode_steps=2625,
 | |
|             use_inter_area_resize=False,
 | |
|         )
 | |
|     else:
 | |
|         cfg_path = f"maps/{task}.cfg"
 | |
|         env = Env(cfg_path, frame_skip, res)
 | |
|         train_envs = ShmemVectorEnv(
 | |
|             [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)],
 | |
|         )
 | |
|         test_envs = ShmemVectorEnv(
 | |
|             [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)],
 | |
|         )
 | |
|         train_envs.seed(seed)
 | |
|         test_envs.seed(seed)
 | |
|     return env, train_envs, test_envs
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
 | |
|     env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
 | |
|     print(env.available_actions)
 | |
|     action_num = env.action_space.n
 | |
|     obs = env.reset()
 | |
|     print(env.spec.reward_threshold)
 | |
|     print(obs.shape, action_num)
 | |
|     for _ in range(4000):
 | |
|         obs, rew, terminated, truncated, info = env.step(0)
 | |
|         if terminated or truncated:
 | |
|             env.reset()
 | |
|     print(obs.shape, rew, terminated, truncated)
 | |
|     cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
 |