dreamerv3-torch/envs/crafter.py
2023-06-18 00:02:22 +09:00

71 lines
2.2 KiB
Python

import gym
import numpy as np
class Crafter:
def __init__(self, task, size=(64, 64), seed=None):
assert task in ("reward", "noreward")
import crafter
self._env = crafter.Env(size=size, reward=(task == "reward"), seed=seed)
self._achievements = crafter.constants.achievements.copy()
@property
def observation_space(self):
spaces = {
"image": gym.spaces.Box(
0, 255, self._env.observation_space.shape, dtype=np.uint8
),
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
}
spaces.update(
{
f"log_achievement_{k}": gym.spaces.Box(
-np.inf, np.inf, (1,), dtype=np.float32
)
for k in self._achievements
}
)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
action_space = self._env.action_space
action_space.discrete = True
return action_space
def step(self, action):
image, reward, done, info = self._env.step(action)
reward = np.float32(reward)
log_achievements = {
f"log_achievement_{k}": info["achievements"][k] if info else 0
for k in self._achievements
}
obs = {
"image": image,
"reward": reward,
"is_first": False,
"is_last": done,
"is_terminal": info["discount"] == 0,
"log_reward": np.float32(info["reward"] if info else 0.0),
**log_achievements,
}
return obs, reward, done, info
def render(self):
return self._env.render()
def reset(self):
image = self._env.reset()
obs = {
"image": image,
"is_first": True,
"is_last": False,
"is_terminal": False,
}
return obs