Michael Panchenko 600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00

73 lines
2.3 KiB
Python

import contextlib
from typing import Any, Callable, Optional
import gymnasium as gym
import numpy as np
from tianshou.env.utils import gym_new_venv_step_type
from tianshou.env.worker import EnvWorker
with contextlib.suppress(ImportError):
import ray
class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env.unwrapped, key, value)
def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)
class RayEnvWorker(EnvWorker):
"""Ray worker used in RayVectorEnv."""
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) # type: ignore
super().__init__(env_fn)
def get_env_attr(self, key: str) -> Any:
return ray.get(self.env.get_env_attr.remote(key))
def set_env_attr(self, key: str, value: Any) -> None:
ray.get(self.env.set_env_attr.remote(key, value))
def reset(self, **kwargs: Any) -> Any:
if "seed" in kwargs:
super().seed(kwargs["seed"])
return ray.get(self.env.reset.remote(**kwargs))
@staticmethod
def wait( # type: ignore
workers: list["RayEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> list["RayEnvWorker"]:
results = [x.result for x in workers]
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
return [workers[results.index(result)] for result in ready_results]
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
# self.result is actually a handle
if action is None:
self.result = self.env.reset.remote(**kwargs)
else:
self.result = self.env.step.remote(action)
def recv(self) -> gym_new_venv_step_type:
return ray.get(self.result) # type: ignore
def seed(self, seed: Optional[int] = None) -> Optional[list[int]]:
super().seed(seed)
try:
return ray.get(self.env.seed.remote(seed))
except (AttributeError, NotImplementedError):
self.env.reset.remote(seed=seed)
return None
def render(self, **kwargs: Any) -> Any:
return ray.get(self.env.render.remote(**kwargs))
def close_env(self) -> None:
ray.get(self.env.close.remote())