2020-07-26 12:01:21 +02:00
|
|
|
import random
|
2021-09-03 05:05:04 +08:00
|
|
|
import time
|
2021-04-25 15:23:46 +08:00
|
|
|
from copy import deepcopy
|
2024-02-06 14:24:30 +01:00
|
|
|
from typing import Any, Literal
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
import gymnasium as gym
|
2021-09-03 05:05:04 +08:00
|
|
|
import networkx as nx
|
|
|
|
import numpy as np
|
2024-02-06 14:24:30 +01:00
|
|
|
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple
|
2020-03-21 10:58:01 +08:00
|
|
|
|
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
class MoveToRightEnv(gym.Env):
|
|
|
|
"""A task for "going right". The task is to go right ``size`` steps.
|
|
|
|
|
|
|
|
The observation is the current index, and the action is to go left or right.
|
|
|
|
Action 0 is to go left, and action 1 is to go right.
|
|
|
|
Taking action 0 at index 0 will keep the index at 0.
|
|
|
|
Arriving at index ``size`` means the task is done.
|
|
|
|
In the current implementation, stepping after the task is done is possible, which will
|
|
|
|
lead the index to be larger than ``size``.
|
|
|
|
|
|
|
|
Index 0 is the starting point. If reset is called with default options, the index will
|
|
|
|
be reset to 0.
|
|
|
|
"""
|
2020-07-13 00:24:31 +08:00
|
|
|
|
2021-09-03 05:05:04 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
2024-02-06 14:24:30 +01:00
|
|
|
size: int,
|
2024-03-28 18:02:31 +01:00
|
|
|
sleep: float = 0.0,
|
2024-02-06 14:24:30 +01:00
|
|
|
dict_state: bool = False,
|
|
|
|
recurse_state: bool = False,
|
|
|
|
ma_rew: int = 0,
|
|
|
|
multidiscrete_action: bool = False,
|
|
|
|
random_sleep: bool = False,
|
|
|
|
array_state: bool = False,
|
|
|
|
) -> None:
|
2023-08-25 23:40:56 +02:00
|
|
|
assert (
|
|
|
|
dict_state + recurse_state + array_state <= 1
|
|
|
|
), "dict_state / recurse_state / array_state can be only one true"
|
2020-03-21 10:58:01 +08:00
|
|
|
self.size = size
|
|
|
|
self.sleep = sleep
|
2020-07-26 12:01:21 +02:00
|
|
|
self.random_sleep = random_sleep
|
2020-04-28 20:56:02 +08:00
|
|
|
self.dict_state = dict_state
|
2020-08-04 13:39:05 +08:00
|
|
|
self.recurse_state = recurse_state
|
2021-02-19 10:33:49 +08:00
|
|
|
self.array_state = array_state
|
2020-07-13 00:24:31 +08:00
|
|
|
self.ma_rew = ma_rew
|
2020-07-24 17:38:12 +08:00
|
|
|
self._md_action = multidiscrete_action
|
2020-08-27 12:15:18 +08:00
|
|
|
# how many steps this env has stepped
|
|
|
|
self.steps = 0
|
2020-08-04 13:39:05 +08:00
|
|
|
if dict_state:
|
|
|
|
self.observation_space = Dict(
|
2021-09-03 05:05:04 +08:00
|
|
|
{
|
2023-08-25 23:40:56 +02:00
|
|
|
"index": Box(shape=(1,), low=0, high=size - 1),
|
|
|
|
"rand": Box(shape=(1,), low=0, high=1, dtype=np.float64),
|
|
|
|
},
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2020-08-04 13:39:05 +08:00
|
|
|
elif recurse_state:
|
|
|
|
self.observation_space = Dict(
|
2021-09-03 05:05:04 +08:00
|
|
|
{
|
2023-08-25 23:40:56 +02:00
|
|
|
"index": Box(shape=(1,), low=0, high=size - 1),
|
|
|
|
"dict": Dict(
|
2021-09-03 05:05:04 +08:00
|
|
|
{
|
2023-08-25 23:40:56 +02:00
|
|
|
"tuple": Tuple(
|
2021-09-03 05:05:04 +08:00
|
|
|
(
|
|
|
|
Discrete(2),
|
2023-08-25 23:40:56 +02:00
|
|
|
Box(shape=(2,), low=0, high=1, dtype=np.float64),
|
|
|
|
),
|
2021-09-03 05:05:04 +08:00
|
|
|
),
|
2023-08-25 23:40:56 +02:00
|
|
|
"rand": Box(shape=(1, 2), low=0, high=1, dtype=np.float64),
|
|
|
|
},
|
|
|
|
),
|
|
|
|
},
|
2021-09-03 05:05:04 +08:00
|
|
|
)
|
2021-02-19 10:33:49 +08:00
|
|
|
elif array_state:
|
|
|
|
self.observation_space = Box(shape=(4, 84, 84), low=0, high=255)
|
2020-08-04 13:39:05 +08:00
|
|
|
else:
|
2023-08-25 23:40:56 +02:00
|
|
|
self.observation_space = Box(shape=(1,), low=0, high=size - 1)
|
2020-07-24 17:38:12 +08:00
|
|
|
if multidiscrete_action:
|
|
|
|
self.action_space = MultiDiscrete([2, 2])
|
|
|
|
else:
|
|
|
|
self.action_space = Discrete(2)
|
2022-09-26 18:31:23 +02:00
|
|
|
self.terminated = False
|
2020-08-04 13:39:05 +08:00
|
|
|
self.index = 0
|
2020-03-21 10:58:01 +08:00
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def reset(
|
|
|
|
self,
|
|
|
|
seed: int | None = None,
|
2024-03-28 18:02:31 +01:00
|
|
|
# TODO: passing a dict here doesn't make any sense
|
2024-02-06 14:24:30 +01:00
|
|
|
options: dict[str, Any] | None = None,
|
|
|
|
) -> tuple[dict[str, Any] | np.ndarray, dict]:
|
2024-03-28 18:02:31 +01:00
|
|
|
""":param seed:
|
|
|
|
:param options: the start index is provided in options["state"]
|
|
|
|
:return:
|
|
|
|
"""
|
2023-02-03 20:57:27 +01:00
|
|
|
if options is None:
|
|
|
|
options = {"state": 0}
|
2022-09-26 18:31:23 +02:00
|
|
|
super().reset(seed=seed)
|
|
|
|
self.terminated = False
|
2022-02-08 00:40:01 +08:00
|
|
|
self.do_sleep()
|
2023-02-03 20:57:27 +01:00
|
|
|
self.index = options["state"]
|
2023-08-25 23:40:56 +02:00
|
|
|
return self._get_state(), {"key": 1, "env": self}
|
2020-07-13 00:24:31 +08:00
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def _get_reward(self) -> list[int] | int:
|
2020-07-13 00:24:31 +08:00
|
|
|
"""Generate a non-scalar reward if ma_rew is True."""
|
2022-09-26 18:31:23 +02:00
|
|
|
end_flag = int(self.terminated)
|
2020-07-13 00:24:31 +08:00
|
|
|
if self.ma_rew > 0:
|
2022-01-30 00:53:56 +08:00
|
|
|
return [end_flag] * self.ma_rew
|
|
|
|
return end_flag
|
2020-07-13 00:24:31 +08:00
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def _get_state(self) -> dict[str, Any] | np.ndarray:
|
2023-08-25 23:40:56 +02:00
|
|
|
"""Generate state(observation) of MyTestEnv."""
|
2020-08-04 13:39:05 +08:00
|
|
|
if self.dict_state:
|
2021-09-03 05:05:04 +08:00
|
|
|
return {
|
2023-08-25 23:40:56 +02:00
|
|
|
"index": np.array([self.index], dtype=np.float32),
|
|
|
|
"rand": self.np_random.random(1),
|
2021-09-03 05:05:04 +08:00
|
|
|
}
|
2023-08-25 23:40:56 +02:00
|
|
|
if self.recurse_state:
|
2021-09-03 05:05:04 +08:00
|
|
|
return {
|
2023-08-25 23:40:56 +02:00
|
|
|
"index": np.array([self.index], dtype=np.float32),
|
|
|
|
"dict": {
|
2022-09-26 18:31:23 +02:00
|
|
|
"tuple": (np.array([1], dtype=int), self.np_random.random(2)),
|
2023-08-25 23:40:56 +02:00
|
|
|
"rand": self.np_random.random((1, 2)),
|
|
|
|
},
|
2021-09-03 05:05:04 +08:00
|
|
|
}
|
2023-08-25 23:40:56 +02:00
|
|
|
if self.array_state:
|
2021-03-31 15:14:22 +08:00
|
|
|
img = np.zeros([4, 84, 84], int)
|
2021-02-19 10:33:49 +08:00
|
|
|
img[3, np.arange(84), np.arange(84)] = self.index
|
|
|
|
img[2, np.arange(84)] = self.index
|
|
|
|
img[1, :, np.arange(84)] = self.index
|
|
|
|
img[0] = self.index
|
|
|
|
return img
|
2023-08-25 23:40:56 +02:00
|
|
|
return np.array([self.index], dtype=np.float32)
|
2020-03-21 10:58:01 +08:00
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def do_sleep(self) -> None:
|
2022-02-08 00:40:01 +08:00
|
|
|
if self.sleep > 0:
|
|
|
|
sleep_time = random.random() if self.random_sleep else 1
|
|
|
|
sleep_time *= self.sleep
|
|
|
|
time.sleep(sleep_time)
|
|
|
|
|
2024-04-03 18:07:51 +02:00
|
|
|
def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. issue #1080
|
2020-08-27 12:15:18 +08:00
|
|
|
self.steps += 1
|
2024-02-06 14:24:30 +01:00
|
|
|
if self._md_action and isinstance(action, np.ndarray):
|
2020-07-24 17:38:12 +08:00
|
|
|
action = action[0]
|
2022-09-26 18:31:23 +02:00
|
|
|
if self.terminated:
|
2023-08-25 23:40:56 +02:00
|
|
|
raise ValueError("step after done !!!")
|
2022-02-08 00:40:01 +08:00
|
|
|
self.do_sleep()
|
2020-03-21 10:58:01 +08:00
|
|
|
if self.index == self.size:
|
2022-09-26 18:31:23 +02:00
|
|
|
self.terminated = True
|
|
|
|
return self._get_state(), self._get_reward(), self.terminated, False, {}
|
2020-03-21 10:58:01 +08:00
|
|
|
if action == 0:
|
|
|
|
self.index = max(self.index - 1, 0)
|
2023-08-25 23:40:56 +02:00
|
|
|
return (
|
|
|
|
self._get_state(),
|
|
|
|
self._get_reward(),
|
|
|
|
self.terminated,
|
|
|
|
False,
|
|
|
|
{"key": 1, "env": self} if self.dict_state else {},
|
|
|
|
)
|
|
|
|
if action == 1:
|
2020-03-21 10:58:01 +08:00
|
|
|
self.index += 1
|
2022-09-26 18:31:23 +02:00
|
|
|
self.terminated = self.index == self.size
|
2023-08-25 23:40:56 +02:00
|
|
|
return (
|
|
|
|
self._get_state(),
|
|
|
|
self._get_reward(),
|
|
|
|
self.terminated,
|
|
|
|
False,
|
|
|
|
{"key": 1, "env": self},
|
|
|
|
)
|
|
|
|
return None
|
2021-04-25 15:23:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
class NXEnv(gym.Env):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(self, size: int, obs_type: str, feat_dim: int = 32) -> None:
|
2021-04-25 15:23:46 +08:00
|
|
|
self.size = size
|
|
|
|
self.feat_dim = feat_dim
|
|
|
|
self.graph = nx.Graph()
|
|
|
|
self.graph.add_nodes_from(list(range(size)))
|
|
|
|
assert obs_type in ["array", "object"]
|
|
|
|
self.obs_type = obs_type
|
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def _encode_obs(self) -> np.ndarray | nx.Graph:
|
2021-04-25 15:23:46 +08:00
|
|
|
if self.obs_type == "array":
|
|
|
|
return np.stack([v["data"] for v in self.graph._node.values()])
|
|
|
|
return deepcopy(self.graph)
|
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def reset(
|
|
|
|
self,
|
|
|
|
seed: int | None = None,
|
|
|
|
options: dict[str, Any] | None = None,
|
|
|
|
) -> tuple[np.ndarray | nx.Graph, dict]:
|
|
|
|
super().reset(seed=seed)
|
2021-04-25 15:23:46 +08:00
|
|
|
graph_state = np.random.rand(self.size, self.feat_dim)
|
|
|
|
for i in range(self.size):
|
|
|
|
self.graph.nodes[i]["data"] = graph_state[i]
|
2022-09-26 18:31:23 +02:00
|
|
|
return self._encode_obs(), {}
|
2021-04-25 15:23:46 +08:00
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def step(
|
|
|
|
self,
|
|
|
|
action: Space,
|
|
|
|
) -> tuple[np.ndarray | nx.Graph, float, Literal[False], Literal[False], dict]:
|
2021-04-25 15:23:46 +08:00
|
|
|
next_graph_state = np.random.rand(self.size, self.feat_dim)
|
|
|
|
for i in range(self.size):
|
|
|
|
self.graph.nodes[i]["data"] = next_graph_state[i]
|
2024-02-06 14:24:30 +01:00
|
|
|
return self._encode_obs(), 1.0, False, False, {}
|
2022-10-31 08:54:54 +09:00
|
|
|
|
|
|
|
|
2024-03-28 18:02:31 +01:00
|
|
|
class MyGoalEnv(MoveToRightEnv):
|
2024-02-06 14:24:30 +01:00
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
2023-08-25 23:40:56 +02:00
|
|
|
assert (
|
|
|
|
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0
|
|
|
|
), "dict_state / recurse_state not supported"
|
2022-10-31 08:54:54 +09:00
|
|
|
super().__init__(*args, **kwargs)
|
2023-02-03 20:57:27 +01:00
|
|
|
obs, _ = super().reset(options={"state": 0})
|
2022-10-31 08:54:54 +09:00
|
|
|
obs, _, _, _, _ = super().step(1)
|
|
|
|
self._goal = obs * self.size
|
|
|
|
super_obsv = self.observation_space
|
|
|
|
self.observation_space = gym.spaces.Dict(
|
|
|
|
{
|
2023-08-25 23:40:56 +02:00
|
|
|
"observation": super_obsv,
|
|
|
|
"achieved_goal": super_obsv,
|
|
|
|
"desired_goal": super_obsv,
|
|
|
|
},
|
2022-10-31 08:54:54 +09:00
|
|
|
)
|
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def reset(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], dict]:
|
2022-10-31 08:54:54 +09:00
|
|
|
obs, info = super().reset(*args, **kwargs)
|
2023-08-25 23:40:56 +02:00
|
|
|
new_obs = {"observation": obs, "achieved_goal": obs, "desired_goal": self._goal}
|
2022-10-31 08:54:54 +09:00
|
|
|
return new_obs, info
|
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def step(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], float, bool, bool, dict]:
|
2022-10-31 08:54:54 +09:00
|
|
|
obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs)
|
|
|
|
new_obs_next = {
|
2023-08-25 23:40:56 +02:00
|
|
|
"observation": obs_next,
|
|
|
|
"achieved_goal": obs_next,
|
|
|
|
"desired_goal": self._goal,
|
2022-10-31 08:54:54 +09:00
|
|
|
}
|
|
|
|
return new_obs_next, rew, terminated, truncated, info
|
|
|
|
|
|
|
|
def compute_reward_fn(
|
2023-08-25 23:40:56 +02:00
|
|
|
self,
|
|
|
|
achieved_goal: np.ndarray,
|
|
|
|
desired_goal: np.ndarray,
|
|
|
|
info: dict,
|
2022-10-31 08:54:54 +09:00
|
|
|
) -> np.ndarray:
|
2024-02-06 14:24:30 +01:00
|
|
|
axis: tuple[int, ...] = (-3, -2, -1) if self.array_state else (-1,)
|
2022-10-31 08:54:54 +09:00
|
|
|
return (achieved_goal == desired_goal).all(axis=axis)
|