diff --git a/envs/crafter.py b/envs/crafter.py index 5d9e56e..5a67494 100644 --- a/envs/crafter.py +++ b/envs/crafter.py @@ -19,7 +19,6 @@ class Crafter: "image": gym.spaces.Box( 0, 255, self._env.observation_space.shape, dtype=np.uint8 ), - "reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), "is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), "is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), "is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), @@ -50,7 +49,6 @@ class Crafter: } obs = { "image": image, - "reward": reward, "is_first": False, "is_last": done, "is_terminal": info["discount"] == 0, diff --git a/envs/memorymaze.py b/envs/memorymaze.py index 20717a8..93603a6 100644 --- a/envs/memorymaze.py +++ b/envs/memorymaze.py @@ -35,7 +35,6 @@ class MemoryMaze: return gym.spaces.Dict( { **spaces, - "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), "is_first": gym.spaces.Box(0, 1, (), dtype=bool), "is_last": gym.spaces.Box(0, 1, (), dtype=bool), "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), @@ -52,7 +51,6 @@ class MemoryMaze: obs, reward, done, info = self._env.step(action) if not self._obs_is_dict: obs = {self._obs_key: obs} - obs["reward"] = reward obs["is_first"] = False obs["is_last"] = done obs["is_terminal"] = info.get("is_terminal", False) @@ -62,7 +60,6 @@ class MemoryMaze: obs = self._env.reset() if not self._obs_is_dict: obs = {self._obs_key: obs} - obs["reward"] = 0.0 obs["is_first"] = True obs["is_last"] = False obs["is_terminal"] = False diff --git a/envs/minecraft.py b/envs/minecraft.py index f33f52d..59f13e7 100644 --- a/envs/minecraft.py +++ b/envs/minecraft.py @@ -20,11 +20,11 @@ class MinecraftWood: HealthReward(), ] env = minecraft_base.MinecraftBase(actions, *args, **kwargs) + super().__init__(env) def step(self, action): obs, reward, done, info = self.env.step(action) reward = sum([fn(obs, self.env.inventory) for fn in self.rewards]) - obs["reward"] = reward return obs, reward, done, info @@ -34,6 +34,7 @@ class MinecraftClimb: env = minecraft_base.MinecraftBase(actions, *args, **kwargs) self._previous = None self._health_reward = HealthReward() + super().__init__(env) def step(self, action): obs, reward, done, info = self.env.step(action) @@ -43,7 +44,6 @@ class MinecraftClimb: self._previous = height reward = height - self._previous reward += self._health_reward(obs) - obs["reward"] = reward self._previous = height return obs, reward, done, info @@ -87,7 +87,6 @@ class MinecraftDiamond(gym.Wrapper): def step(self, action): obs, reward, done, info = self.env.step(action) reward = sum([fn(obs, self.env.inventory) for fn in self.rewards]) - obs["reward"] = reward return obs, reward, done, info def reset(self): @@ -131,7 +130,7 @@ class HealthReward: return 0 reward = self.scale * (health - self.previous) self.previous = health - return np.float32(reward) + return sum(reward) BASIC_ACTIONS = { diff --git a/envs/minecraft_base.py b/envs/minecraft_base.py index d7f0aba..d6d18dd 100644 --- a/envs/minecraft_base.py +++ b/envs/minecraft_base.py @@ -18,7 +18,7 @@ class MinecraftBase(gym.Env): sticky_attack=30, sticky_jump=10, pitch_limit=(-60, 60), - logs=True, + logs=False, ): if logs: logging.basicConfig(level=logging.DEBUG) @@ -41,7 +41,6 @@ class MinecraftBase(gym.Env): if k.startswith("inventory/") if k != "inventory/log2" ] - self._step = 0 self._max_inventory = None self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][ "type" @@ -75,7 +74,6 @@ class MinecraftBase(gym.Env): "equipped": gym.spaces.Box( -np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32 ), - "reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), "health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), "hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), "breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), @@ -110,12 +108,11 @@ class MinecraftBase(gym.Env): if "error" in info: done = True break - obs["is_first"] = False - obs["is_last"] = bool(done) - obs["is_terminal"] = bool(info.get("is_terminal", done)) + obs["is_first"] = False + obs["is_last"] = bool(done) + obs["is_terminal"] = bool(info.get("is_terminal", done)) obs = self._obs(obs) - self._step += 1 assert "pov" not in obs, list(obs.keys()) return obs, reward, done, info @@ -135,7 +132,6 @@ class MinecraftBase(gym.Env): obs["is_terminal"] = False obs = self._obs(obs) - self._step = 0 self._sticky_attack_counter = 0 self._sticky_jump_counter = 0 self._pitch = 0 @@ -166,7 +162,6 @@ class MinecraftBase(gym.Env): "health": np.float32([obs["life_stats/life"]]) / 20, "hunger": np.float32([obs["life_stats/food"]]) / 20, "breath": np.float32([obs["life_stats/air"]]) / 300, - "reward": [0.0], "is_first": obs["is_first"], "is_last": obs["is_last"], "is_terminal": obs["is_terminal"],