added log for inventory items in minecraft
This commit is contained in:
		
							parent
							
								
									99dc4e4ed1
								
							
						
					
					
						commit
						68096d1f62
					
				| @ -36,8 +36,7 @@ So far, the following benchmarks can be used for testing. | ||||
|  | ||||
| 
 | ||||
| #### Crafter | ||||
| <img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/2a4d65d3-7e7b-4a95-b0cf-146d978054f0" width="300" height="150" /> | ||||
| 
 | ||||
| <img src="https://github.com/NM512/dreamerv3-torch/assets/70328564/a0626038-53f6-4300-a622-7ac257f4c290" width="300" height="150" /> | ||||
| 
 | ||||
| ## Acknowledgments | ||||
| This code is heavily inspired by the following works: | ||||
|  | ||||
| @ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper): | ||||
|             "place_furnace": dict(place="furnace"), | ||||
|             "smelt_iron_ingot": dict(nearbySmelt="iron_ingot"), | ||||
|         } | ||||
|         self.rewards = [ | ||||
|             CollectReward("log", once=1), | ||||
|             CollectReward("planks", once=1), | ||||
|             CollectReward("stick", once=1), | ||||
|             CollectReward("crafting_table", once=1), | ||||
|             CollectReward("wooden_pickaxe", once=1), | ||||
|             CollectReward("cobblestone", once=1), | ||||
|             CollectReward("stone_pickaxe", once=1), | ||||
|             CollectReward("iron_ore", once=1), | ||||
|             CollectReward("furnace", once=1), | ||||
|             CollectReward("iron_ingot", once=1), | ||||
|             CollectReward("iron_pickaxe", once=1), | ||||
|             CollectReward("diamond", once=1), | ||||
|             HealthReward(), | ||||
|         self.items = [ | ||||
|             "log", | ||||
|             "planks", | ||||
|             "stick", | ||||
|             "crafting_table", | ||||
|             "wooden_pickaxe", | ||||
|             "cobblestone", | ||||
|             "stone_pickaxe", | ||||
|             "iron_ore", | ||||
|             "furnace", | ||||
|             "iron_ingot", | ||||
|             "iron_pickaxe", | ||||
|             "diamond", | ||||
|         ] | ||||
|         self.rewards = [CollectReward(item, once=1) for item in self.items] + [ | ||||
|             HealthReward() | ||||
|         ] | ||||
|         env = minecraft_base.MinecraftBase(actions, *args, **kwargs) | ||||
|         super().__init__(env) | ||||
| @ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper): | ||||
|     def step(self, action): | ||||
|         obs, reward, done, info = self.env.step(action) | ||||
|         reward = sum([fn(obs, self.env.inventory) for fn in self.rewards]) | ||||
|         # restrict log for memory save | ||||
|         obs = { | ||||
|             k: v | ||||
|             for k, v in obs.items() | ||||
|             if "log" not in k or k.split("/")[-1] in self.items | ||||
|         } | ||||
|         return obs, reward, done, info | ||||
| 
 | ||||
|     def reset(self): | ||||
|         obs = self.env.reset() | ||||
|         # called for reset of reward calculations | ||||
|         _ = sum([fn(obs, self.env.inventory) for fn in self.rewards]) | ||||
|         # restrict log for memory save | ||||
|         obs = { | ||||
|             k: v | ||||
|             for k, v in obs.items() | ||||
|             if "log" not in k or k.split("/")[-1] in self.items | ||||
|         } | ||||
|         return obs | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										14
									
								
								tools.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								tools.py
									
									
									
									
									
								
							| @ -84,7 +84,10 @@ class Logger: | ||||
|         with (self._logdir / "metrics.jsonl").open("a") as f: | ||||
|             f.write(json.dumps({"step": step, **dict(scalars)}) + "\n") | ||||
|         for name, value in scalars: | ||||
|             self._writer.add_scalar("scalars/" + name, value, step) | ||||
|             if "/" not in name: | ||||
|                 self._writer.add_scalar("scalars/" + name, value, step) | ||||
|             else: | ||||
|                 self._writer.add_scalar(name, value, step) | ||||
|         for name, value in self._images.items(): | ||||
|             self._writer.add_image(name, value, step) | ||||
|         for name, value in self._videos.items(): | ||||
| @ -203,6 +206,15 @@ def simulate( | ||||
|                 length = len(cache[envs[i].id]["reward"]) - 1 | ||||
|                 score = float(np.array(cache[envs[i].id]["reward"]).sum()) | ||||
|                 video = cache[envs[i].id]["image"] | ||||
|                 # record logs given from environments | ||||
|                 for key in list(cache[envs[i].id].keys()): | ||||
|                     if "log_" in key: | ||||
|                         logger.scalar( | ||||
|                             key, float(np.array(cache[envs[i].id][key]).sum()) | ||||
|                         ) | ||||
|                         # log items won't be used later | ||||
|                         cache[envs[i].id].pop(key) | ||||
| 
 | ||||
|                 if not is_eval: | ||||
|                     step_in_dataset = erase_over_episodes(cache, limit) | ||||
|                     logger.scalar(f"dataset_size", step_in_dataset) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user