added log for inventory items in minecraft
This commit is contained in:
parent
99dc4e4ed1
commit
68096d1f62
@ -36,8 +36,7 @@ So far, the following benchmarks can be used for testing.
|
|||||||

|

|
||||||
|
|
||||||
#### Crafter
|
#### Crafter
|
||||||
<img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/2a4d65d3-7e7b-4a95-b0cf-146d978054f0" width="300" height="150" />
|
<img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/a0626038-53f6-4300-a622-7ac257f4c290" width="300" height="150" />
|
||||||
|
|
||||||
|
|
||||||
## Acknowledgments
|
## Acknowledgments
|
||||||
This code is heavily inspired by the following works:
|
This code is heavily inspired by the following works:
|
||||||
|
@ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper):
|
|||||||
"place_furnace": dict(place="furnace"),
|
"place_furnace": dict(place="furnace"),
|
||||||
"smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
|
"smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
|
||||||
}
|
}
|
||||||
self.rewards = [
|
self.items = [
|
||||||
CollectReward("log", once=1),
|
"log",
|
||||||
CollectReward("planks", once=1),
|
"planks",
|
||||||
CollectReward("stick", once=1),
|
"stick",
|
||||||
CollectReward("crafting_table", once=1),
|
"crafting_table",
|
||||||
CollectReward("wooden_pickaxe", once=1),
|
"wooden_pickaxe",
|
||||||
CollectReward("cobblestone", once=1),
|
"cobblestone",
|
||||||
CollectReward("stone_pickaxe", once=1),
|
"stone_pickaxe",
|
||||||
CollectReward("iron_ore", once=1),
|
"iron_ore",
|
||||||
CollectReward("furnace", once=1),
|
"furnace",
|
||||||
CollectReward("iron_ingot", once=1),
|
"iron_ingot",
|
||||||
CollectReward("iron_pickaxe", once=1),
|
"iron_pickaxe",
|
||||||
CollectReward("diamond", once=1),
|
"diamond",
|
||||||
HealthReward(),
|
]
|
||||||
|
self.rewards = [CollectReward(item, once=1) for item in self.items] + [
|
||||||
|
HealthReward()
|
||||||
]
|
]
|
||||||
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
@ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
|
# restrict log for memory save
|
||||||
|
obs = {
|
||||||
|
k: v
|
||||||
|
for k, v in obs.items()
|
||||||
|
if "log" not in k or k.split("/")[-1] in self.items
|
||||||
|
}
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
# called for reset of reward calculations
|
# called for reset of reward calculations
|
||||||
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
|
# restrict log for memory save
|
||||||
|
obs = {
|
||||||
|
k: v
|
||||||
|
for k, v in obs.items()
|
||||||
|
if "log" not in k or k.split("/")[-1] in self.items
|
||||||
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
12
tools.py
12
tools.py
@ -84,7 +84,10 @@ class Logger:
|
|||||||
with (self._logdir / "metrics.jsonl").open("a") as f:
|
with (self._logdir / "metrics.jsonl").open("a") as f:
|
||||||
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
|
||||||
for name, value in scalars:
|
for name, value in scalars:
|
||||||
|
if "/" not in name:
|
||||||
self._writer.add_scalar("scalars/" + name, value, step)
|
self._writer.add_scalar("scalars/" + name, value, step)
|
||||||
|
else:
|
||||||
|
self._writer.add_scalar(name, value, step)
|
||||||
for name, value in self._images.items():
|
for name, value in self._images.items():
|
||||||
self._writer.add_image(name, value, step)
|
self._writer.add_image(name, value, step)
|
||||||
for name, value in self._videos.items():
|
for name, value in self._videos.items():
|
||||||
@ -203,6 +206,15 @@ def simulate(
|
|||||||
length = len(cache[envs[i].id]["reward"]) - 1
|
length = len(cache[envs[i].id]["reward"]) - 1
|
||||||
score = float(np.array(cache[envs[i].id]["reward"]).sum())
|
score = float(np.array(cache[envs[i].id]["reward"]).sum())
|
||||||
video = cache[envs[i].id]["image"]
|
video = cache[envs[i].id]["image"]
|
||||||
|
# record logs given from environments
|
||||||
|
for key in list(cache[envs[i].id].keys()):
|
||||||
|
if "log_" in key:
|
||||||
|
logger.scalar(
|
||||||
|
key, float(np.array(cache[envs[i].id][key]).sum())
|
||||||
|
)
|
||||||
|
# log items won't be used later
|
||||||
|
cache[envs[i].id].pop(key)
|
||||||
|
|
||||||
if not is_eval:
|
if not is_eval:
|
||||||
step_in_dataset = erase_over_episodes(cache, limit)
|
step_in_dataset = erase_over_episodes(cache, limit)
|
||||||
logger.scalar(f"dataset_size", step_in_dataset)
|
logger.scalar(f"dataset_size", step_in_dataset)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user