import gym import numpy as np class DeepMindControl: def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): domain, task = name.split("_", 1) if domain == "cup": # Only domain with multiple words. domain = "ball_in_cup" if isinstance(domain, str): from dm_control import suite self._env = suite.load(domain, task) else: assert task is None self._env = domain() self._action_repeat = action_repeat self._size = size if camera is None: camera = dict(quadruped=2).get(domain, 0) self._camera = camera @property def observation_space(self): spaces = {} for key, value in self._env.observation_spec().items(): if len(value.shape) == 0: shape = (1,) else: shape = value.shape spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32) spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) return gym.spaces.Dict(spaces) @property def action_space(self): spec = self._env.action_spec() return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) def step(self, action): assert np.isfinite(action).all(), action reward = 0 for _ in range(self._action_repeat): time_step = self._env.step(action) reward += time_step.reward or 0 if time_step.last(): break obs = dict(time_step.observation) obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()} obs["image"] = self.render() # There is no terminal state in DMC obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 obs["is_first"] = time_step.first() done = time_step.last() info = {"discount": np.array(time_step.discount, np.float32)} return obs, reward, done, info def reset(self): time_step = self._env.reset() obs = dict(time_step.observation) obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()} obs["image"] = self.render() obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 obs["is_first"] = time_step.first() return obs def render(self, *args, **kwargs): if kwargs.get("mode", "rgb_array") != "rgb_array": raise ValueError("Only render mode 'rgb_array' is supported.") return self._env.physics.render(*self._size, camera_id=self._camera)