changed treatment of obs shape in minecraft

This commit is contained in:
NM512 2023-08-03 08:12:44 +09:00
parent d94a719421
commit 3f6659d365
2 changed files with 10 additions and 11 deletions

View File

@ -163,10 +163,10 @@ class MinecraftBase(gym.Env):
"inventory": inventory,
"inventory_max": self._max_inventory.copy(),
"equipped": equipped,
"health": np.float32(obs["life_stats/life"] / 20),
"hunger": np.float32(obs["life_stats/food"] / 20),
"breath": np.float32(obs["life_stats/air"] / 300),
"reward": 0.0,
"health": np.float32([obs["life_stats/life"]]) / 20,
"hunger": np.float32([obs["life_stats/food"]]) / 20,
"breath": np.float32([obs["life_stats/air"]]) / 300,
"reward": [0.0],
"is_first": obs["is_first"],
"is_last": obs["is_last"],
"is_terminal": obs["is_terminal"],

View File

@ -174,20 +174,19 @@ class WorldModel(nn.Module):
post = {k: v.detach() for k, v in post.items()}
return post, context, metrics
# this function is called during both rollout and training
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
if "is_terminal" in obs:
# this label is necessary to train cont_head
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
else:
raise ValueError('"is_terminal" was not found in observation.')
# 'is_first' is necesarry to initialize hidden state at training
assert "is_first" in obs
# 'is_terminal' is necesarry to train cont_head
assert "is_terminal" in obs
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
return obs