From 036e9a8028c00a2ff1b6f65accdf372f334cbf8b Mon Sep 17 00:00:00 2001
From: NM512 <morihira3513@gmail.com>
Date: Sun, 2 Jul 2023 11:29:48 +0900
Subject: [PATCH] added minecraft environment

---
 configs.yaml             |   9 ++
 dreamer.py               |   5 +-
 envs/minecraft.py        | 154 +++++++++++++++++++++++++++
 envs/minecraft_base.py   | 219 +++++++++++++++++++++++++++++++++++++++
 envs/minecraft_minerl.py | 150 +++++++++++++++++++++++++++
 requirements.txt         |  11 +-
 6 files changed, 543 insertions(+), 5 deletions(-)
 create mode 100644 envs/minecraft.py
 create mode 100644 envs/minecraft_base.py
 create mode 100644 envs/minecraft_minerl.py

diff --git a/configs.yaml b/configs.yaml
index 684cbb2..d085760 100644
--- a/configs.yaml
+++ b/configs.yaml
@@ -172,6 +172,15 @@ atari100k:
   imag_gradient: 'reinforce'
   time_limit: 108000
 
+minecraft:
+  task: minecraft_diamond
+  break_speed: 100.0
+  envs: 16
+  train_ratio: 16
+  log_keys_max: '^log_inventory.*'
+  encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image'}
+  decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image'}
+  time_limit: 36000
 
 debug:
   debug: True
diff --git a/dreamer.py b/dreamer.py
index eeadf5a..f593669 100644
--- a/dreamer.py
+++ b/dreamer.py
@@ -218,9 +218,12 @@ def make_env(config, logger, mode, train_eps, eval_eps):
         env = wrappers.OneHotAction(env)
     elif suite == "crafter":
         import envs.crafter as crafter
-
         env = crafter.Crafter(task, config.size)
         env = wrappers.OneHotAction(env)
+    elif suite == "minecraft":
+        import envs.minecraft as minecraft
+        env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
+        env = wrappers.OneHotAction(env)
     else:
         raise NotImplementedError(suite)
     env = wrappers.TimeLimit(env, config.time_limit)
diff --git a/envs/minecraft.py b/envs/minecraft.py
new file mode 100644
index 0000000..c94525b
--- /dev/null
+++ b/envs/minecraft.py
@@ -0,0 +1,154 @@
+import numpy as np
+from . import minecraft_base
+
+import gym
+
+def make_env(task, *args, **kwargs):
+    return {
+        'wood': MinecraftWood,
+        'climb': MinecraftClimb,
+        'diamond': MinecraftDiamond,
+        }[task](*args, **kwargs)
+
+
+class MinecraftWood:
+
+  def __init__(self, *args, **kwargs):
+    actions = BASIC_ACTIONS
+    self.rewards = [
+        CollectReward('log', repeated=1),
+        HealthReward(),
+    ]
+    env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
+
+  def step(self, action):
+    obs, reward, done, info = self.env.step(action)
+    reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
+    obs['reward'] = reward
+    return obs, reward, done, info
+
+
+class MinecraftClimb:
+
+  def __init__(self, *args, **kwargs):
+    actions = BASIC_ACTIONS
+    env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
+    self._previous = None
+    self._health_reward = HealthReward()
+
+  def step(self, action):
+    obs, reward, done, info = self.env.step(action)
+    x, y, z = obs['log_player_pos']
+    height = np.float32(y)
+    if obs['is_first']:
+      self._previous = height
+    reward = height - self._previous
+    reward += self._health_reward(obs)
+    obs['reward'] = reward
+    self._previous = height
+    return obs, reward, done, info
+
+
+class MinecraftDiamond(gym.Wrapper):
+
+  def __init__(self, *args, **kwargs):
+    actions = {
+        **BASIC_ACTIONS,
+        'craft_planks': dict(craft='planks'),
+        'craft_stick': dict(craft='stick'),
+        'craft_crafting_table': dict(craft='crafting_table'),
+        'place_crafting_table': dict(place='crafting_table'),
+        'craft_wooden_pickaxe': dict(nearbyCraft='wooden_pickaxe'),
+        'craft_stone_pickaxe': dict(nearbyCraft='stone_pickaxe'),
+        'craft_iron_pickaxe': dict(nearbyCraft='iron_pickaxe'),
+        'equip_stone_pickaxe': dict(equip='stone_pickaxe'),
+        'equip_wooden_pickaxe': dict(equip='wooden_pickaxe'),
+        'equip_iron_pickaxe': dict(equip='iron_pickaxe'),
+        'craft_furnace': dict(nearbyCraft='furnace'),
+        'place_furnace': dict(place='furnace'),
+        'smelt_iron_ingot': dict(nearbySmelt='iron_ingot'),
+    }
+    self.rewards = [
+        CollectReward('log', once=1),
+        CollectReward('planks', once=1),
+        CollectReward('stick', once=1),
+        CollectReward('crafting_table', once=1),
+        CollectReward('wooden_pickaxe', once=1),
+        CollectReward('cobblestone', once=1),
+        CollectReward('stone_pickaxe', once=1),
+        CollectReward('iron_ore', once=1),
+        CollectReward('furnace', once=1),
+        CollectReward('iron_ingot', once=1),
+        CollectReward('iron_pickaxe', once=1),
+        CollectReward('diamond', once=1),
+        HealthReward(),
+    ]
+    env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
+    super().__init__(env)
+
+  def step(self, action):
+    obs, reward, done, info  = self.env.step(action)
+    reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
+    obs['reward'] = reward
+    return obs, reward, done, info
+
+  def reset(self):
+    obs = self.env.reset()
+    # called for reset of reward calculations
+    _ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
+    return obs
+
+
+class CollectReward:
+
+  def __init__(self, item, once=0, repeated=0):
+    self.item = item
+    self.once = once
+    self.repeated = repeated
+    self.previous = 0
+    self.maximum = 0
+
+  def __call__(self, obs, inventory):
+    current = inventory[self.item]
+    if obs['is_first']:
+      self.previous = current
+      self.maximum = current
+      return 0
+    reward = self.repeated * max(0, current - self.previous)
+    if self.maximum == 0 and current > 0:
+      reward += self.once
+    self.previous = current
+    self.maximum = max(self.maximum, current)
+    return reward
+
+
+class HealthReward:
+
+  def __init__(self, scale=0.01):
+    self.scale = scale
+    self.previous = None
+
+  def __call__(self, obs, inventory=None):
+    health = obs['health']
+    if obs['is_first']:
+      self.previous = health
+      return 0
+    reward = self.scale * (health - self.previous)
+    self.previous = health
+    return np.float32(reward)
+
+
+BASIC_ACTIONS = {
+    'noop': dict(),
+    'attack': dict(attack=1),
+    'turn_up': dict(camera=(-15, 0)),
+    'turn_down': dict(camera=(15, 0)),
+    'turn_left': dict(camera=(0, -15)),
+    'turn_right': dict(camera=(0, 15)),
+    'forward': dict(forward=1),
+    'back': dict(back=1),
+    'left': dict(left=1),
+    'right': dict(right=1),
+    'jump': dict(jump=1, forward=1),
+    'place_dirt': dict(place='dirt'),
+}
diff --git a/envs/minecraft_base.py b/envs/minecraft_base.py
new file mode 100644
index 0000000..47a3a56
--- /dev/null
+++ b/envs/minecraft_base.py
@@ -0,0 +1,219 @@
+import logging
+import threading
+
+import numpy as np
+import gym
+
+class MinecraftBase(gym.Env):
+
+  _LOCK = threading.Lock()
+
+  def __init__(
+      self, actions,
+      repeat=1,
+      size=(64, 64),
+      break_speed=100.0,
+      gamma=10.0,
+      sticky_attack=30,
+      sticky_jump=10,
+      pitch_limit=(-60, 60),
+      logs=True,
+  ):
+    if logs:
+      logging.basicConfig(level=logging.DEBUG)
+    self._repeat = repeat
+    self._size = size
+    if break_speed != 1.0:
+      sticky_attack = 0
+
+    # Make env
+    with self._LOCK:
+        from .import minecraft_minerl
+        self._env = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make()
+    self._inventory = {}
+
+    # Observations
+    self._inv_keys = [
+        k for k in self._flatten(self._env.observation_space.spaces) if k.startswith('inventory/')
+        if k != 'inventory/log2']
+    self._step = 0
+    self._max_inventory = None
+    self._equip_enum = self._env.observation_space[
+        'equipped_items']['mainhand']['type'].values.tolist()
+
+    # Actions
+    self._noop_action = minecraft_minerl.NOOP_ACTION
+    actions = self._insert_defaults(actions)
+    self._action_names = tuple(actions.keys())
+    self._action_values = tuple(actions.values())
+    message = f'Minecraft action space ({len(self._action_values)}):'
+    print(message, ', '.join(self._action_names))
+    self._sticky_attack_length = sticky_attack
+    self._sticky_attack_counter = 0
+    self._sticky_jump_length = sticky_jump
+    self._sticky_jump_counter = 0
+    self._pitch_limit = pitch_limit
+    self._pitch = 0
+
+  @property
+  def observation_space(self):
+    return gym.spaces.Dict(
+        {
+        'image': gym.spaces.Box(0, 255, self._size + (3,), np.uint8),
+        'inventory': gym.spaces.Box(-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32),
+        'inventory_max': gym.spaces.Box(-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32),
+        'equipped': gym.spaces.Box(-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32),
+        'reward': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
+        'health': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
+        'hunger': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
+        'breath': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
+        'is_first': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
+        'is_last': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
+        'is_terminal': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
+        **{f'log_{k}': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.int64) for k in self._inv_keys},
+        'log_player_pos': gym.spaces.Box(-np.inf, np.inf, (3,), dtype=np.float32),
+        }
+    )
+
+  @property
+  def action_space(self):
+    space = gym.spaces.discrete.Discrete(len(self._action_values))
+    space.discrete = True
+    return space
+
+  def step(self, action):
+    action = action.copy()
+    print(self._step, action)
+    action = self._action_values[action]
+    action = self._action(action)
+    following = self._noop_action.copy()
+    for key in ('attack', 'forward', 'back', 'left', 'right'):
+        following[key] = action[key]
+    for act in [action] + ([following] * (self._repeat - 1)):
+        obs, reward, done, info = self._env.step(act)
+        if 'error' in info:
+            done = True
+            break
+        obs['is_first'] = False
+        obs['is_last'] = bool(done)
+        obs['is_terminal'] = bool(info.get('is_terminal', done))
+
+    obs = self._obs(obs)
+    self._step += 1
+    assert 'pov' not in obs, list(obs.keys())
+    return obs, reward, done, info
+
+  @property
+  def inventory(self):
+    return self._inventory
+
+  def reset(self):
+    # inventory will be added in _obs
+    self._inventory = {}
+    self._max_inventory = None
+
+    with self._LOCK:
+      obs = self._env.reset()
+    obs['is_first'] = True
+    obs['is_last'] = False
+    obs['is_terminal'] = False
+    obs = self._obs(obs)
+
+    self._step = 0
+    self._sticky_attack_counter = 0
+    self._sticky_jump_counter = 0
+    self._pitch = 0
+    return obs
+
+  def _obs(self, obs):
+    obs = self._flatten(obs)
+    obs['inventory/log'] += obs.pop('inventory/log2')
+    self._inventory = {
+        k.split('/', 1)[1]: obs[k] for k in self._inv_keys
+        if k != 'inventory/air'}
+    inventory = np.array([obs[k] for k in self._inv_keys], np.float32)
+    if self._max_inventory is None:
+      self._max_inventory = inventory
+    else:
+      self._max_inventory = np.maximum(self._max_inventory, inventory)
+    index = self._equip_enum.index(obs['equipped_items/mainhand/type'])
+    equipped = np.zeros(len(self._equip_enum), np.float32)
+    equipped[index] = 1.0
+    player_x = obs['location_stats/xpos']
+    player_y = obs['location_stats/ypos']
+    player_z = obs['location_stats/zpos']
+    obs = {
+        'image': obs['pov'],
+        'inventory': inventory,
+        'inventory_max': self._max_inventory.copy(),
+        'equipped': equipped,
+        'health': np.float32(obs['life_stats/life'] / 20),
+        'hunger': np.float32(obs['life_stats/food'] / 20),
+        'breath': np.float32(obs['life_stats/air'] / 300),
+        'reward': 0.0,
+        'is_first': obs['is_first'],
+        'is_last': obs['is_last'],
+        'is_terminal': obs['is_terminal'],
+        **{f'log_{k}': np.int64(obs[k]) for k in self._inv_keys},
+        'log_player_pos': np.array([player_x, player_y, player_z], np.float32),
+    }
+    for key, value in obs.items():
+      space = self.observation_space[key]
+      if not isinstance(value, np.ndarray):
+        value = np.array(value)
+      assert (key, value, value.dtype, value.shape, space)
+    return obs
+
+  def _action(self, action):
+    if self._sticky_attack_length:
+      if action['attack']:
+        self._sticky_attack_counter = self._sticky_attack_length
+      if self._sticky_attack_counter > 0:
+        action['attack'] = 1
+        action['jump'] = 0
+        self._sticky_attack_counter -= 1
+    if self._sticky_jump_length:
+      if action['jump']:
+        self._sticky_jump_counter = self._sticky_jump_length
+      if self._sticky_jump_counter > 0:
+        action['jump'] = 1
+        action['forward'] = 1
+        self._sticky_jump_counter -= 1
+    if self._pitch_limit and action['camera'][0]:
+      lo, hi = self._pitch_limit
+      if not (lo <= self._pitch + action['camera'][0] <= hi):
+        action['camera'] = (0, action['camera'][1])
+      self._pitch += action['camera'][0]
+    return action
+
+  def _insert_defaults(self, actions):
+    actions = {name: action.copy() for name, action in actions.items()}
+    for key, default in self._noop_action.items():
+      for action in actions.values():
+        if key not in action:
+          action[key] = default
+    return actions
+
+  def _flatten(self, nest, prefix=None):
+    result = {}
+    for key, value in nest.items():
+      key = prefix + '/' + key if prefix else key
+      if isinstance(value, gym.spaces.Dict):
+        value = value.spaces
+      if isinstance(value, dict):
+        result.update(self._flatten(value, key))
+      else:
+        result[key] = value
+    return result
+
+  def _unflatten(self, flat):
+    result = {}
+    for key, value in flat.items():
+      parts = key.split('/')
+      node = result
+      for part in parts[:-1]:
+        if part not in node:
+          node[part] = {}
+        node = node[part]
+      node[parts[-1]] = value
+    return result
\ No newline at end of file
diff --git a/envs/minecraft_minerl.py b/envs/minecraft_minerl.py
new file mode 100644
index 0000000..b412218
--- /dev/null
+++ b/envs/minecraft_minerl.py
@@ -0,0 +1,150 @@
+from minerl.herobraine.env_spec import EnvSpec
+from minerl.herobraine.hero import handler
+from minerl.herobraine.hero import handlers
+from minerl.herobraine.hero import mc
+from minerl.herobraine.hero.mc import INVERSE_KEYMAP
+
+
+def edit_options(**kwargs):
+  import os, pathlib, re
+  for word in os.popen('pip3 --version').read().split(' '):
+    if '-packages/pip' in word:
+      break
+  else:
+    raise RuntimeError('Could not found python package directory.')
+  packages = pathlib.Path(word).parent
+  filename = packages / 'minerl/Malmo/Minecraft/run/options.txt'
+  options = filename.read_text()
+  if 'fovEffectScale:' not in options:
+    options += 'fovEffectScale:1.0\n'
+  if 'simulationDistance:' not in options:
+    options += 'simulationDistance:12\n'
+  for key, value in kwargs.items():
+    assert f'{key}:' in options, key
+    assert isinstance(value, str), (value, type(value))
+    options = re.sub(f'{key}:.*\n', f'{key}:{value}\n', options)
+  filename.write_text(options)
+
+
+edit_options(
+    difficulty='2',
+    renderDistance='6',
+    simulationDistance='6',
+    fovEffectScale='0.0',
+    ao='1',
+    gamma='5.0',
+)
+
+
+class MineRLEnv(EnvSpec):
+
+  def __init__(self, resolution=(64, 64), break_speed=50, gamma=10.0):
+    self.resolution = resolution
+    self.break_speed = break_speed
+    self.gamma = gamma
+    super().__init__(name='MineRLEnv-v1')
+
+  def create_agent_start(self):
+    return [
+        BreakSpeedMultiplier(self.break_speed),
+    ]
+
+  def create_agent_handlers(self):
+    return []
+
+  def create_server_world_generators(self):
+    return [handlers.DefaultWorldGenerator(force_reset=True)]
+
+  def create_server_quit_producers(self):
+    return [handlers.ServerQuitWhenAnyAgentFinishes()]
+
+  def create_server_initial_conditions(self):
+    return [
+        handlers.TimeInitialCondition(
+            allow_passage_of_time=True,
+            start_time=0,
+        ),
+        handlers.SpawningInitialCondition(
+            allow_spawning=True,
+        )
+    ]
+
+  def create_observables(self):
+    return [
+        handlers.POVObservation(self.resolution),
+        handlers.FlatInventoryObservation(mc.ALL_ITEMS),
+        handlers.EquippedItemObservation(
+            mc.ALL_ITEMS, _default='air', _other='other'),
+        handlers.ObservationFromCurrentLocation(),
+        handlers.ObservationFromLifeStats(),
+    ]
+
+  def create_actionables(self):
+    kw = dict(_other='none', _default='none')
+    return [
+        handlers.KeybasedCommandAction('forward', INVERSE_KEYMAP['forward']),
+        handlers.KeybasedCommandAction('back', INVERSE_KEYMAP['back']),
+        handlers.KeybasedCommandAction('left', INVERSE_KEYMAP['left']),
+        handlers.KeybasedCommandAction('right', INVERSE_KEYMAP['right']),
+        handlers.KeybasedCommandAction('jump', INVERSE_KEYMAP['jump']),
+        handlers.KeybasedCommandAction('sneak', INVERSE_KEYMAP['sneak']),
+        handlers.KeybasedCommandAction('attack', INVERSE_KEYMAP['attack']),
+        handlers.CameraAction(),
+        handlers.PlaceBlock(['none'] + mc.ALL_ITEMS, **kw),
+        handlers.EquipAction(['none'] + mc.ALL_ITEMS, **kw),
+        handlers.CraftAction(['none'] + mc.ALL_ITEMS, **kw),
+        handlers.CraftNearbyAction(['none'] + mc.ALL_ITEMS, **kw),
+        handlers.SmeltItemNearby(['none'] + mc.ALL_ITEMS, **kw),
+    ]
+
+  def is_from_folder(self, folder):
+    return folder == 'none'
+
+  def get_docstring(self):
+    return ''
+
+  def determine_success_from_rewards(self, rewards):
+    return True
+
+  def create_rewardables(self):
+    return []
+
+  def create_server_decorators(self):
+    return []
+
+  def create_mission_handlers(self):
+    return []
+
+  def create_monitors(self):
+    return []
+
+
+class BreakSpeedMultiplier(handler.Handler):
+
+  def __init__(self, multiplier=1.0):
+    self.multiplier = multiplier
+
+  def to_string(self):
+    return f'break_speed({self.multiplier})'
+
+  def xml_template(self):
+    return '<BreakSpeedMultiplier>{{multiplier}}</BreakSpeedMultiplier>'
+
+
+class Gamma(handler.Handler):
+
+  def __init__(self, gamma=2.0):
+    self.gamma = gamma
+
+  def to_string(self):
+    return f'gamma({self.gamma})'
+
+  def xml_template(self):
+    return '<GammaSetting>{{gamma}}</GammaSetting>'
+
+
+NOOP_ACTION = dict(
+    camera=(0, 0), forward=0, back=0, left=0, right=0, attack=0, sprint=0,
+    jump=0, sneak=0, craft='none', nearbyCraft='none', nearbySmelt='none',
+    place='none', equip='none',
+)
diff --git a/requirements.txt b/requirements.txt
index 077499a..fc7a2a3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,7 @@
 setuptools==60.0.0
 torch==2.0.0
 torchvision==0.15.1
-numpy==1.20.1
-tensorboard==2.5.0
+tensorboard==2.10.0
 pandas==1.2.4
 matplotlib==3.5.0
 ruamel.yaml==0.17.4
@@ -11,8 +10,12 @@ einops==0.3.0
 protobuf==3.20.0
 gym==0.19.0
 dm_control==1.0.9
-scipy==1.7.0
+scipy==1.8.0
 memory_maze==1.0.2
 atari-py==0.2.9
 crafter==1.8.0
-opencv-python==4.7.0.72
\ No newline at end of file
+opencv-python==4.7.0.72
+numpy==1.21.0
+# minerl==0.4.4
+# This was needed for minerl
+# conda install -c conda-forge openjdk=8