import logging
import threading

import numpy as np
import gym


class MinecraftBase(gym.Env):
    _LOCK = threading.Lock()

    def __init__(
        self,
        actions,
        repeat=1,
        size=(64, 64),
        break_speed=100.0,
        gamma=10.0,
        sticky_attack=30,
        sticky_jump=10,
        pitch_limit=(-60, 60),
        logs=True,
    ):
        if logs:
            logging.basicConfig(level=logging.DEBUG)
        self._repeat = repeat
        self._size = size
        if break_speed != 1.0:
            sticky_attack = 0

        # Make env
        with self._LOCK:
            from . import minecraft_minerl

            self._env = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make()
        self._inventory = {}

        # Observations
        self._inv_keys = [
            k
            for k in self._flatten(self._env.observation_space.spaces)
            if k.startswith("inventory/")
            if k != "inventory/log2"
        ]
        self._step = 0
        self._max_inventory = None
        self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][
            "type"
        ].values.tolist()

        # Actions
        self._noop_action = minecraft_minerl.NOOP_ACTION
        actions = self._insert_defaults(actions)
        self._action_names = tuple(actions.keys())
        self._action_values = tuple(actions.values())
        message = f"Minecraft action space ({len(self._action_values)}):"
        print(message, ", ".join(self._action_names))
        self._sticky_attack_length = sticky_attack
        self._sticky_attack_counter = 0
        self._sticky_jump_length = sticky_jump
        self._sticky_jump_counter = 0
        self._pitch_limit = pitch_limit
        self._pitch = 0

    @property
    def observation_space(self):
        return gym.spaces.Dict(
            {
                "image": gym.spaces.Box(0, 255, self._size + (3,), np.uint8),
                "inventory": gym.spaces.Box(
                    -np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32
                ),
                "inventory_max": gym.spaces.Box(
                    -np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32
                ),
                "equipped": gym.spaces.Box(
                    -np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32
                ),
                "reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
                "health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
                "hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
                "breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
                "is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
                "is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
                "is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
                **{
                    f"log_{k}": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.int64)
                    for k in self._inv_keys
                },
                "log_player_pos": gym.spaces.Box(
                    -np.inf, np.inf, (3,), dtype=np.float32
                ),
            }
        )

    @property
    def action_space(self):
        space = gym.spaces.discrete.Discrete(len(self._action_values))
        space.discrete = True
        return space

    def step(self, action):
        action = action.copy()
        action = self._action_values[action]
        action = self._action(action)
        following = self._noop_action.copy()
        for key in ("attack", "forward", "back", "left", "right"):
            following[key] = action[key]
        for act in [action] + ([following] * (self._repeat - 1)):
            obs, reward, done, info = self._env.step(act)
            if "error" in info:
                done = True
                break
            obs["is_first"] = False
            obs["is_last"] = bool(done)
            obs["is_terminal"] = bool(info.get("is_terminal", done))

        obs = self._obs(obs)
        self._step += 1
        assert "pov" not in obs, list(obs.keys())
        return obs, reward, done, info

    @property
    def inventory(self):
        return self._inventory

    def reset(self):
        # inventory will be added in _obs
        self._inventory = {}
        self._max_inventory = None

        with self._LOCK:
            obs = self._env.reset()
        obs["is_first"] = True
        obs["is_last"] = False
        obs["is_terminal"] = False
        obs = self._obs(obs)

        self._step = 0
        self._sticky_attack_counter = 0
        self._sticky_jump_counter = 0
        self._pitch = 0
        return obs

    def _obs(self, obs):
        obs = self._flatten(obs)
        obs["inventory/log"] += obs.pop("inventory/log2")
        self._inventory = {
            k.split("/", 1)[1]: obs[k] for k in self._inv_keys if k != "inventory/air"
        }
        inventory = np.array([obs[k] for k in self._inv_keys], np.float32)
        if self._max_inventory is None:
            self._max_inventory = inventory
        else:
            self._max_inventory = np.maximum(self._max_inventory, inventory)
        index = self._equip_enum.index(obs["equipped_items/mainhand/type"])
        equipped = np.zeros(len(self._equip_enum), np.float32)
        equipped[index] = 1.0
        player_x = obs["location_stats/xpos"]
        player_y = obs["location_stats/ypos"]
        player_z = obs["location_stats/zpos"]
        obs = {
            "image": obs["pov"],
            "inventory": inventory,
            "inventory_max": self._max_inventory.copy(),
            "equipped": equipped,
            "health": np.float32(obs["life_stats/life"] / 20),
            "hunger": np.float32(obs["life_stats/food"] / 20),
            "breath": np.float32(obs["life_stats/air"] / 300),
            "reward": 0.0,
            "is_first": obs["is_first"],
            "is_last": obs["is_last"],
            "is_terminal": obs["is_terminal"],
            **{f"log_{k}": np.int64(obs[k]) for k in self._inv_keys},
            "log_player_pos": np.array([player_x, player_y, player_z], np.float32),
        }
        for key, value in obs.items():
            space = self.observation_space[key]
            if not isinstance(value, np.ndarray):
                value = np.array(value)
            assert (key, value, value.dtype, value.shape, space)
        return obs

    def _action(self, action):
        if self._sticky_attack_length:
            if action["attack"]:
                self._sticky_attack_counter = self._sticky_attack_length
            if self._sticky_attack_counter > 0:
                action["attack"] = 1
                action["jump"] = 0
                self._sticky_attack_counter -= 1
        if self._sticky_jump_length:
            if action["jump"]:
                self._sticky_jump_counter = self._sticky_jump_length
            if self._sticky_jump_counter > 0:
                action["jump"] = 1
                action["forward"] = 1
                self._sticky_jump_counter -= 1
        if self._pitch_limit and action["camera"][0]:
            lo, hi = self._pitch_limit
            if not (lo <= self._pitch + action["camera"][0] <= hi):
                action["camera"] = (0, action["camera"][1])
            self._pitch += action["camera"][0]
        return action

    def _insert_defaults(self, actions):
        actions = {name: action.copy() for name, action in actions.items()}
        for key, default in self._noop_action.items():
            for action in actions.values():
                if key not in action:
                    action[key] = default
        return actions

    def _flatten(self, nest, prefix=None):
        result = {}
        for key, value in nest.items():
            key = prefix + "/" + key if prefix else key
            if isinstance(value, gym.spaces.Dict):
                value = value.spaces
            if isinstance(value, dict):
                result.update(self._flatten(value, key))
            else:
                result[key] = value
        return result

    def _unflatten(self, flat):
        result = {}
        for key, value in flat.items():
            parts = key.split("/")
            node = result
            for part in parts[:-1]:
                if part not in node:
                    node[part] = {}
                node = node[part]
            node[parts[-1]] = value
        return result