n+e 09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00

40 lines
1.1 KiB
Python

import gym
import numpy as np
from typing import Any, List, Callable, Optional
from tianshou.env.worker import EnvWorker
class DummyEnvWorker(EnvWorker):
"""Dummy worker used in sequential vector environments."""
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self.env = env_fn()
super().__init__(env_fn)
def __getattr__(self, key: str) -> Any:
return getattr(self.env, key)
def reset(self) -> Any:
return self.env.reset()
@staticmethod
def wait( # type: ignore
workers: List["DummyEnvWorker"], wait_num: int, timeout: Optional[float] = None
) -> List["DummyEnvWorker"]:
# Sequential EnvWorker objects are always ready
return workers
def send_action(self, action: np.ndarray) -> None:
self.result = self.env.step(action)
def seed(self, seed: Optional[int] = None) -> List[int]:
super().seed(seed)
return self.env.seed(seed)
def render(self, **kwargs: Any) -> Any:
return self.env.render(**kwargs)
def close_env(self) -> None:
self.env.close()