added log for inventory items in minecraft
This commit is contained in:
		
							parent
							
								
									99dc4e4ed1
								
							
						
					
					
						commit
						68096d1f62
					
				@ -36,8 +36,7 @@ So far, the following benchmarks can be used for testing.
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
#### Crafter
 | 
			
		||||
<img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/2a4d65d3-7e7b-4a95-b0cf-146d978054f0" width="300" height="150" />
 | 
			
		||||
 | 
			
		||||
<img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/a0626038-53f6-4300-a622-7ac257f4c290" width="300" height="150" />
 | 
			
		||||
 | 
			
		||||
## Acknowledgments
 | 
			
		||||
This code is heavily inspired by the following works:
 | 
			
		||||
 | 
			
		||||
@ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper):
 | 
			
		||||
            "place_furnace": dict(place="furnace"),
 | 
			
		||||
            "smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
 | 
			
		||||
        }
 | 
			
		||||
        self.rewards = [
 | 
			
		||||
            CollectReward("log", once=1),
 | 
			
		||||
            CollectReward("planks", once=1),
 | 
			
		||||
            CollectReward("stick", once=1),
 | 
			
		||||
            CollectReward("crafting_table", once=1),
 | 
			
		||||
            CollectReward("wooden_pickaxe", once=1),
 | 
			
		||||
            CollectReward("cobblestone", once=1),
 | 
			
		||||
            CollectReward("stone_pickaxe", once=1),
 | 
			
		||||
            CollectReward("iron_ore", once=1),
 | 
			
		||||
            CollectReward("furnace", once=1),
 | 
			
		||||
            CollectReward("iron_ingot", once=1),
 | 
			
		||||
            CollectReward("iron_pickaxe", once=1),
 | 
			
		||||
            CollectReward("diamond", once=1),
 | 
			
		||||
            HealthReward(),
 | 
			
		||||
        self.items = [
 | 
			
		||||
            "log",
 | 
			
		||||
            "planks",
 | 
			
		||||
            "stick",
 | 
			
		||||
            "crafting_table",
 | 
			
		||||
            "wooden_pickaxe",
 | 
			
		||||
            "cobblestone",
 | 
			
		||||
            "stone_pickaxe",
 | 
			
		||||
            "iron_ore",
 | 
			
		||||
            "furnace",
 | 
			
		||||
            "iron_ingot",
 | 
			
		||||
            "iron_pickaxe",
 | 
			
		||||
            "diamond",
 | 
			
		||||
        ]
 | 
			
		||||
        self.rewards = [CollectReward(item, once=1) for item in self.items] + [
 | 
			
		||||
            HealthReward()
 | 
			
		||||
        ]
 | 
			
		||||
        env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
 | 
			
		||||
        super().__init__(env)
 | 
			
		||||
@ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper):
 | 
			
		||||
    def step(self, action):
 | 
			
		||||
        obs, reward, done, info = self.env.step(action)
 | 
			
		||||
        reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
 | 
			
		||||
        # restrict log for memory save
 | 
			
		||||
        obs = {
 | 
			
		||||
            k: v
 | 
			
		||||
            for k, v in obs.items()
 | 
			
		||||
            if "log" not in k or k.split("/")[-1] in self.items
 | 
			
		||||
        }
 | 
			
		||||
        return obs, reward, done, info
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
        obs = self.env.reset()
 | 
			
		||||
        # called for reset of reward calculations
 | 
			
		||||
        _ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
 | 
			
		||||
        # restrict log for memory save
 | 
			
		||||
        obs = {
 | 
			
		||||
            k: v
 | 
			
		||||
            for k, v in obs.items()
 | 
			
		||||
            if "log" not in k or k.split("/")[-1] in self.items
 | 
			
		||||
        }
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								tools.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								tools.py
									
									
									
									
									
								
							@ -84,7 +84,10 @@ class Logger:
 | 
			
		||||
        with (self._logdir / "metrics.jsonl").open("a") as f:
 | 
			
		||||
            f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
 | 
			
		||||
        for name, value in scalars:
 | 
			
		||||
            if "/" not in name:
 | 
			
		||||
                self._writer.add_scalar("scalars/" + name, value, step)
 | 
			
		||||
            else:
 | 
			
		||||
                self._writer.add_scalar(name, value, step)
 | 
			
		||||
        for name, value in self._images.items():
 | 
			
		||||
            self._writer.add_image(name, value, step)
 | 
			
		||||
        for name, value in self._videos.items():
 | 
			
		||||
@ -203,6 +206,15 @@ def simulate(
 | 
			
		||||
                length = len(cache[envs[i].id]["reward"]) - 1
 | 
			
		||||
                score = float(np.array(cache[envs[i].id]["reward"]).sum())
 | 
			
		||||
                video = cache[envs[i].id]["image"]
 | 
			
		||||
                # record logs given from environments
 | 
			
		||||
                for key in list(cache[envs[i].id].keys()):
 | 
			
		||||
                    if "log_" in key:
 | 
			
		||||
                        logger.scalar(
 | 
			
		||||
                            key, float(np.array(cache[envs[i].id][key]).sum())
 | 
			
		||||
                        )
 | 
			
		||||
                        # log items won't be used later
 | 
			
		||||
                        cache[envs[i].id].pop(key)
 | 
			
		||||
 | 
			
		||||
                if not is_eval:
 | 
			
		||||
                    step_in_dataset = erase_over_episodes(cache, limit)
 | 
			
		||||
                    logger.scalar(f"dataset_size", step_in_dataset)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user