Daniel Plop 8a0629ded6
Fix mypy issues in tests and examples (#1077)
Closes #952 

- `SamplingConfig` supports `batch_size=None`. #1077
- tests and examples are covered by `mypy`. #1077
- `NetBase` is more used, stricter typing by making it generic. #1077
- `utils.net.common.Recurrent` now receives and returns a
`RecurrentStateBatch` instead of a dict. #1077

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2024-04-03 18:07:51 +02:00

205 lines
6.7 KiB
Python

import os
from collections.abc import Sequence
from typing import Any
import cv2
import gymnasium as gym
import numpy as np
import vizdoom as vzd
from numpy.typing import NDArray
from tianshou.env import ShmemVectorEnv
try:
import envpool
except ImportError:
envpool = None
def normal_button_comb() -> list:
actions = []
m_forward = [[0.0], [1.0]]
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
for i in m_forward:
for j in t_left_right:
actions.append(i + j)
return actions
def battle_button_comb() -> list:
actions = []
m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
attack = [[0.0], [1.0]]
speed = [[0.0], [1.0]]
for m in attack:
for n in speed:
for j in m_left_right:
for i in m_forward_backward:
for k in t_left_right:
actions.append(i + j + k + m + n)
return actions
class Env(gym.Env):
def __init__(
self,
cfg_path: str,
frameskip: int = 4,
res: Sequence[int] = (4, 40, 60),
save_lmp: bool = False,
) -> None:
super().__init__()
self.save_lmp = save_lmp
self.health_setting = "battle" in cfg_path
if save_lmp:
os.makedirs("lmps", exist_ok=True)
self.res = res
self.skip = frameskip
self.observation_space = gym.spaces.Box(low=0, high=255, shape=res, dtype=np.float32)
self.game = vzd.DoomGame()
self.game.load_config(cfg_path)
self.game.init()
if "battle" in cfg_path:
self.available_actions = battle_button_comb()
else:
self.available_actions = normal_button_comb()
self.action_num = len(self.available_actions)
self.action_space = gym.spaces.Discrete(self.action_num)
self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
self.count = 0
def get_obs(self) -> None:
state = self.game.get_state()
if state is None:
return
obs = state.screen_buffer
self.obs_buffer[:-1] = self.obs_buffer[1:]
self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[NDArray[np.uint8], dict[str, Any]]:
if self.save_lmp:
self.game.new_episode(f"lmps/episode_{self.count}.lmp")
else:
self.game.new_episode()
self.count += 1
self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
self.get_obs()
self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
return self.obs_buffer, {"TimeLimit.truncated": False}
def step(self, action: int) -> tuple[NDArray[np.uint8], float, bool, bool, dict[str, Any]]:
self.game.make_action(self.available_actions[action], self.skip)
reward = 0.0
self.get_obs()
health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
if self.health_setting or health > self.health: # positive health reward only for d1/d2
reward += health - self.health
self.health = health
killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
reward += 20 * (killcount - self.killcount)
self.killcount = killcount
ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
# if ammo2 > self.ammo2:
reward += ammo2 - self.ammo2
self.ammo2 = ammo2
done = False
info = {}
if self.game.is_player_dead() or self.game.get_state() is None:
done = True
elif self.game.is_episode_finished():
done = True
info["TimeLimit.truncated"] = True
return self.obs_buffer, reward, done, info.get("TimeLimit.truncated", False), info
def render(self) -> None:
pass
def close(self) -> None:
self.game.close()
def make_vizdoom_env(
task: str,
frame_skip: int,
res: tuple[int],
save_lmp: bool = False,
seed: int | None = None,
training_num: int = 10,
test_num: int = 10,
) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]:
cpu_count = os.cpu_count()
if cpu_count is not None:
test_num = min(cpu_count - 1, test_num)
if envpool is not None:
task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1"
lmp_save_dir = "lmps/" if save_lmp else ""
reward_config = {
"KILLCOUNT": [20.0, -20.0],
"HEALTH": [1.0, 0.0],
"AMMO2": [1.0, -1.0],
}
if "battle" in task:
reward_config["HEALTH"] = [1.0, -1.0]
env = train_envs = envpool.make_gymnasium(
task_id,
frame_skip=frame_skip,
stack_num=res[0],
seed=seed,
num_envs=training_num,
reward_config=reward_config,
use_combined_action=True,
max_episode_steps=2625,
use_inter_area_resize=False,
)
test_envs = envpool.make_gymnasium(
task_id,
frame_skip=frame_skip,
stack_num=res[0],
lmp_save_dir=lmp_save_dir,
seed=seed,
num_envs=test_num,
reward_config=reward_config,
use_combined_action=True,
max_episode_steps=2625,
use_inter_area_resize=False,
)
else:
cfg_path = f"maps/{task}.cfg"
env = Env(cfg_path, frame_skip, res)
train_envs = ShmemVectorEnv(
[lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)],
)
test_envs = ShmemVectorEnv(
[lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)],
)
train_envs.seed(seed)
test_envs.seed(seed)
return env, train_envs, test_envs
if __name__ == "__main__":
# env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
print(env.available_actions)
assert isinstance(env.action_space, gym.spaces.Discrete)
action_num = env.action_space.n
obs, _ = env.reset()
if env.spec:
print(env.spec.reward_threshold)
print(obs.shape, action_num)
for _ in range(4000):
obs, rew, terminated, truncated, info = env.step(0)
if terminated or truncated:
env.reset()
print(obs.shape, rew, terminated, truncated)
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])