import random import time from copy import deepcopy from typing import Any, Literal import gymnasium as gym import networkx as nx import numpy as np from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple class MoveToRightEnv(gym.Env): """A task for "going right". The task is to go right ``size`` steps. The observation is the current index, and the action is to go left or right. Action 0 is to go left, and action 1 is to go right. Taking action 0 at index 0 will keep the index at 0. Arriving at index ``size`` means the task is done. In the current implementation, stepping after the task is done is possible, which will lead the index to be larger than ``size``. Index 0 is the starting point. If reset is called with default options, the index will be reset to 0. """ def __init__( self, size: int, sleep: float = 0.0, dict_state: bool = False, recurse_state: bool = False, ma_rew: int = 0, multidiscrete_action: bool = False, random_sleep: bool = False, array_state: bool = False, ) -> None: assert ( dict_state + recurse_state + array_state <= 1 ), "dict_state / recurse_state / array_state can be only one true" self.size = size self.sleep = sleep self.random_sleep = random_sleep self.dict_state = dict_state self.recurse_state = recurse_state self.array_state = array_state self.ma_rew = ma_rew self._md_action = multidiscrete_action # how many steps this env has stepped self.steps = 0 if dict_state: self.observation_space = Dict( { "index": Box(shape=(1,), low=0, high=size - 1), "rand": Box(shape=(1,), low=0, high=1, dtype=np.float64), }, ) elif recurse_state: self.observation_space = Dict( { "index": Box(shape=(1,), low=0, high=size - 1), "dict": Dict( { "tuple": Tuple( ( Discrete(2), Box(shape=(2,), low=0, high=1, dtype=np.float64), ), ), "rand": Box(shape=(1, 2), low=0, high=1, dtype=np.float64), }, ), }, ) elif array_state: self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) else: self.observation_space = Box(shape=(1,), low=0, high=size - 1) if multidiscrete_action: self.action_space = MultiDiscrete([2, 2]) else: self.action_space = Discrete(2) self.terminated = False self.index = 0 def reset( self, seed: int | None = None, # TODO: passing a dict here doesn't make any sense options: dict[str, Any] | None = None, ) -> tuple[dict[str, Any] | np.ndarray, dict]: """:param seed: :param options: the start index is provided in options["state"] :return: """ if options is None: options = {"state": 0} super().reset(seed=seed) self.terminated = False self.do_sleep() self.index = options["state"] return self._get_state(), {"key": 1, "env": self} def _get_reward(self) -> list[int] | int: """Generate a non-scalar reward if ma_rew is True.""" end_flag = int(self.terminated) if self.ma_rew > 0: return [end_flag] * self.ma_rew return end_flag def _get_state(self) -> dict[str, Any] | np.ndarray: """Generate state(observation) of MyTestEnv.""" if self.dict_state: return { "index": np.array([self.index], dtype=np.float32), "rand": self.np_random.random(1), } if self.recurse_state: return { "index": np.array([self.index], dtype=np.float32), "dict": { "tuple": (np.array([1], dtype=int), self.np_random.random(2)), "rand": self.np_random.random((1, 2)), }, } if self.array_state: img = np.zeros([4, 84, 84], int) img[3, np.arange(84), np.arange(84)] = self.index img[2, np.arange(84)] = self.index img[1, :, np.arange(84)] = self.index img[0] = self.index return img return np.array([self.index], dtype=np.float32) def do_sleep(self) -> None: if self.sleep > 0: sleep_time = random.random() if self.random_sleep else 1 sleep_time *= self.sleep time.sleep(sleep_time) def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. issue #1080 self.steps += 1 if self._md_action and isinstance(action, np.ndarray): action = action[0] if self.terminated: raise ValueError("step after done !!!") self.do_sleep() if self.index == self.size: self.terminated = True return self._get_state(), self._get_reward(), self.terminated, False, {} if action == 0: self.index = max(self.index - 1, 0) return ( self._get_state(), self._get_reward(), self.terminated, False, {"key": 1, "env": self} if self.dict_state else {}, ) if action == 1: self.index += 1 self.terminated = self.index == self.size return ( self._get_state(), self._get_reward(), self.terminated, False, {"key": 1, "env": self}, ) return None class NXEnv(gym.Env): def __init__(self, size: int, obs_type: str, feat_dim: int = 32) -> None: self.size = size self.feat_dim = feat_dim self.graph = nx.Graph() self.graph.add_nodes_from(list(range(size))) assert obs_type in ["array", "object"] self.obs_type = obs_type def _encode_obs(self) -> np.ndarray | nx.Graph: if self.obs_type == "array": return np.stack([v["data"] for v in self.graph._node.values()]) return deepcopy(self.graph) def reset( self, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[np.ndarray | nx.Graph, dict]: super().reset(seed=seed) graph_state = np.random.rand(self.size, self.feat_dim) for i in range(self.size): self.graph.nodes[i]["data"] = graph_state[i] return self._encode_obs(), {} def step( self, action: Space, ) -> tuple[np.ndarray | nx.Graph, float, Literal[False], Literal[False], dict]: next_graph_state = np.random.rand(self.size, self.feat_dim) for i in range(self.size): self.graph.nodes[i]["data"] = next_graph_state[i] return self._encode_obs(), 1.0, False, False, {} class MyGoalEnv(MoveToRightEnv): def __init__(self, *args: Any, **kwargs: Any) -> None: assert ( kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 ), "dict_state / recurse_state not supported" super().__init__(*args, **kwargs) obs, _ = super().reset(options={"state": 0}) obs, _, _, _, _ = super().step(1) self._goal = obs * self.size super_obsv = self.observation_space self.observation_space = gym.spaces.Dict( { "observation": super_obsv, "achieved_goal": super_obsv, "desired_goal": super_obsv, }, ) def reset(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], dict]: obs, info = super().reset(*args, **kwargs) new_obs = {"observation": obs, "achieved_goal": obs, "desired_goal": self._goal} return new_obs, info def step(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], float, bool, bool, dict]: obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs) new_obs_next = { "observation": obs_next, "achieved_goal": obs_next, "desired_goal": self._goal, } return new_obs_next, rew, terminated, truncated, info def compute_reward_fn( self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict, ) -> np.ndarray: axis: tuple[int, ...] = (-3, -2, -1) if self.array_state else (-1,) return (achieved_goal == desired_goal).all(axis=axis)