From 68096d1f62498d0d45af23268f764652e547ffa9 Mon Sep 17 00:00:00 2001 From: NM512 Date: Wed, 16 Aug 2023 15:52:33 +0900 Subject: [PATCH] added log for inventory items in minecraft --- README.md | 3 +-- envs/minecraft.py | 42 ++++++++++++++++++++++++++++-------------- tools.py | 14 +++++++++++++- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 7b660bc..4e37bf5 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,7 @@ So far, the following benchmarks can be used for testing. ![atari100k](https://github.com/NM512/dreamerv3-torch/assets/70328564/0da6d899-d91d-44b4-a8c4-d5b37413aa11) #### Crafter - - + ## Acknowledgments This code is heavily inspired by the following works: diff --git a/envs/minecraft.py b/envs/minecraft.py index 59f13e7..338d31f 100644 --- a/envs/minecraft.py +++ b/envs/minecraft.py @@ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper): "place_furnace": dict(place="furnace"), "smelt_iron_ingot": dict(nearbySmelt="iron_ingot"), } - self.rewards = [ - CollectReward("log", once=1), - CollectReward("planks", once=1), - CollectReward("stick", once=1), - CollectReward("crafting_table", once=1), - CollectReward("wooden_pickaxe", once=1), - CollectReward("cobblestone", once=1), - CollectReward("stone_pickaxe", once=1), - CollectReward("iron_ore", once=1), - CollectReward("furnace", once=1), - CollectReward("iron_ingot", once=1), - CollectReward("iron_pickaxe", once=1), - CollectReward("diamond", once=1), - HealthReward(), + self.items = [ + "log", + "planks", + "stick", + "crafting_table", + "wooden_pickaxe", + "cobblestone", + "stone_pickaxe", + "iron_ore", + "furnace", + "iron_ingot", + "iron_pickaxe", + "diamond", + ] + self.rewards = [CollectReward(item, once=1) for item in self.items] + [ + HealthReward() ] env = minecraft_base.MinecraftBase(actions, *args, **kwargs) super().__init__(env) @@ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper): def step(self, action): obs, reward, done, info = self.env.step(action) reward = sum([fn(obs, self.env.inventory) for fn in self.rewards]) + # restrict log for memory save + obs = { + k: v + for k, v in obs.items() + if "log" not in k or k.split("/")[-1] in self.items + } return obs, reward, done, info def reset(self): obs = self.env.reset() # called for reset of reward calculations _ = sum([fn(obs, self.env.inventory) for fn in self.rewards]) + # restrict log for memory save + obs = { + k: v + for k, v in obs.items() + if "log" not in k or k.split("/")[-1] in self.items + } return obs diff --git a/tools.py b/tools.py index 1cfed7d..d5da141 100644 --- a/tools.py +++ b/tools.py @@ -84,7 +84,10 @@ class Logger: with (self._logdir / "metrics.jsonl").open("a") as f: f.write(json.dumps({"step": step, **dict(scalars)}) + "\n") for name, value in scalars: - self._writer.add_scalar("scalars/" + name, value, step) + if "/" not in name: + self._writer.add_scalar("scalars/" + name, value, step) + else: + self._writer.add_scalar(name, value, step) for name, value in self._images.items(): self._writer.add_image(name, value, step) for name, value in self._videos.items(): @@ -203,6 +206,15 @@ def simulate( length = len(cache[envs[i].id]["reward"]) - 1 score = float(np.array(cache[envs[i].id]["reward"]).sum()) video = cache[envs[i].id]["image"] + # record logs given from environments + for key in list(cache[envs[i].id].keys()): + if "log_" in key: + logger.scalar( + key, float(np.array(cache[envs[i].id][key]).sum()) + ) + # log items won't be used later + cache[envs[i].id].pop(key) + if not is_eval: step_in_dataset = erase_over_episodes(cache, limit) logger.scalar(f"dataset_size", step_in_dataset)