applied formatter
This commit is contained in:
parent
afa5ab988d
commit
12ed21e06d
33
dreamer.py
33
dreamer.py
@ -217,10 +217,12 @@ def make_env(config, mode):
|
|||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "crafter":
|
elif suite == "crafter":
|
||||||
import envs.crafter as crafter
|
import envs.crafter as crafter
|
||||||
|
|
||||||
env = crafter.Crafter(task, config.size)
|
env = crafter.Crafter(task, config.size)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "minecraft":
|
elif suite == "minecraft":
|
||||||
import envs.minecraft as minecraft
|
import envs.minecraft as minecraft
|
||||||
|
|
||||||
env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
|
env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
else:
|
else:
|
||||||
@ -294,7 +296,15 @@ def main(config):
|
|||||||
logprob = random_actor.log_prob(action)
|
logprob = random_actor.log_prob(action)
|
||||||
return {"action": action, "logprob": logprob}, None
|
return {"action": action, "logprob": logprob}, None
|
||||||
|
|
||||||
state = tools.simulate(random_agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=prefill)
|
state = tools.simulate(
|
||||||
|
random_agent,
|
||||||
|
train_envs,
|
||||||
|
train_eps,
|
||||||
|
config.traindir,
|
||||||
|
logger,
|
||||||
|
limit=config.dataset_size,
|
||||||
|
steps=prefill,
|
||||||
|
)
|
||||||
logger.step += prefill * config.action_repeat
|
logger.step += prefill * config.action_repeat
|
||||||
print(f"Logger: ({logger.step} steps).")
|
print(f"Logger: ({logger.step} steps).")
|
||||||
|
|
||||||
@ -317,12 +327,29 @@ def main(config):
|
|||||||
logger.write()
|
logger.write()
|
||||||
print("Start evaluation.")
|
print("Start evaluation.")
|
||||||
eval_policy = functools.partial(agent, training=False)
|
eval_policy = functools.partial(agent, training=False)
|
||||||
tools.simulate(eval_policy, eval_envs, eval_eps, config.evaldir, logger, is_eval=True, episodes=config.eval_episode_num)
|
tools.simulate(
|
||||||
|
eval_policy,
|
||||||
|
eval_envs,
|
||||||
|
eval_eps,
|
||||||
|
config.evaldir,
|
||||||
|
logger,
|
||||||
|
is_eval=True,
|
||||||
|
episodes=config.eval_episode_num,
|
||||||
|
)
|
||||||
if config.video_pred_log:
|
if config.video_pred_log:
|
||||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||||
logger.video("eval_openl", to_np(video_pred))
|
logger.video("eval_openl", to_np(video_pred))
|
||||||
print("Start training.")
|
print("Start training.")
|
||||||
state = tools.simulate(agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=config.eval_every, state=state)
|
state = tools.simulate(
|
||||||
|
agent,
|
||||||
|
train_envs,
|
||||||
|
train_eps,
|
||||||
|
config.traindir,
|
||||||
|
logger,
|
||||||
|
limit=config.dataset_size,
|
||||||
|
steps=config.eval_every,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||||
for env in train_envs + eval_envs:
|
for env in train_envs + eval_envs:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import numpy as np
|
|||||||
class Atari:
|
class Atari:
|
||||||
LOCK = None
|
LOCK = None
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
|
|
||||||
class Crafter:
|
class Crafter:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
def __init__(self, task, size=(64, 64), seed=None):
|
def __init__(self, task, size=(64, 64), seed=None):
|
||||||
assert task in ("reward", "noreward")
|
assert task in ("reward", "noreward")
|
||||||
import crafter
|
import crafter
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
|
|
||||||
class DeepMindControl:
|
class DeepMindControl:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||||
domain, task = name.split("_", 1)
|
domain, task = name.split("_", 1)
|
||||||
if domain == "cup": # Only domain with multiple words.
|
if domain == "cup": # Only domain with multiple words.
|
||||||
|
|||||||
@ -3,152 +3,148 @@ from . import minecraft_base
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
def make_env(task, *args, **kwargs):
|
def make_env(task, *args, **kwargs):
|
||||||
return {
|
return {
|
||||||
'wood': MinecraftWood,
|
"wood": MinecraftWood,
|
||||||
'climb': MinecraftClimb,
|
"climb": MinecraftClimb,
|
||||||
'diamond': MinecraftDiamond,
|
"diamond": MinecraftDiamond,
|
||||||
}[task](*args, **kwargs)
|
}[task](*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class MinecraftWood:
|
class MinecraftWood:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
actions = BASIC_ACTIONS
|
||||||
|
self.rewards = [
|
||||||
|
CollectReward("log", repeated=1),
|
||||||
|
HealthReward(),
|
||||||
|
]
|
||||||
|
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def step(self, action):
|
||||||
actions = BASIC_ACTIONS
|
obs, reward, done, info = self.env.step(action)
|
||||||
self.rewards = [
|
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
CollectReward('log', repeated=1),
|
obs["reward"] = reward
|
||||||
HealthReward(),
|
return obs, reward, done, info
|
||||||
]
|
|
||||||
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, done, info = self.env.step(action)
|
|
||||||
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
|
||||||
obs['reward'] = reward
|
|
||||||
return obs, reward, done, info
|
|
||||||
|
|
||||||
|
|
||||||
class MinecraftClimb:
|
class MinecraftClimb:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
actions = BASIC_ACTIONS
|
||||||
|
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
||||||
|
self._previous = None
|
||||||
|
self._health_reward = HealthReward()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def step(self, action):
|
||||||
actions = BASIC_ACTIONS
|
obs, reward, done, info = self.env.step(action)
|
||||||
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
x, y, z = obs["log_player_pos"]
|
||||||
self._previous = None
|
height = np.float32(y)
|
||||||
self._health_reward = HealthReward()
|
if obs["is_first"]:
|
||||||
|
self._previous = height
|
||||||
def step(self, action):
|
reward = height - self._previous
|
||||||
obs, reward, done, info = self.env.step(action)
|
reward += self._health_reward(obs)
|
||||||
x, y, z = obs['log_player_pos']
|
obs["reward"] = reward
|
||||||
height = np.float32(y)
|
self._previous = height
|
||||||
if obs['is_first']:
|
return obs, reward, done, info
|
||||||
self._previous = height
|
|
||||||
reward = height - self._previous
|
|
||||||
reward += self._health_reward(obs)
|
|
||||||
obs['reward'] = reward
|
|
||||||
self._previous = height
|
|
||||||
return obs, reward, done, info
|
|
||||||
|
|
||||||
|
|
||||||
class MinecraftDiamond(gym.Wrapper):
|
class MinecraftDiamond(gym.Wrapper):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
actions = {
|
||||||
|
**BASIC_ACTIONS,
|
||||||
|
"craft_planks": dict(craft="planks"),
|
||||||
|
"craft_stick": dict(craft="stick"),
|
||||||
|
"craft_crafting_table": dict(craft="crafting_table"),
|
||||||
|
"place_crafting_table": dict(place="crafting_table"),
|
||||||
|
"craft_wooden_pickaxe": dict(nearbyCraft="wooden_pickaxe"),
|
||||||
|
"craft_stone_pickaxe": dict(nearbyCraft="stone_pickaxe"),
|
||||||
|
"craft_iron_pickaxe": dict(nearbyCraft="iron_pickaxe"),
|
||||||
|
"equip_stone_pickaxe": dict(equip="stone_pickaxe"),
|
||||||
|
"equip_wooden_pickaxe": dict(equip="wooden_pickaxe"),
|
||||||
|
"equip_iron_pickaxe": dict(equip="iron_pickaxe"),
|
||||||
|
"craft_furnace": dict(nearbyCraft="furnace"),
|
||||||
|
"place_furnace": dict(place="furnace"),
|
||||||
|
"smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
|
||||||
|
}
|
||||||
|
self.rewards = [
|
||||||
|
CollectReward("log", once=1),
|
||||||
|
CollectReward("planks", once=1),
|
||||||
|
CollectReward("stick", once=1),
|
||||||
|
CollectReward("crafting_table", once=1),
|
||||||
|
CollectReward("wooden_pickaxe", once=1),
|
||||||
|
CollectReward("cobblestone", once=1),
|
||||||
|
CollectReward("stone_pickaxe", once=1),
|
||||||
|
CollectReward("iron_ore", once=1),
|
||||||
|
CollectReward("furnace", once=1),
|
||||||
|
CollectReward("iron_ingot", once=1),
|
||||||
|
CollectReward("iron_pickaxe", once=1),
|
||||||
|
CollectReward("diamond", once=1),
|
||||||
|
HealthReward(),
|
||||||
|
]
|
||||||
|
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def step(self, action):
|
||||||
actions = {
|
obs, reward, done, info = self.env.step(action)
|
||||||
**BASIC_ACTIONS,
|
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
'craft_planks': dict(craft='planks'),
|
obs["reward"] = reward
|
||||||
'craft_stick': dict(craft='stick'),
|
return obs, reward, done, info
|
||||||
'craft_crafting_table': dict(craft='crafting_table'),
|
|
||||||
'place_crafting_table': dict(place='crafting_table'),
|
|
||||||
'craft_wooden_pickaxe': dict(nearbyCraft='wooden_pickaxe'),
|
|
||||||
'craft_stone_pickaxe': dict(nearbyCraft='stone_pickaxe'),
|
|
||||||
'craft_iron_pickaxe': dict(nearbyCraft='iron_pickaxe'),
|
|
||||||
'equip_stone_pickaxe': dict(equip='stone_pickaxe'),
|
|
||||||
'equip_wooden_pickaxe': dict(equip='wooden_pickaxe'),
|
|
||||||
'equip_iron_pickaxe': dict(equip='iron_pickaxe'),
|
|
||||||
'craft_furnace': dict(nearbyCraft='furnace'),
|
|
||||||
'place_furnace': dict(place='furnace'),
|
|
||||||
'smelt_iron_ingot': dict(nearbySmelt='iron_ingot'),
|
|
||||||
}
|
|
||||||
self.rewards = [
|
|
||||||
CollectReward('log', once=1),
|
|
||||||
CollectReward('planks', once=1),
|
|
||||||
CollectReward('stick', once=1),
|
|
||||||
CollectReward('crafting_table', once=1),
|
|
||||||
CollectReward('wooden_pickaxe', once=1),
|
|
||||||
CollectReward('cobblestone', once=1),
|
|
||||||
CollectReward('stone_pickaxe', once=1),
|
|
||||||
CollectReward('iron_ore', once=1),
|
|
||||||
CollectReward('furnace', once=1),
|
|
||||||
CollectReward('iron_ingot', once=1),
|
|
||||||
CollectReward('iron_pickaxe', once=1),
|
|
||||||
CollectReward('diamond', once=1),
|
|
||||||
HealthReward(),
|
|
||||||
]
|
|
||||||
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
def step(self, action):
|
def reset(self):
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs = self.env.reset()
|
||||||
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
# called for reset of reward calculations
|
||||||
obs['reward'] = reward
|
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
return obs, reward, done, info
|
return obs
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
obs = self.env.reset()
|
|
||||||
# called for reset of reward calculations
|
|
||||||
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
|
||||||
return obs
|
|
||||||
|
|
||||||
|
|
||||||
class CollectReward:
|
class CollectReward:
|
||||||
|
def __init__(self, item, once=0, repeated=0):
|
||||||
|
self.item = item
|
||||||
|
self.once = once
|
||||||
|
self.repeated = repeated
|
||||||
|
self.previous = 0
|
||||||
|
self.maximum = 0
|
||||||
|
|
||||||
def __init__(self, item, once=0, repeated=0):
|
def __call__(self, obs, inventory):
|
||||||
self.item = item
|
current = inventory[self.item]
|
||||||
self.once = once
|
if obs["is_first"]:
|
||||||
self.repeated = repeated
|
self.previous = current
|
||||||
self.previous = 0
|
self.maximum = current
|
||||||
self.maximum = 0
|
return 0
|
||||||
|
reward = self.repeated * max(0, current - self.previous)
|
||||||
def __call__(self, obs, inventory):
|
if self.maximum == 0 and current > 0:
|
||||||
current = inventory[self.item]
|
reward += self.once
|
||||||
if obs['is_first']:
|
self.previous = current
|
||||||
self.previous = current
|
self.maximum = max(self.maximum, current)
|
||||||
self.maximum = current
|
return reward
|
||||||
return 0
|
|
||||||
reward = self.repeated * max(0, current - self.previous)
|
|
||||||
if self.maximum == 0 and current > 0:
|
|
||||||
reward += self.once
|
|
||||||
self.previous = current
|
|
||||||
self.maximum = max(self.maximum, current)
|
|
||||||
return reward
|
|
||||||
|
|
||||||
|
|
||||||
class HealthReward:
|
class HealthReward:
|
||||||
|
def __init__(self, scale=0.01):
|
||||||
|
self.scale = scale
|
||||||
|
self.previous = None
|
||||||
|
|
||||||
def __init__(self, scale=0.01):
|
def __call__(self, obs, inventory=None):
|
||||||
self.scale = scale
|
health = obs["health"]
|
||||||
self.previous = None
|
if obs["is_first"]:
|
||||||
|
self.previous = health
|
||||||
def __call__(self, obs, inventory=None):
|
return 0
|
||||||
health = obs['health']
|
reward = self.scale * (health - self.previous)
|
||||||
if obs['is_first']:
|
self.previous = health
|
||||||
self.previous = health
|
return np.float32(reward)
|
||||||
return 0
|
|
||||||
reward = self.scale * (health - self.previous)
|
|
||||||
self.previous = health
|
|
||||||
return np.float32(reward)
|
|
||||||
|
|
||||||
|
|
||||||
BASIC_ACTIONS = {
|
BASIC_ACTIONS = {
|
||||||
'noop': dict(),
|
"noop": dict(),
|
||||||
'attack': dict(attack=1),
|
"attack": dict(attack=1),
|
||||||
'turn_up': dict(camera=(-15, 0)),
|
"turn_up": dict(camera=(-15, 0)),
|
||||||
'turn_down': dict(camera=(15, 0)),
|
"turn_down": dict(camera=(15, 0)),
|
||||||
'turn_left': dict(camera=(0, -15)),
|
"turn_left": dict(camera=(0, -15)),
|
||||||
'turn_right': dict(camera=(0, 15)),
|
"turn_right": dict(camera=(0, 15)),
|
||||||
'forward': dict(forward=1),
|
"forward": dict(forward=1),
|
||||||
'back': dict(back=1),
|
"back": dict(back=1),
|
||||||
'left': dict(left=1),
|
"left": dict(left=1),
|
||||||
'right': dict(right=1),
|
"right": dict(right=1),
|
||||||
'jump': dict(jump=1, forward=1),
|
"jump": dict(jump=1, forward=1),
|
||||||
'place_dirt': dict(place='dirt'),
|
"place_dirt": dict(place="dirt"),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,215 +4,232 @@ import threading
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
class MinecraftBase(gym.Env):
|
class MinecraftBase(gym.Env):
|
||||||
|
_LOCK = threading.Lock()
|
||||||
|
|
||||||
_LOCK = threading.Lock()
|
def __init__(
|
||||||
|
self,
|
||||||
|
actions,
|
||||||
|
repeat=1,
|
||||||
|
size=(64, 64),
|
||||||
|
break_speed=100.0,
|
||||||
|
gamma=10.0,
|
||||||
|
sticky_attack=30,
|
||||||
|
sticky_jump=10,
|
||||||
|
pitch_limit=(-60, 60),
|
||||||
|
logs=True,
|
||||||
|
):
|
||||||
|
if logs:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
self._repeat = repeat
|
||||||
|
self._size = size
|
||||||
|
if break_speed != 1.0:
|
||||||
|
sticky_attack = 0
|
||||||
|
|
||||||
def __init__(
|
# Make env
|
||||||
self, actions,
|
with self._LOCK:
|
||||||
repeat=1,
|
from . import minecraft_minerl
|
||||||
size=(64, 64),
|
|
||||||
break_speed=100.0,
|
|
||||||
gamma=10.0,
|
|
||||||
sticky_attack=30,
|
|
||||||
sticky_jump=10,
|
|
||||||
pitch_limit=(-60, 60),
|
|
||||||
logs=True,
|
|
||||||
):
|
|
||||||
if logs:
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
self._repeat = repeat
|
|
||||||
self._size = size
|
|
||||||
if break_speed != 1.0:
|
|
||||||
sticky_attack = 0
|
|
||||||
|
|
||||||
# Make env
|
self._env = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make()
|
||||||
with self._LOCK:
|
self._inventory = {}
|
||||||
from .import minecraft_minerl
|
|
||||||
self._env = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make()
|
|
||||||
self._inventory = {}
|
|
||||||
|
|
||||||
# Observations
|
# Observations
|
||||||
self._inv_keys = [
|
self._inv_keys = [
|
||||||
k for k in self._flatten(self._env.observation_space.spaces) if k.startswith('inventory/')
|
k
|
||||||
if k != 'inventory/log2']
|
for k in self._flatten(self._env.observation_space.spaces)
|
||||||
self._step = 0
|
if k.startswith("inventory/")
|
||||||
self._max_inventory = None
|
if k != "inventory/log2"
|
||||||
self._equip_enum = self._env.observation_space[
|
]
|
||||||
'equipped_items']['mainhand']['type'].values.tolist()
|
self._step = 0
|
||||||
|
self._max_inventory = None
|
||||||
|
self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][
|
||||||
|
"type"
|
||||||
|
].values.tolist()
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
self._noop_action = minecraft_minerl.NOOP_ACTION
|
self._noop_action = minecraft_minerl.NOOP_ACTION
|
||||||
actions = self._insert_defaults(actions)
|
actions = self._insert_defaults(actions)
|
||||||
self._action_names = tuple(actions.keys())
|
self._action_names = tuple(actions.keys())
|
||||||
self._action_values = tuple(actions.values())
|
self._action_values = tuple(actions.values())
|
||||||
message = f'Minecraft action space ({len(self._action_values)}):'
|
message = f"Minecraft action space ({len(self._action_values)}):"
|
||||||
print(message, ', '.join(self._action_names))
|
print(message, ", ".join(self._action_names))
|
||||||
self._sticky_attack_length = sticky_attack
|
self._sticky_attack_length = sticky_attack
|
||||||
self._sticky_attack_counter = 0
|
self._sticky_attack_counter = 0
|
||||||
self._sticky_jump_length = sticky_jump
|
self._sticky_jump_length = sticky_jump
|
||||||
self._sticky_jump_counter = 0
|
self._sticky_jump_counter = 0
|
||||||
self._pitch_limit = pitch_limit
|
self._pitch_limit = pitch_limit
|
||||||
self._pitch = 0
|
self._pitch = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return gym.spaces.Dict(
|
return gym.spaces.Dict(
|
||||||
{
|
{
|
||||||
'image': gym.spaces.Box(0, 255, self._size + (3,), np.uint8),
|
"image": gym.spaces.Box(0, 255, self._size + (3,), np.uint8),
|
||||||
'inventory': gym.spaces.Box(-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32),
|
"inventory": gym.spaces.Box(
|
||||||
'inventory_max': gym.spaces.Box(-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32),
|
-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32
|
||||||
'equipped': gym.spaces.Box(-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32),
|
),
|
||||||
'reward': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
"inventory_max": gym.spaces.Box(
|
||||||
'health': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
-np.inf, np.inf, (len(self._inv_keys),), dtype=np.float32
|
||||||
'hunger': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
),
|
||||||
'breath': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
"equipped": gym.spaces.Box(
|
||||||
'is_first': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32
|
||||||
'is_last': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
),
|
||||||
'is_terminal': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
**{f'log_{k}': gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.int64) for k in self._inv_keys},
|
"health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
'log_player_pos': gym.spaces.Box(-np.inf, np.inf, (3,), dtype=np.float32),
|
"hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
|
"breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
|
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
|
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
|
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
|
**{
|
||||||
|
f"log_{k}": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.int64)
|
||||||
|
for k in self._inv_keys
|
||||||
|
},
|
||||||
|
"log_player_pos": gym.spaces.Box(
|
||||||
|
-np.inf, np.inf, (3,), dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_space(self):
|
||||||
|
space = gym.spaces.discrete.Discrete(len(self._action_values))
|
||||||
|
space.discrete = True
|
||||||
|
return space
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action = action.copy()
|
||||||
|
action = self._action_values[action]
|
||||||
|
action = self._action(action)
|
||||||
|
following = self._noop_action.copy()
|
||||||
|
for key in ("attack", "forward", "back", "left", "right"):
|
||||||
|
following[key] = action[key]
|
||||||
|
for act in [action] + ([following] * (self._repeat - 1)):
|
||||||
|
obs, reward, done, info = self._env.step(act)
|
||||||
|
if "error" in info:
|
||||||
|
done = True
|
||||||
|
break
|
||||||
|
obs["is_first"] = False
|
||||||
|
obs["is_last"] = bool(done)
|
||||||
|
obs["is_terminal"] = bool(info.get("is_terminal", done))
|
||||||
|
|
||||||
|
obs = self._obs(obs)
|
||||||
|
self._step += 1
|
||||||
|
assert "pov" not in obs, list(obs.keys())
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inventory(self):
|
||||||
|
return self._inventory
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# inventory will be added in _obs
|
||||||
|
self._inventory = {}
|
||||||
|
self._max_inventory = None
|
||||||
|
|
||||||
|
with self._LOCK:
|
||||||
|
obs = self._env.reset()
|
||||||
|
obs["is_first"] = True
|
||||||
|
obs["is_last"] = False
|
||||||
|
obs["is_terminal"] = False
|
||||||
|
obs = self._obs(obs)
|
||||||
|
|
||||||
|
self._step = 0
|
||||||
|
self._sticky_attack_counter = 0
|
||||||
|
self._sticky_jump_counter = 0
|
||||||
|
self._pitch = 0
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _obs(self, obs):
|
||||||
|
obs = self._flatten(obs)
|
||||||
|
obs["inventory/log"] += obs.pop("inventory/log2")
|
||||||
|
self._inventory = {
|
||||||
|
k.split("/", 1)[1]: obs[k] for k in self._inv_keys if k != "inventory/air"
|
||||||
}
|
}
|
||||||
)
|
inventory = np.array([obs[k] for k in self._inv_keys], np.float32)
|
||||||
|
if self._max_inventory is None:
|
||||||
|
self._max_inventory = inventory
|
||||||
|
else:
|
||||||
|
self._max_inventory = np.maximum(self._max_inventory, inventory)
|
||||||
|
index = self._equip_enum.index(obs["equipped_items/mainhand/type"])
|
||||||
|
equipped = np.zeros(len(self._equip_enum), np.float32)
|
||||||
|
equipped[index] = 1.0
|
||||||
|
player_x = obs["location_stats/xpos"]
|
||||||
|
player_y = obs["location_stats/ypos"]
|
||||||
|
player_z = obs["location_stats/zpos"]
|
||||||
|
obs = {
|
||||||
|
"image": obs["pov"],
|
||||||
|
"inventory": inventory,
|
||||||
|
"inventory_max": self._max_inventory.copy(),
|
||||||
|
"equipped": equipped,
|
||||||
|
"health": np.float32(obs["life_stats/life"] / 20),
|
||||||
|
"hunger": np.float32(obs["life_stats/food"] / 20),
|
||||||
|
"breath": np.float32(obs["life_stats/air"] / 300),
|
||||||
|
"reward": 0.0,
|
||||||
|
"is_first": obs["is_first"],
|
||||||
|
"is_last": obs["is_last"],
|
||||||
|
"is_terminal": obs["is_terminal"],
|
||||||
|
**{f"log_{k}": np.int64(obs[k]) for k in self._inv_keys},
|
||||||
|
"log_player_pos": np.array([player_x, player_y, player_z], np.float32),
|
||||||
|
}
|
||||||
|
for key, value in obs.items():
|
||||||
|
space = self.observation_space[key]
|
||||||
|
if not isinstance(value, np.ndarray):
|
||||||
|
value = np.array(value)
|
||||||
|
assert (key, value, value.dtype, value.shape, space)
|
||||||
|
return obs
|
||||||
|
|
||||||
@property
|
def _action(self, action):
|
||||||
def action_space(self):
|
if self._sticky_attack_length:
|
||||||
space = gym.spaces.discrete.Discrete(len(self._action_values))
|
if action["attack"]:
|
||||||
space.discrete = True
|
self._sticky_attack_counter = self._sticky_attack_length
|
||||||
return space
|
if self._sticky_attack_counter > 0:
|
||||||
|
action["attack"] = 1
|
||||||
|
action["jump"] = 0
|
||||||
|
self._sticky_attack_counter -= 1
|
||||||
|
if self._sticky_jump_length:
|
||||||
|
if action["jump"]:
|
||||||
|
self._sticky_jump_counter = self._sticky_jump_length
|
||||||
|
if self._sticky_jump_counter > 0:
|
||||||
|
action["jump"] = 1
|
||||||
|
action["forward"] = 1
|
||||||
|
self._sticky_jump_counter -= 1
|
||||||
|
if self._pitch_limit and action["camera"][0]:
|
||||||
|
lo, hi = self._pitch_limit
|
||||||
|
if not (lo <= self._pitch + action["camera"][0] <= hi):
|
||||||
|
action["camera"] = (0, action["camera"][1])
|
||||||
|
self._pitch += action["camera"][0]
|
||||||
|
return action
|
||||||
|
|
||||||
def step(self, action):
|
def _insert_defaults(self, actions):
|
||||||
action = action.copy()
|
actions = {name: action.copy() for name, action in actions.items()}
|
||||||
action = self._action_values[action]
|
for key, default in self._noop_action.items():
|
||||||
action = self._action(action)
|
for action in actions.values():
|
||||||
following = self._noop_action.copy()
|
if key not in action:
|
||||||
for key in ('attack', 'forward', 'back', 'left', 'right'):
|
action[key] = default
|
||||||
following[key] = action[key]
|
return actions
|
||||||
for act in [action] + ([following] * (self._repeat - 1)):
|
|
||||||
obs, reward, done, info = self._env.step(act)
|
|
||||||
if 'error' in info:
|
|
||||||
done = True
|
|
||||||
break
|
|
||||||
obs['is_first'] = False
|
|
||||||
obs['is_last'] = bool(done)
|
|
||||||
obs['is_terminal'] = bool(info.get('is_terminal', done))
|
|
||||||
|
|
||||||
obs = self._obs(obs)
|
def _flatten(self, nest, prefix=None):
|
||||||
self._step += 1
|
result = {}
|
||||||
assert 'pov' not in obs, list(obs.keys())
|
for key, value in nest.items():
|
||||||
return obs, reward, done, info
|
key = prefix + "/" + key if prefix else key
|
||||||
|
if isinstance(value, gym.spaces.Dict):
|
||||||
|
value = value.spaces
|
||||||
|
if isinstance(value, dict):
|
||||||
|
result.update(self._flatten(value, key))
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
def _unflatten(self, flat):
|
||||||
def inventory(self):
|
result = {}
|
||||||
return self._inventory
|
for key, value in flat.items():
|
||||||
|
parts = key.split("/")
|
||||||
def reset(self):
|
node = result
|
||||||
# inventory will be added in _obs
|
for part in parts[:-1]:
|
||||||
self._inventory = {}
|
if part not in node:
|
||||||
self._max_inventory = None
|
node[part] = {}
|
||||||
|
node = node[part]
|
||||||
with self._LOCK:
|
node[parts[-1]] = value
|
||||||
obs = self._env.reset()
|
return result
|
||||||
obs['is_first'] = True
|
|
||||||
obs['is_last'] = False
|
|
||||||
obs['is_terminal'] = False
|
|
||||||
obs = self._obs(obs)
|
|
||||||
|
|
||||||
self._step = 0
|
|
||||||
self._sticky_attack_counter = 0
|
|
||||||
self._sticky_jump_counter = 0
|
|
||||||
self._pitch = 0
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def _obs(self, obs):
|
|
||||||
obs = self._flatten(obs)
|
|
||||||
obs['inventory/log'] += obs.pop('inventory/log2')
|
|
||||||
self._inventory = {
|
|
||||||
k.split('/', 1)[1]: obs[k] for k in self._inv_keys
|
|
||||||
if k != 'inventory/air'}
|
|
||||||
inventory = np.array([obs[k] for k in self._inv_keys], np.float32)
|
|
||||||
if self._max_inventory is None:
|
|
||||||
self._max_inventory = inventory
|
|
||||||
else:
|
|
||||||
self._max_inventory = np.maximum(self._max_inventory, inventory)
|
|
||||||
index = self._equip_enum.index(obs['equipped_items/mainhand/type'])
|
|
||||||
equipped = np.zeros(len(self._equip_enum), np.float32)
|
|
||||||
equipped[index] = 1.0
|
|
||||||
player_x = obs['location_stats/xpos']
|
|
||||||
player_y = obs['location_stats/ypos']
|
|
||||||
player_z = obs['location_stats/zpos']
|
|
||||||
obs = {
|
|
||||||
'image': obs['pov'],
|
|
||||||
'inventory': inventory,
|
|
||||||
'inventory_max': self._max_inventory.copy(),
|
|
||||||
'equipped': equipped,
|
|
||||||
'health': np.float32(obs['life_stats/life'] / 20),
|
|
||||||
'hunger': np.float32(obs['life_stats/food'] / 20),
|
|
||||||
'breath': np.float32(obs['life_stats/air'] / 300),
|
|
||||||
'reward': 0.0,
|
|
||||||
'is_first': obs['is_first'],
|
|
||||||
'is_last': obs['is_last'],
|
|
||||||
'is_terminal': obs['is_terminal'],
|
|
||||||
**{f'log_{k}': np.int64(obs[k]) for k in self._inv_keys},
|
|
||||||
'log_player_pos': np.array([player_x, player_y, player_z], np.float32),
|
|
||||||
}
|
|
||||||
for key, value in obs.items():
|
|
||||||
space = self.observation_space[key]
|
|
||||||
if not isinstance(value, np.ndarray):
|
|
||||||
value = np.array(value)
|
|
||||||
assert (key, value, value.dtype, value.shape, space)
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def _action(self, action):
|
|
||||||
if self._sticky_attack_length:
|
|
||||||
if action['attack']:
|
|
||||||
self._sticky_attack_counter = self._sticky_attack_length
|
|
||||||
if self._sticky_attack_counter > 0:
|
|
||||||
action['attack'] = 1
|
|
||||||
action['jump'] = 0
|
|
||||||
self._sticky_attack_counter -= 1
|
|
||||||
if self._sticky_jump_length:
|
|
||||||
if action['jump']:
|
|
||||||
self._sticky_jump_counter = self._sticky_jump_length
|
|
||||||
if self._sticky_jump_counter > 0:
|
|
||||||
action['jump'] = 1
|
|
||||||
action['forward'] = 1
|
|
||||||
self._sticky_jump_counter -= 1
|
|
||||||
if self._pitch_limit and action['camera'][0]:
|
|
||||||
lo, hi = self._pitch_limit
|
|
||||||
if not (lo <= self._pitch + action['camera'][0] <= hi):
|
|
||||||
action['camera'] = (0, action['camera'][1])
|
|
||||||
self._pitch += action['camera'][0]
|
|
||||||
return action
|
|
||||||
|
|
||||||
def _insert_defaults(self, actions):
|
|
||||||
actions = {name: action.copy() for name, action in actions.items()}
|
|
||||||
for key, default in self._noop_action.items():
|
|
||||||
for action in actions.values():
|
|
||||||
if key not in action:
|
|
||||||
action[key] = default
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def _flatten(self, nest, prefix=None):
|
|
||||||
result = {}
|
|
||||||
for key, value in nest.items():
|
|
||||||
key = prefix + '/' + key if prefix else key
|
|
||||||
if isinstance(value, gym.spaces.Dict):
|
|
||||||
value = value.spaces
|
|
||||||
if isinstance(value, dict):
|
|
||||||
result.update(self._flatten(value, key))
|
|
||||||
else:
|
|
||||||
result[key] = value
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _unflatten(self, flat):
|
|
||||||
result = {}
|
|
||||||
for key, value in flat.items():
|
|
||||||
parts = key.split('/')
|
|
||||||
node = result
|
|
||||||
for part in parts[:-1]:
|
|
||||||
if part not in node:
|
|
||||||
node[part] = {}
|
|
||||||
node = node[part]
|
|
||||||
node[parts[-1]] = value
|
|
||||||
return result
|
|
||||||
|
|||||||
@ -6,145 +6,155 @@ from minerl.herobraine.hero.mc import INVERSE_KEYMAP
|
|||||||
|
|
||||||
|
|
||||||
def edit_options(**kwargs):
|
def edit_options(**kwargs):
|
||||||
import os, pathlib, re
|
import os, pathlib, re
|
||||||
for word in os.popen('pip3 --version').read().split(' '):
|
|
||||||
if '-packages/pip' in word:
|
for word in os.popen("pip3 --version").read().split(" "):
|
||||||
break
|
if "-packages/pip" in word:
|
||||||
else:
|
break
|
||||||
raise RuntimeError('Could not found python package directory.')
|
else:
|
||||||
packages = pathlib.Path(word).parent
|
raise RuntimeError("Could not found python package directory.")
|
||||||
filename = packages / 'minerl/Malmo/Minecraft/run/options.txt'
|
packages = pathlib.Path(word).parent
|
||||||
options = filename.read_text()
|
filename = packages / "minerl/Malmo/Minecraft/run/options.txt"
|
||||||
if 'fovEffectScale:' not in options:
|
options = filename.read_text()
|
||||||
options += 'fovEffectScale:1.0\n'
|
if "fovEffectScale:" not in options:
|
||||||
if 'simulationDistance:' not in options:
|
options += "fovEffectScale:1.0\n"
|
||||||
options += 'simulationDistance:12\n'
|
if "simulationDistance:" not in options:
|
||||||
for key, value in kwargs.items():
|
options += "simulationDistance:12\n"
|
||||||
assert f'{key}:' in options, key
|
for key, value in kwargs.items():
|
||||||
assert isinstance(value, str), (value, type(value))
|
assert f"{key}:" in options, key
|
||||||
options = re.sub(f'{key}:.*\n', f'{key}:{value}\n', options)
|
assert isinstance(value, str), (value, type(value))
|
||||||
filename.write_text(options)
|
options = re.sub(f"{key}:.*\n", f"{key}:{value}\n", options)
|
||||||
|
filename.write_text(options)
|
||||||
|
|
||||||
|
|
||||||
edit_options(
|
edit_options(
|
||||||
difficulty='2',
|
difficulty="2",
|
||||||
renderDistance='6',
|
renderDistance="6",
|
||||||
simulationDistance='6',
|
simulationDistance="6",
|
||||||
fovEffectScale='0.0',
|
fovEffectScale="0.0",
|
||||||
ao='1',
|
ao="1",
|
||||||
gamma='5.0',
|
gamma="5.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MineRLEnv(EnvSpec):
|
class MineRLEnv(EnvSpec):
|
||||||
|
def __init__(self, resolution=(64, 64), break_speed=50, gamma=10.0):
|
||||||
|
self.resolution = resolution
|
||||||
|
self.break_speed = break_speed
|
||||||
|
self.gamma = gamma
|
||||||
|
super().__init__(name="MineRLEnv-v1")
|
||||||
|
|
||||||
def __init__(self, resolution=(64, 64), break_speed=50, gamma=10.0):
|
def create_agent_start(self):
|
||||||
self.resolution = resolution
|
return [
|
||||||
self.break_speed = break_speed
|
BreakSpeedMultiplier(self.break_speed),
|
||||||
self.gamma = gamma
|
]
|
||||||
super().__init__(name='MineRLEnv-v1')
|
|
||||||
|
|
||||||
def create_agent_start(self):
|
def create_agent_handlers(self):
|
||||||
return [
|
return []
|
||||||
BreakSpeedMultiplier(self.break_speed),
|
|
||||||
]
|
|
||||||
|
|
||||||
def create_agent_handlers(self):
|
def create_server_world_generators(self):
|
||||||
return []
|
return [handlers.DefaultWorldGenerator(force_reset=True)]
|
||||||
|
|
||||||
def create_server_world_generators(self):
|
def create_server_quit_producers(self):
|
||||||
return [handlers.DefaultWorldGenerator(force_reset=True)]
|
return [handlers.ServerQuitWhenAnyAgentFinishes()]
|
||||||
|
|
||||||
def create_server_quit_producers(self):
|
def create_server_initial_conditions(self):
|
||||||
return [handlers.ServerQuitWhenAnyAgentFinishes()]
|
return [
|
||||||
|
handlers.TimeInitialCondition(
|
||||||
|
allow_passage_of_time=True,
|
||||||
|
start_time=0,
|
||||||
|
),
|
||||||
|
handlers.SpawningInitialCondition(
|
||||||
|
allow_spawning=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
def create_server_initial_conditions(self):
|
def create_observables(self):
|
||||||
return [
|
return [
|
||||||
handlers.TimeInitialCondition(
|
handlers.POVObservation(self.resolution),
|
||||||
allow_passage_of_time=True,
|
handlers.FlatInventoryObservation(mc.ALL_ITEMS),
|
||||||
start_time=0,
|
handlers.EquippedItemObservation(
|
||||||
),
|
mc.ALL_ITEMS, _default="air", _other="other"
|
||||||
handlers.SpawningInitialCondition(
|
),
|
||||||
allow_spawning=True,
|
handlers.ObservationFromCurrentLocation(),
|
||||||
)
|
handlers.ObservationFromLifeStats(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def create_observables(self):
|
def create_actionables(self):
|
||||||
return [
|
kw = dict(_other="none", _default="none")
|
||||||
handlers.POVObservation(self.resolution),
|
return [
|
||||||
handlers.FlatInventoryObservation(mc.ALL_ITEMS),
|
handlers.KeybasedCommandAction("forward", INVERSE_KEYMAP["forward"]),
|
||||||
handlers.EquippedItemObservation(
|
handlers.KeybasedCommandAction("back", INVERSE_KEYMAP["back"]),
|
||||||
mc.ALL_ITEMS, _default='air', _other='other'),
|
handlers.KeybasedCommandAction("left", INVERSE_KEYMAP["left"]),
|
||||||
handlers.ObservationFromCurrentLocation(),
|
handlers.KeybasedCommandAction("right", INVERSE_KEYMAP["right"]),
|
||||||
handlers.ObservationFromLifeStats(),
|
handlers.KeybasedCommandAction("jump", INVERSE_KEYMAP["jump"]),
|
||||||
]
|
handlers.KeybasedCommandAction("sneak", INVERSE_KEYMAP["sneak"]),
|
||||||
|
handlers.KeybasedCommandAction("attack", INVERSE_KEYMAP["attack"]),
|
||||||
|
handlers.CameraAction(),
|
||||||
|
handlers.PlaceBlock(["none"] + mc.ALL_ITEMS, **kw),
|
||||||
|
handlers.EquipAction(["none"] + mc.ALL_ITEMS, **kw),
|
||||||
|
handlers.CraftAction(["none"] + mc.ALL_ITEMS, **kw),
|
||||||
|
handlers.CraftNearbyAction(["none"] + mc.ALL_ITEMS, **kw),
|
||||||
|
handlers.SmeltItemNearby(["none"] + mc.ALL_ITEMS, **kw),
|
||||||
|
]
|
||||||
|
|
||||||
def create_actionables(self):
|
def is_from_folder(self, folder):
|
||||||
kw = dict(_other='none', _default='none')
|
return folder == "none"
|
||||||
return [
|
|
||||||
handlers.KeybasedCommandAction('forward', INVERSE_KEYMAP['forward']),
|
|
||||||
handlers.KeybasedCommandAction('back', INVERSE_KEYMAP['back']),
|
|
||||||
handlers.KeybasedCommandAction('left', INVERSE_KEYMAP['left']),
|
|
||||||
handlers.KeybasedCommandAction('right', INVERSE_KEYMAP['right']),
|
|
||||||
handlers.KeybasedCommandAction('jump', INVERSE_KEYMAP['jump']),
|
|
||||||
handlers.KeybasedCommandAction('sneak', INVERSE_KEYMAP['sneak']),
|
|
||||||
handlers.KeybasedCommandAction('attack', INVERSE_KEYMAP['attack']),
|
|
||||||
handlers.CameraAction(),
|
|
||||||
handlers.PlaceBlock(['none'] + mc.ALL_ITEMS, **kw),
|
|
||||||
handlers.EquipAction(['none'] + mc.ALL_ITEMS, **kw),
|
|
||||||
handlers.CraftAction(['none'] + mc.ALL_ITEMS, **kw),
|
|
||||||
handlers.CraftNearbyAction(['none'] + mc.ALL_ITEMS, **kw),
|
|
||||||
handlers.SmeltItemNearby(['none'] + mc.ALL_ITEMS, **kw),
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_from_folder(self, folder):
|
def get_docstring(self):
|
||||||
return folder == 'none'
|
return ""
|
||||||
|
|
||||||
def get_docstring(self):
|
def determine_success_from_rewards(self, rewards):
|
||||||
return ''
|
return True
|
||||||
|
|
||||||
def determine_success_from_rewards(self, rewards):
|
def create_rewardables(self):
|
||||||
return True
|
return []
|
||||||
|
|
||||||
def create_rewardables(self):
|
def create_server_decorators(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def create_server_decorators(self):
|
def create_mission_handlers(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def create_mission_handlers(self):
|
def create_monitors(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def create_monitors(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class BreakSpeedMultiplier(handler.Handler):
|
class BreakSpeedMultiplier(handler.Handler):
|
||||||
|
def __init__(self, multiplier=1.0):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
def __init__(self, multiplier=1.0):
|
def to_string(self):
|
||||||
self.multiplier = multiplier
|
return f"break_speed({self.multiplier})"
|
||||||
|
|
||||||
def to_string(self):
|
def xml_template(self):
|
||||||
return f'break_speed({self.multiplier})'
|
return "<BreakSpeedMultiplier>{{multiplier}}</BreakSpeedMultiplier>"
|
||||||
|
|
||||||
def xml_template(self):
|
|
||||||
return '<BreakSpeedMultiplier>{{multiplier}}</BreakSpeedMultiplier>'
|
|
||||||
|
|
||||||
|
|
||||||
class Gamma(handler.Handler):
|
class Gamma(handler.Handler):
|
||||||
|
def __init__(self, gamma=2.0):
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
def __init__(self, gamma=2.0):
|
def to_string(self):
|
||||||
self.gamma = gamma
|
return f"gamma({self.gamma})"
|
||||||
|
|
||||||
def to_string(self):
|
def xml_template(self):
|
||||||
return f'gamma({self.gamma})'
|
return "<GammaSetting>{{gamma}}</GammaSetting>"
|
||||||
|
|
||||||
def xml_template(self):
|
|
||||||
return '<GammaSetting>{{gamma}}</GammaSetting>'
|
|
||||||
|
|
||||||
|
|
||||||
NOOP_ACTION = dict(
|
NOOP_ACTION = dict(
|
||||||
camera=(0, 0), forward=0, back=0, left=0, right=0, attack=0, sprint=0,
|
camera=(0, 0),
|
||||||
jump=0, sneak=0, craft='none', nearbyCraft='none', nearbySmelt='none',
|
forward=0,
|
||||||
place='none', equip='none',
|
back=0,
|
||||||
|
left=0,
|
||||||
|
right=0,
|
||||||
|
attack=0,
|
||||||
|
sprint=0,
|
||||||
|
jump=0,
|
||||||
|
sneak=0,
|
||||||
|
craft="none",
|
||||||
|
nearbyCraft="none",
|
||||||
|
nearbySmelt="none",
|
||||||
|
place="none",
|
||||||
|
equip="none",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -52,7 +52,6 @@ class OneHotAction(gym.Wrapper):
|
|||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self._random = np.random.RandomState()
|
self._random = np.random.RandomState()
|
||||||
|
|
||||||
|
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
shape = (self.env.action_space.n,)
|
shape = (self.env.action_space.n,)
|
||||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
@ -83,7 +82,6 @@ class RewardObs(gym.Wrapper):
|
|||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
spaces = self.env.observation_space.spaces
|
spaces = self.env.observation_space.spaces
|
||||||
if "reward" not in spaces:
|
if "reward" not in spaces:
|
||||||
@ -110,17 +108,16 @@ class SelectAction(gym.Wrapper):
|
|||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self._key = key
|
self._key = key
|
||||||
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
return self.env.step(action[self._key])
|
return self.env.step(action[self._key])
|
||||||
|
|
||||||
|
|
||||||
class UUID(gym.Wrapper):
|
class UUID(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||||
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||||
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
||||||
|
|||||||
@ -194,6 +194,7 @@ class Future:
|
|||||||
self._complete = True
|
self._complete = True
|
||||||
return self._result
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
class Damy:
|
class Damy:
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self._env = env
|
self._env = env
|
||||||
@ -202,7 +203,7 @@ class Damy:
|
|||||||
return getattr(self._env, name)
|
return getattr(self._env, name)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
return lambda :self._env.step(action)
|
return lambda: self._env.step(action)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return lambda :self._env.reset()
|
return lambda: self._env.reset()
|
||||||
|
|||||||
19
tools.py
19
tools.py
@ -122,7 +122,18 @@ class Logger:
|
|||||||
self._writer.add_video(name, value, step, 16)
|
self._writer.add_video(name, value, step, 16)
|
||||||
|
|
||||||
|
|
||||||
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
|
def simulate(
|
||||||
|
agent,
|
||||||
|
envs,
|
||||||
|
cache,
|
||||||
|
directory,
|
||||||
|
logger,
|
||||||
|
is_eval=False,
|
||||||
|
limit=None,
|
||||||
|
steps=0,
|
||||||
|
episodes=0,
|
||||||
|
state=None,
|
||||||
|
):
|
||||||
# initialize or unpack simulation state
|
# initialize or unpack simulation state
|
||||||
if state is None:
|
if state is None:
|
||||||
step, episode = 0, 0
|
step, episode = 0, 0
|
||||||
@ -200,7 +211,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
logger.scalar(f"train_episodes", len(cache))
|
logger.scalar(f"train_episodes", len(cache))
|
||||||
logger.write(step=logger.step)
|
logger.write(step=logger.step)
|
||||||
else:
|
else:
|
||||||
if not 'eval_lengths' in locals():
|
if not "eval_lengths" in locals():
|
||||||
eval_lengths = []
|
eval_lengths = []
|
||||||
eval_scores = []
|
eval_scores = []
|
||||||
eval_done = False
|
eval_done = False
|
||||||
@ -278,6 +289,7 @@ class CollectDataset:
|
|||||||
self.add_to_cache(transition)
|
self.add_to_cache(transition)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
|
||||||
def add_to_cache(cache, id, transition):
|
def add_to_cache(cache, id, transition):
|
||||||
if id not in cache:
|
if id not in cache:
|
||||||
cache[id] = dict()
|
cache[id] = dict()
|
||||||
@ -292,6 +304,7 @@ def add_to_cache(cache, id, transition):
|
|||||||
else:
|
else:
|
||||||
cache[id][key].append(convert(val))
|
cache[id][key].append(convert(val))
|
||||||
|
|
||||||
|
|
||||||
def erase_over_episodes(cache, dataset_size):
|
def erase_over_episodes(cache, dataset_size):
|
||||||
step_in_dataset = 0
|
step_in_dataset = 0
|
||||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||||
@ -304,6 +317,7 @@ def erase_over_episodes(cache, dataset_size):
|
|||||||
del cache[key]
|
del cache[key]
|
||||||
return step_in_dataset
|
return step_in_dataset
|
||||||
|
|
||||||
|
|
||||||
def convert(value, precision=32):
|
def convert(value, precision=32):
|
||||||
value = np.array(value)
|
value = np.array(value)
|
||||||
if np.issubdtype(value.dtype, np.floating):
|
if np.issubdtype(value.dtype, np.floating):
|
||||||
@ -318,6 +332,7 @@ def convert(value, precision=32):
|
|||||||
raise NotImplementedError(value.dtype)
|
raise NotImplementedError(value.dtype)
|
||||||
return value.astype(dtype)
|
return value.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def save_episodes(directory, episodes):
|
def save_episodes(directory, episodes):
|
||||||
directory = pathlib.Path(directory).expanduser()
|
directory = pathlib.Path(directory).expanduser()
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user