added log for inventory items in minecraft

This commit is contained in:
NM512 2023-08-16 15:52:33 +09:00
parent 99dc4e4ed1
commit 68096d1f62
3 changed files with 42 additions and 17 deletions

View File

@ -36,8 +36,7 @@ So far, the following benchmarks can be used for testing.
![atari100k](https://github.com/NM512/dreamerv3-torch/assets/70328564/0da6d899-d91d-44b4-a8c4-d5b37413aa11) ![atari100k](https://github.com/NM512/dreamerv3-torch/assets/70328564/0da6d899-d91d-44b4-a8c4-d5b37413aa11)
#### Crafter #### Crafter
<img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/2a4d65d3-7e7b-4a95-b0cf-146d978054f0" width="300" height="150" /> <img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/a0626038-53f6-4300-a622-7ac257f4c290" width="300" height="150" />
## Acknowledgments ## Acknowledgments
This code is heavily inspired by the following works: This code is heavily inspired by the following works:

View File

@ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper):
"place_furnace": dict(place="furnace"), "place_furnace": dict(place="furnace"),
"smelt_iron_ingot": dict(nearbySmelt="iron_ingot"), "smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
} }
self.rewards = [ self.items = [
CollectReward("log", once=1), "log",
CollectReward("planks", once=1), "planks",
CollectReward("stick", once=1), "stick",
CollectReward("crafting_table", once=1), "crafting_table",
CollectReward("wooden_pickaxe", once=1), "wooden_pickaxe",
CollectReward("cobblestone", once=1), "cobblestone",
CollectReward("stone_pickaxe", once=1), "stone_pickaxe",
CollectReward("iron_ore", once=1), "iron_ore",
CollectReward("furnace", once=1), "furnace",
CollectReward("iron_ingot", once=1), "iron_ingot",
CollectReward("iron_pickaxe", once=1), "iron_pickaxe",
CollectReward("diamond", once=1), "diamond",
HealthReward(), ]
self.rewards = [CollectReward(item, once=1) for item in self.items] + [
HealthReward()
] ]
env = minecraft_base.MinecraftBase(actions, *args, **kwargs) env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
super().__init__(env) super().__init__(env)
@ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper):
def step(self, action): def step(self, action):
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards]) reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
# restrict log for memory save
obs = {
k: v
for k, v in obs.items()
if "log" not in k or k.split("/")[-1] in self.items
}
return obs, reward, done, info return obs, reward, done, info
def reset(self): def reset(self):
obs = self.env.reset() obs = self.env.reset()
# called for reset of reward calculations # called for reset of reward calculations
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards]) _ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
# restrict log for memory save
obs = {
k: v
for k, v in obs.items()
if "log" not in k or k.split("/")[-1] in self.items
}
return obs return obs

View File

@ -84,7 +84,10 @@ class Logger:
with (self._logdir / "metrics.jsonl").open("a") as f: with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n") f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars: for name, value in scalars:
self._writer.add_scalar("scalars/" + name, value, step) if "/" not in name:
self._writer.add_scalar("scalars/" + name, value, step)
else:
self._writer.add_scalar(name, value, step)
for name, value in self._images.items(): for name, value in self._images.items():
self._writer.add_image(name, value, step) self._writer.add_image(name, value, step)
for name, value in self._videos.items(): for name, value in self._videos.items():
@ -203,6 +206,15 @@ def simulate(
length = len(cache[envs[i].id]["reward"]) - 1 length = len(cache[envs[i].id]["reward"]) - 1
score = float(np.array(cache[envs[i].id]["reward"]).sum()) score = float(np.array(cache[envs[i].id]["reward"]).sum())
video = cache[envs[i].id]["image"] video = cache[envs[i].id]["image"]
# record logs given from environments
for key in list(cache[envs[i].id].keys()):
if "log_" in key:
logger.scalar(
key, float(np.array(cache[envs[i].id][key]).sum())
)
# log items won't be used later
cache[envs[i].id].pop(key)
if not is_eval: if not is_eval:
step_in_dataset = erase_over_episodes(cache, limit) step_in_dataset = erase_over_episodes(cache, limit)
logger.scalar(f"dataset_size", step_in_dataset) logger.scalar(f"dataset_size", step_in_dataset)