dreamerv3-torch/envs/minecraft_base.py
2023-07-23 22:02:06 +09:00

236 lines
8.3 KiB
Python

import logging
import threading
import numpy as np
import gym
class MinecraftBase(gym.Env):
_LOCK = threading.Lock()
def __init__(
self,
actions,
repeat=1,
size=(64, 64),
break_speed=100.0,
gamma=10.0,
sticky_attack=30,
sticky_jump=10,
pitch_limit=(-60, 60),
logs=True,
):
if logs:
logging.basicConfig(level=logging.DEBUG)
self._repeat = repeat
self._size = size
if break_speed != 1.0:
sticky_attack = 0
# Make env
with self._LOCK:
from . import minecraft_minerl
self._env = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make()
self._inventory = {}
# Observations
self._inv_keys = [
k
for k in self._flatten(self._env.observation_space.spaces)
if k.startswith("inventory/")
if k != "inventory/log2"
]
self._step = 0
self._max_inventory = None
self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][
"type"
].values.tolist()
# Actions
self._noop_action = minecraft_minerl.NOOP_ACTION
actions = self._insert_defaults(actions)
self._action_names = tuple(actions.keys())
self._action_values = tuple(actions.values())
message = f"Minecraft action space ({len(self._action_values)}):"
print(message, ", ".join(self._action_names))
self._sticky_attack_length = sticky_attack
self._sticky_attack_counter = 0
self._sticky_jump_length = sticky_jump
self._sticky_jump_counter = 0
self._pitch_limit = pitch_limit
self._pitch = 0
@property
def observation_space(self):
return gym.spaces.Dict(
{
"image": gym.spaces.Box(0, 255, self._size + (3,), np.uint8),
"inventory": gym.spaces.Box(
-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32
),
"inventory_max": gym.spaces.Box(
-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32
),
"equipped": gym.spaces.Box(
-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32
),
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
**{
f"log_{k}": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.int64)
for k in self._inv_keys
},
"log_player_pos": gym.spaces.Box(
-np.inf, np.inf, (3,), dtype=np.float32
),
}
)
@property
def action_space(self):
space = gym.spaces.discrete.Discrete(len(self._action_values))
space.discrete = True
return space
def step(self, action):
action = action.copy()
action = self._action_values[action]
action = self._action(action)
following = self._noop_action.copy()
for key in ("attack", "forward", "back", "left", "right"):
following[key] = action[key]
for act in [action] + ([following] * (self._repeat - 1)):
obs, reward, done, info = self._env.step(act)
if "error" in info:
done = True
break
obs["is_first"] = False
obs["is_last"] = bool(done)
obs["is_terminal"] = bool(info.get("is_terminal", done))
obs = self._obs(obs)
self._step += 1
assert "pov" not in obs, list(obs.keys())
return obs, reward, done, info
@property
def inventory(self):
return self._inventory
def reset(self):
# inventory will be added in _obs
self._inventory = {}
self._max_inventory = None
with self._LOCK:
obs = self._env.reset()
obs["is_first"] = True
obs["is_last"] = False
obs["is_terminal"] = False
obs = self._obs(obs)
self._step = 0
self._sticky_attack_counter = 0
self._sticky_jump_counter = 0
self._pitch = 0
return obs
def _obs(self, obs):
obs = self._flatten(obs)
obs["inventory/log"] += obs.pop("inventory/log2")
self._inventory = {
k.split("/", 1)[1]: obs[k] for k in self._inv_keys if k != "inventory/air"
}
inventory = np.array([obs[k] for k in self._inv_keys], np.float32)
if self._max_inventory is None:
self._max_inventory = inventory
else:
self._max_inventory = np.maximum(self._max_inventory, inventory)
index = self._equip_enum.index(obs["equipped_items/mainhand/type"])
equipped = np.zeros(len(self._equip_enum), np.float32)
equipped[index] = 1.0
player_x = obs["location_stats/xpos"]
player_y = obs["location_stats/ypos"]
player_z = obs["location_stats/zpos"]
obs = {
"image": obs["pov"],
"inventory": inventory,
"inventory_max": self._max_inventory.copy(),
"equipped": equipped,
"health": np.float32(obs["life_stats/life"] / 20),
"hunger": np.float32(obs["life_stats/food"] / 20),
"breath": np.float32(obs["life_stats/air"] / 300),
"reward": 0.0,
"is_first": obs["is_first"],
"is_last": obs["is_last"],
"is_terminal": obs["is_terminal"],
**{f"log_{k}": np.int64(obs[k]) for k in self._inv_keys},
"log_player_pos": np.array([player_x, player_y, player_z], np.float32),
}
for key, value in obs.items():
space = self.observation_space[key]
if not isinstance(value, np.ndarray):
value = np.array(value)
assert (key, value, value.dtype, value.shape, space)
return obs
def _action(self, action):
if self._sticky_attack_length:
if action["attack"]:
self._sticky_attack_counter = self._sticky_attack_length
if self._sticky_attack_counter > 0:
action["attack"] = 1
action["jump"] = 0
self._sticky_attack_counter -= 1
if self._sticky_jump_length:
if action["jump"]:
self._sticky_jump_counter = self._sticky_jump_length
if self._sticky_jump_counter > 0:
action["jump"] = 1
action["forward"] = 1
self._sticky_jump_counter -= 1
if self._pitch_limit and action["camera"][0]:
lo, hi = self._pitch_limit
if not (lo <= self._pitch + action["camera"][0] <= hi):
action["camera"] = (0, action["camera"][1])
self._pitch += action["camera"][0]
return action
def _insert_defaults(self, actions):
actions = {name: action.copy() for name, action in actions.items()}
for key, default in self._noop_action.items():
for action in actions.values():
if key not in action:
action[key] = default
return actions
def _flatten(self, nest, prefix=None):
result = {}
for key, value in nest.items():
key = prefix + "/" + key if prefix else key
if isinstance(value, gym.spaces.Dict):
value = value.spaces
if isinstance(value, dict):
result.update(self._flatten(value, key))
else:
result[key] = value
return result
def _unflatten(self, flat):
result = {}
for key, value in flat.items():
parts = key.split("/")
node = result
for part in parts[:-1]:
if part not in node:
node[part] = {}
node = node[part]
node[parts[-1]] = value
return result