2023-05-14 23:38:46 +09:00

71 lines
2.6 KiB
Python

import gym
import numpy as np
class DeepMindControl:
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
domain, task = name.split("_", 1)
if domain == "cup": # Only domain with multiple words.
domain = "ball_in_cup"
if isinstance(domain, str):
from dm_control import suite
self._env = suite.load(domain, task)
else:
assert task is None
self._env = domain()
self._action_repeat = action_repeat
self._size = size
if camera is None:
camera = dict(quadruped=2).get(domain, 0)
self._camera = camera
@property
def observation_space(self):
spaces = {}
for key, value in self._env.observation_spec().items():
if len(value.shape) == 0:
shape = (1,)
else:
shape = value.shape
spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32)
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
spec = self._env.action_spec()
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
def step(self, action):
assert np.isfinite(action).all(), action
reward = 0
for _ in range(self._action_repeat):
time_step = self._env.step(action)
reward += time_step.reward or 0
if time_step.last():
break
obs = dict(time_step.observation)
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
obs["image"] = self.render()
# There is no terminal state in DMC
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
obs["is_first"] = time_step.first()
done = time_step.last()
info = {"discount": np.array(time_step.discount, np.float32)}
return obs, reward, done, info
def reset(self):
time_step = self._env.reset()
obs = dict(time_step.observation)
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
obs["image"] = self.render()
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
obs["is_first"] = time_step.first()
return obs
def render(self, *args, **kwargs):
if kwargs.get("mode", "rgb_array") != "rgb_array":
raise ValueError("Only render mode 'rgb_array' is supported.")
return self._env.physics.render(*self._size, camera_id=self._camera)