71 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			71 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import gym
 | 
						|
import numpy as np
 | 
						|
 | 
						|
 | 
						|
class DeepMindControl:
 | 
						|
    def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
 | 
						|
        domain, task = name.split("_", 1)
 | 
						|
        if domain == "cup":  # Only domain with multiple words.
 | 
						|
            domain = "ball_in_cup"
 | 
						|
        if isinstance(domain, str):
 | 
						|
            from dm_control import suite
 | 
						|
 | 
						|
            self._env = suite.load(domain, task)
 | 
						|
        else:
 | 
						|
            assert task is None
 | 
						|
            self._env = domain()
 | 
						|
        self._action_repeat = action_repeat
 | 
						|
        self._size = size
 | 
						|
        if camera is None:
 | 
						|
            camera = dict(quadruped=2).get(domain, 0)
 | 
						|
        self._camera = camera
 | 
						|
 | 
						|
    @property
 | 
						|
    def observation_space(self):
 | 
						|
        spaces = {}
 | 
						|
        for key, value in self._env.observation_spec().items():
 | 
						|
            if len(value.shape) == 0:
 | 
						|
                shape = (1,)
 | 
						|
            else:
 | 
						|
                shape = value.shape
 | 
						|
            spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32)
 | 
						|
        spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
 | 
						|
        return gym.spaces.Dict(spaces)
 | 
						|
 | 
						|
    @property
 | 
						|
    def action_space(self):
 | 
						|
        spec = self._env.action_spec()
 | 
						|
        return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
 | 
						|
 | 
						|
    def step(self, action):
 | 
						|
        assert np.isfinite(action).all(), action
 | 
						|
        reward = 0
 | 
						|
        for _ in range(self._action_repeat):
 | 
						|
            time_step = self._env.step(action)
 | 
						|
            reward += time_step.reward or 0
 | 
						|
            if time_step.last():
 | 
						|
                break
 | 
						|
        obs = dict(time_step.observation)
 | 
						|
        obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
 | 
						|
        obs["image"] = self.render()
 | 
						|
        # There is no terminal state in DMC
 | 
						|
        obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
 | 
						|
        obs["is_first"] = time_step.first()
 | 
						|
        done = time_step.last()
 | 
						|
        info = {"discount": np.array(time_step.discount, np.float32)}
 | 
						|
        return obs, reward, done, info
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        time_step = self._env.reset()
 | 
						|
        obs = dict(time_step.observation)
 | 
						|
        obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
 | 
						|
        obs["image"] = self.render()
 | 
						|
        obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
 | 
						|
        obs["is_first"] = time_step.first()
 | 
						|
        return obs
 | 
						|
 | 
						|
    def render(self, *args, **kwargs):
 | 
						|
        if kwargs.get("mode", "rgb_array") != "rgb_array":
 | 
						|
            raise ValueError("Only render mode 'rgb_array' is supported.")
 | 
						|
        return self._env.physics.render(*self._size, camera_id=self._camera)
 |