modified envs
This commit is contained in:
parent
a6ad132198
commit
eb14e2488b
@ -19,7 +19,6 @@ class Crafter:
|
|||||||
"image": gym.spaces.Box(
|
"image": gym.spaces.Box(
|
||||||
0, 255, self._env.observation_space.shape, dtype=np.uint8
|
0, 255, self._env.observation_space.shape, dtype=np.uint8
|
||||||
),
|
),
|
||||||
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
|
||||||
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
@ -50,7 +49,6 @@ class Crafter:
|
|||||||
}
|
}
|
||||||
obs = {
|
obs = {
|
||||||
"image": image,
|
"image": image,
|
||||||
"reward": reward,
|
|
||||||
"is_first": False,
|
"is_first": False,
|
||||||
"is_last": done,
|
"is_last": done,
|
||||||
"is_terminal": info["discount"] == 0,
|
"is_terminal": info["discount"] == 0,
|
||||||
|
@ -35,7 +35,6 @@ class MemoryMaze:
|
|||||||
return gym.spaces.Dict(
|
return gym.spaces.Dict(
|
||||||
{
|
{
|
||||||
**spaces,
|
**spaces,
|
||||||
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
|
||||||
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
@ -52,7 +51,6 @@ class MemoryMaze:
|
|||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self._env.step(action)
|
||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
obs["reward"] = reward
|
|
||||||
obs["is_first"] = False
|
obs["is_first"] = False
|
||||||
obs["is_last"] = done
|
obs["is_last"] = done
|
||||||
obs["is_terminal"] = info.get("is_terminal", False)
|
obs["is_terminal"] = info.get("is_terminal", False)
|
||||||
@ -62,7 +60,6 @@ class MemoryMaze:
|
|||||||
obs = self._env.reset()
|
obs = self._env.reset()
|
||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
obs["reward"] = 0.0
|
|
||||||
obs["is_first"] = True
|
obs["is_first"] = True
|
||||||
obs["is_last"] = False
|
obs["is_last"] = False
|
||||||
obs["is_terminal"] = False
|
obs["is_terminal"] = False
|
||||||
|
@ -20,11 +20,11 @@ class MinecraftWood:
|
|||||||
HealthReward(),
|
HealthReward(),
|
||||||
]
|
]
|
||||||
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
obs["reward"] = reward
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +34,7 @@ class MinecraftClimb:
|
|||||||
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
|
||||||
self._previous = None
|
self._previous = None
|
||||||
self._health_reward = HealthReward()
|
self._health_reward = HealthReward()
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
@ -43,7 +44,6 @@ class MinecraftClimb:
|
|||||||
self._previous = height
|
self._previous = height
|
||||||
reward = height - self._previous
|
reward = height - self._previous
|
||||||
reward += self._health_reward(obs)
|
reward += self._health_reward(obs)
|
||||||
obs["reward"] = reward
|
|
||||||
self._previous = height
|
self._previous = height
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
@ -87,7 +87,6 @@ class MinecraftDiamond(gym.Wrapper):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
|
||||||
obs["reward"] = reward
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -131,7 +130,7 @@ class HealthReward:
|
|||||||
return 0
|
return 0
|
||||||
reward = self.scale * (health - self.previous)
|
reward = self.scale * (health - self.previous)
|
||||||
self.previous = health
|
self.previous = health
|
||||||
return np.float32(reward)
|
return sum(reward)
|
||||||
|
|
||||||
|
|
||||||
BASIC_ACTIONS = {
|
BASIC_ACTIONS = {
|
||||||
|
@ -18,7 +18,7 @@ class MinecraftBase(gym.Env):
|
|||||||
sticky_attack=30,
|
sticky_attack=30,
|
||||||
sticky_jump=10,
|
sticky_jump=10,
|
||||||
pitch_limit=(-60, 60),
|
pitch_limit=(-60, 60),
|
||||||
logs=True,
|
logs=False,
|
||||||
):
|
):
|
||||||
if logs:
|
if logs:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -41,7 +41,6 @@ class MinecraftBase(gym.Env):
|
|||||||
if k.startswith("inventory/")
|
if k.startswith("inventory/")
|
||||||
if k != "inventory/log2"
|
if k != "inventory/log2"
|
||||||
]
|
]
|
||||||
self._step = 0
|
|
||||||
self._max_inventory = None
|
self._max_inventory = None
|
||||||
self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][
|
self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][
|
||||||
"type"
|
"type"
|
||||||
@ -75,7 +74,6 @@ class MinecraftBase(gym.Env):
|
|||||||
"equipped": gym.spaces.Box(
|
"equipped": gym.spaces.Box(
|
||||||
-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32
|
-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32
|
||||||
),
|
),
|
||||||
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
|
||||||
"health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
"health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
"hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
"hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
"breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
"breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
@ -115,7 +113,6 @@ class MinecraftBase(gym.Env):
|
|||||||
obs["is_terminal"] = bool(info.get("is_terminal", done))
|
obs["is_terminal"] = bool(info.get("is_terminal", done))
|
||||||
|
|
||||||
obs = self._obs(obs)
|
obs = self._obs(obs)
|
||||||
self._step += 1
|
|
||||||
assert "pov" not in obs, list(obs.keys())
|
assert "pov" not in obs, list(obs.keys())
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
@ -135,7 +132,6 @@ class MinecraftBase(gym.Env):
|
|||||||
obs["is_terminal"] = False
|
obs["is_terminal"] = False
|
||||||
obs = self._obs(obs)
|
obs = self._obs(obs)
|
||||||
|
|
||||||
self._step = 0
|
|
||||||
self._sticky_attack_counter = 0
|
self._sticky_attack_counter = 0
|
||||||
self._sticky_jump_counter = 0
|
self._sticky_jump_counter = 0
|
||||||
self._pitch = 0
|
self._pitch = 0
|
||||||
@ -166,7 +162,6 @@ class MinecraftBase(gym.Env):
|
|||||||
"health": np.float32([obs["life_stats/life"]]) / 20,
|
"health": np.float32([obs["life_stats/life"]]) / 20,
|
||||||
"hunger": np.float32([obs["life_stats/food"]]) / 20,
|
"hunger": np.float32([obs["life_stats/food"]]) / 20,
|
||||||
"breath": np.float32([obs["life_stats/air"]]) / 300,
|
"breath": np.float32([obs["life_stats/air"]]) / 300,
|
||||||
"reward": [0.0],
|
|
||||||
"is_first": obs["is_first"],
|
"is_first": obs["is_first"],
|
||||||
"is_last": obs["is_last"],
|
"is_last": obs["is_last"],
|
||||||
"is_terminal": obs["is_terminal"],
|
"is_terminal": obs["is_terminal"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user