changed treatment of obs shape in minecraft
This commit is contained in:
parent
d94a719421
commit
3f6659d365
@ -163,10 +163,10 @@ class MinecraftBase(gym.Env):
|
||||
"inventory": inventory,
|
||||
"inventory_max": self._max_inventory.copy(),
|
||||
"equipped": equipped,
|
||||
"health": np.float32(obs["life_stats/life"] / 20),
|
||||
"hunger": np.float32(obs["life_stats/food"] / 20),
|
||||
"breath": np.float32(obs["life_stats/air"] / 300),
|
||||
"reward": 0.0,
|
||||
"health": np.float32([obs["life_stats/life"]]) / 20,
|
||||
"hunger": np.float32([obs["life_stats/food"]]) / 20,
|
||||
"breath": np.float32([obs["life_stats/air"]]) / 300,
|
||||
"reward": [0.0],
|
||||
"is_first": obs["is_first"],
|
||||
"is_last": obs["is_last"],
|
||||
"is_terminal": obs["is_terminal"],
|
||||
|
13
models.py
13
models.py
@ -174,20 +174,19 @@ class WorldModel(nn.Module):
|
||||
post = {k: v.detach() for k, v in post.items()}
|
||||
return post, context, metrics
|
||||
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
||||
if "is_terminal" in obs:
|
||||
# this label is necessary to train cont_head
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('"is_terminal" was not found in observation.')
|
||||
# 'is_first' is necesarry to initialize hidden state at training
|
||||
assert "is_first" in obs
|
||||
# 'is_terminal' is necesarry to train cont_head
|
||||
assert "is_terminal" in obs
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||
return obs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user