Markus Krimmel 6c6c872523
Gymnasium Integration (#789)
Changes:
- Disclaimer in README
- Replaced all occurences of Gym with Gymnasium
- Removed code that is now dead since we no longer need to support the
old step API
- Updated type hints to only allow new step API
- Increased required version of envpool to support Gymnasium
- Increased required version of PettingZoo to support Gymnasium
- Updated `PettingZooEnv` to only use the new step API, removed hack to
also support old API
- I had to add some `# type: ignore` comments, due to new type hinting
in Gymnasium. I'm not that familiar with type hinting but I believe that
the issue is on the Gymnasium side and we are looking into it.
- Had to update `MyTestEnv` to support `options` kwarg
- Skip NNI tests because they still use OpenAI Gym
- Also allow `PettingZooEnv` in vector environment
- Updated doc page about ReplayBuffer to also talk about terminated and
truncated flags.

Still need to do: 
- Update the Jupyter notebooks in docs
- Check the entire code base for more dead code (from compatibility
stuff)
- Check the reset functions of all environments/wrappers in code base to
make sure they use the `options` kwarg
- Someone might want to check test_env_finite.py
- Is it okay to allow `PettingZooEnv` in vector environments? Might need
to update docs?
2023-02-03 11:57:27 -08:00

257 lines
8.8 KiB
Python

import ctypes
import time
from collections import OrderedDict
from multiprocessing import Array, Pipe, connection
from multiprocessing.context import Process
from typing import Any, Callable, List, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type
from tianshou.env.worker import EnvWorker
_NP_TO_CT = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
self.dtype = dtype
self.shape = shape
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst,
dtype=self.dtype).reshape(self.shape) # type: ignore
np.copyto(dst_np, ndarray)
def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
else:
return ShArray(space.dtype, space.shape) # type: ignore
def _worker(
parent: connection.Connection,
p: connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None:
def _encode_obs(
obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]
) -> None:
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
buffer.save(obs)
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
for o, b in zip(obs, buffer):
_encode_obs(o, b)
elif isinstance(obs, dict) and isinstance(buffer, dict):
for k in obs.keys():
_encode_obs(obs[k], buffer[k])
return None
parent.close()
env = env_fn_wrapper.data()
try:
while True:
try:
cmd, data = p.recv()
except EOFError: # the pipe has been closed
p.close()
break
if cmd == "step":
env_return = env.step(data)
if obs_bufs is not None:
_encode_obs(env_return[0], obs_bufs)
env_return = (None, *env_return[1:])
p.send(env_return)
elif cmd == "reset":
obs, info = env.reset(**data)
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
p.send((obs, info))
elif cmd == "close":
p.send(env.close())
p.close()
break
elif cmd == "render":
p.send(env.render(**data) if hasattr(env, "render") else None)
elif cmd == "seed":
if hasattr(env, "seed"):
p.send(env.seed(data))
else:
env.reset(seed=data)
p.send(None)
elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr":
setattr(env.unwrapped, data["key"], data["value"])
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
def __init__(
self, env_fn: Callable[[], gym.Env], share_memory: bool = False
) -> None:
self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
if self.share_memory:
dummy = env_fn()
obs_space = dummy.observation_space
dummy.close()
del dummy
self.buffer = _setup_buf(obs_space)
args = (
self.parent_remote,
self.child_remote,
CloudpickleWrapper(env_fn),
self.buffer,
)
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
super().__init__(env_fn)
def get_env_attr(self, key: str) -> Any:
self.parent_remote.send(["getattr", key])
return self.parent_remote.recv()
def set_env_attr(self, key: str, value: Any) -> None:
self.parent_remote.send(["setattr", {"key": key, "value": value}])
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs(
buffer: Optional[Union[dict, tuple, ShArray]]
) -> Union[dict, tuple, np.ndarray]:
if isinstance(buffer, ShArray):
return buffer.get()
elif isinstance(buffer, tuple):
return tuple([decode_obs(b) for b in buffer])
elif isinstance(buffer, dict):
return {k: decode_obs(v) for k, v in buffer.items()}
else:
raise NotImplementedError
return decode_obs(self.buffer)
@staticmethod
def wait( # type: ignore
workers: List["SubprocEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["SubprocEnvWorker"]:
remain_conns = conns = [x.parent_remote for x in workers]
ready_conns: List[connection.Connection] = []
remain_time, t1 = timeout, time.time()
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
if timeout:
remain_time = timeout - (time.time() - t1)
if remain_time <= 0:
break
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
if action is None:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
else:
self.parent_remote.send(["step", action])
def recv(
self
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125
result = self.parent_remote.recv()
if isinstance(result, tuple):
if len(result) == 2:
obs, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, info
obs = result[0]
if self.share_memory:
obs = self._decode_obs()
return (obs, *result[1:]) # type: ignore
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs
def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
result = self.parent_remote.recv()
if isinstance(result, tuple):
obs, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, info
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed)
self.parent_remote.send(["seed", seed])
return self.parent_remote.recv()
def render(self, **kwargs: Any) -> Any:
self.parent_remote.send(["render", kwargs])
return self.parent_remote.recv()
def close_env(self) -> None:
try:
self.parent_remote.send(["close", None])
# mp may be deleted so it may raise AttributeError
self.parent_remote.recv()
self.process.join()
except (BrokenPipeError, EOFError, AttributeError):
pass
# ensure the subproc is terminated
self.process.terminate()