Closes #952 - `SamplingConfig` supports `batch_size=None`. #1077 - tests and examples are covered by `mypy`. #1077 - `NetBase` is more used, stricter typing by making it generic. #1077 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 --------- Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
205 lines
6.7 KiB
Python
205 lines
6.7 KiB
Python
import os
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
import cv2
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import vizdoom as vzd
|
|
from numpy.typing import NDArray
|
|
|
|
from tianshou.env import ShmemVectorEnv
|
|
|
|
try:
|
|
import envpool
|
|
except ImportError:
|
|
envpool = None
|
|
|
|
|
|
def normal_button_comb() -> list:
|
|
actions = []
|
|
m_forward = [[0.0], [1.0]]
|
|
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
|
for i in m_forward:
|
|
for j in t_left_right:
|
|
actions.append(i + j)
|
|
return actions
|
|
|
|
|
|
def battle_button_comb() -> list:
|
|
actions = []
|
|
m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
|
m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
|
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
|
attack = [[0.0], [1.0]]
|
|
speed = [[0.0], [1.0]]
|
|
|
|
for m in attack:
|
|
for n in speed:
|
|
for j in m_left_right:
|
|
for i in m_forward_backward:
|
|
for k in t_left_right:
|
|
actions.append(i + j + k + m + n)
|
|
return actions
|
|
|
|
|
|
class Env(gym.Env):
|
|
def __init__(
|
|
self,
|
|
cfg_path: str,
|
|
frameskip: int = 4,
|
|
res: Sequence[int] = (4, 40, 60),
|
|
save_lmp: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.save_lmp = save_lmp
|
|
self.health_setting = "battle" in cfg_path
|
|
if save_lmp:
|
|
os.makedirs("lmps", exist_ok=True)
|
|
self.res = res
|
|
self.skip = frameskip
|
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=res, dtype=np.float32)
|
|
self.game = vzd.DoomGame()
|
|
self.game.load_config(cfg_path)
|
|
self.game.init()
|
|
if "battle" in cfg_path:
|
|
self.available_actions = battle_button_comb()
|
|
else:
|
|
self.available_actions = normal_button_comb()
|
|
self.action_num = len(self.available_actions)
|
|
self.action_space = gym.spaces.Discrete(self.action_num)
|
|
self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
|
|
self.count = 0
|
|
|
|
def get_obs(self) -> None:
|
|
state = self.game.get_state()
|
|
if state is None:
|
|
return
|
|
obs = state.screen_buffer
|
|
self.obs_buffer[:-1] = self.obs_buffer[1:]
|
|
self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))
|
|
|
|
def reset(
|
|
self,
|
|
seed: int | None = None,
|
|
options: dict[str, Any] | None = None,
|
|
) -> tuple[NDArray[np.uint8], dict[str, Any]]:
|
|
if self.save_lmp:
|
|
self.game.new_episode(f"lmps/episode_{self.count}.lmp")
|
|
else:
|
|
self.game.new_episode()
|
|
self.count += 1
|
|
self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
|
|
self.get_obs()
|
|
self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
|
|
self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
|
|
self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
|
|
return self.obs_buffer, {"TimeLimit.truncated": False}
|
|
|
|
def step(self, action: int) -> tuple[NDArray[np.uint8], float, bool, bool, dict[str, Any]]:
|
|
self.game.make_action(self.available_actions[action], self.skip)
|
|
reward = 0.0
|
|
self.get_obs()
|
|
health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
|
|
if self.health_setting or health > self.health: # positive health reward only for d1/d2
|
|
reward += health - self.health
|
|
self.health = health
|
|
killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
|
|
reward += 20 * (killcount - self.killcount)
|
|
self.killcount = killcount
|
|
ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
|
|
# if ammo2 > self.ammo2:
|
|
reward += ammo2 - self.ammo2
|
|
self.ammo2 = ammo2
|
|
done = False
|
|
info = {}
|
|
if self.game.is_player_dead() or self.game.get_state() is None:
|
|
done = True
|
|
elif self.game.is_episode_finished():
|
|
done = True
|
|
info["TimeLimit.truncated"] = True
|
|
return self.obs_buffer, reward, done, info.get("TimeLimit.truncated", False), info
|
|
|
|
def render(self) -> None:
|
|
pass
|
|
|
|
def close(self) -> None:
|
|
self.game.close()
|
|
|
|
|
|
def make_vizdoom_env(
|
|
task: str,
|
|
frame_skip: int,
|
|
res: tuple[int],
|
|
save_lmp: bool = False,
|
|
seed: int | None = None,
|
|
training_num: int = 10,
|
|
test_num: int = 10,
|
|
) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]:
|
|
cpu_count = os.cpu_count()
|
|
if cpu_count is not None:
|
|
test_num = min(cpu_count - 1, test_num)
|
|
if envpool is not None:
|
|
task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1"
|
|
lmp_save_dir = "lmps/" if save_lmp else ""
|
|
reward_config = {
|
|
"KILLCOUNT": [20.0, -20.0],
|
|
"HEALTH": [1.0, 0.0],
|
|
"AMMO2": [1.0, -1.0],
|
|
}
|
|
if "battle" in task:
|
|
reward_config["HEALTH"] = [1.0, -1.0]
|
|
env = train_envs = envpool.make_gymnasium(
|
|
task_id,
|
|
frame_skip=frame_skip,
|
|
stack_num=res[0],
|
|
seed=seed,
|
|
num_envs=training_num,
|
|
reward_config=reward_config,
|
|
use_combined_action=True,
|
|
max_episode_steps=2625,
|
|
use_inter_area_resize=False,
|
|
)
|
|
test_envs = envpool.make_gymnasium(
|
|
task_id,
|
|
frame_skip=frame_skip,
|
|
stack_num=res[0],
|
|
lmp_save_dir=lmp_save_dir,
|
|
seed=seed,
|
|
num_envs=test_num,
|
|
reward_config=reward_config,
|
|
use_combined_action=True,
|
|
max_episode_steps=2625,
|
|
use_inter_area_resize=False,
|
|
)
|
|
else:
|
|
cfg_path = f"maps/{task}.cfg"
|
|
env = Env(cfg_path, frame_skip, res)
|
|
train_envs = ShmemVectorEnv(
|
|
[lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)],
|
|
)
|
|
test_envs = ShmemVectorEnv(
|
|
[lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)],
|
|
)
|
|
train_envs.seed(seed)
|
|
test_envs.seed(seed)
|
|
return env, train_envs, test_envs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
|
|
env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
|
|
print(env.available_actions)
|
|
assert isinstance(env.action_space, gym.spaces.Discrete)
|
|
action_num = env.action_space.n
|
|
obs, _ = env.reset()
|
|
if env.spec:
|
|
print(env.spec.reward_threshold)
|
|
print(obs.shape, action_num)
|
|
for _ in range(4000):
|
|
obs, rew, terminated, truncated, info = env.step(0)
|
|
if terminated or truncated:
|
|
env.reset()
|
|
print(obs.shape, rew, terminated, truncated)
|
|
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
|