From b284ace1024b0f73b0f8d4fdd3f7b2ea1fb340e5 Mon Sep 17 00:00:00 2001 From: n+e Date: Sun, 13 Sep 2020 19:31:50 +0800 Subject: [PATCH] type check in unit test (#200) Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml. Also remove the out-of-the-date api --- .github/workflows/lint_and_docs.yml | 5 +- docs/contributing.rst | 10 ++ docs/index.rst | 1 + setup.cfg | 23 +++++ test/throughput/test_collector_profile.py | 15 --- tianshou/__init__.py | 2 +- tianshou/data/batch.py | 51 ++++------ tianshou/data/buffer.py | 20 ++-- tianshou/data/collector.py | 46 ++------- tianshou/data/utils/converter.py | 4 +- tianshou/env/__init__.py | 3 +- tianshou/env/venvs.py | 29 ++---- tianshou/env/worker/base.py | 2 +- tianshou/env/worker/dummy.py | 2 +- tianshou/env/worker/ray.py | 2 +- tianshou/env/worker/subproc.py | 117 +++++++++++----------- tianshou/exploration/random.py | 19 ++-- tianshou/policy/base.py | 2 +- tianshou/policy/imitation/base.py | 4 +- tianshou/policy/modelfree/a2c.py | 6 +- tianshou/policy/modelfree/ddpg.py | 16 +-- tianshou/policy/modelfree/dqn.py | 11 +- tianshou/policy/modelfree/pg.py | 7 +- tianshou/policy/modelfree/ppo.py | 9 +- tianshou/policy/modelfree/sac.py | 4 +- tianshou/trainer/offpolicy.py | 4 +- tianshou/trainer/onpolicy.py | 10 +- tianshou/trainer/utils.py | 2 +- tianshou/utils/moving_average.py | 4 +- tianshou/utils/net/common.py | 37 ++++--- tianshou/utils/net/continuous.py | 2 +- tianshou/utils/net/discrete.py | 4 +- 32 files changed, 224 insertions(+), 249 deletions(-) diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 07e1def..6226798 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -1,4 +1,4 @@ -name: PEP8 and Docs Check +name: PEP8, Types and Docs Check on: [push, pull_request] @@ -20,6 +20,9 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --show-source --statistics + - name: Type check + run: | + mypy - name: Documentation test run: | pydocstyle tianshou diff --git a/docs/contributing.rst b/docs/contributing.rst index b56a7a4..cf015a9 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run: $ flake8 . --count --show-source --statistics +Type Check +---------- + +We use `mypy `_ to check the type annotations. To check, in the main directory, run: + +.. code-block:: bash + + $ mypy + + Test Locally ------------ diff --git a/docs/index.rst b/docs/index.rst index 5d962af..587dc7e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,6 +31,7 @@ Here is Tianshou's other features: * Support :ref:`customize_training` * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support :doc:`/tutorials/tictactoe` +* Comprehensive `unit tests `_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking 中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ diff --git a/setup.cfg b/setup.cfg index 188700c..20f8d6d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,26 @@ +[mypy] +files = tianshou/**/*.py +allow_redefinition = True +check_untyped_defs = True +disallow_incomplete_defs = True +disallow_untyped_defs = True +ignore_missing_imports = True +no_implicit_optional = True +pretty = True +show_error_codes = True +show_error_context = True +show_traceback = True +strict_equality = True +strict_optional = True +warn_no_return = True +warn_redundant_casts = True +warn_unreachable = True +warn_unused_configs = True +warn_unused_ignores = True + +[mypy-tianshou.utils.net.*] +ignore_errors = True + [pydocstyle] ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402 diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index 21260f5..4036472 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -100,11 +100,6 @@ def test_collect_ep(data): data["collector"].collect(n_episode=10) -def test_sample(data): - for _ in range(5000): - data["collector"].sample(256) - - def test_init_vec_env(data): for _ in range(5000): Collector(data["policy"], data["env_vec"], data["buffer"]) @@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data): data["collector_vec"].collect(n_episode=10) -def test_sample_vec_env(data): - for _ in range(5000): - data["collector_vec"].sample(256) - - def test_init_subproc_env(data): for _ in range(5000): Collector(data["policy"], data["env_subproc_init"], data["buffer"]) @@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data): data["collector_subproc"].collect(n_episode=10) -def test_sample_subproc_env(data): - for _ in range(5000): - data["collector_subproc"].sample(256) - - if __name__ == '__main__': pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"]) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 5ad1c9f..0b9c0e9 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.2.7" +__version__ = "0.3.0rc0" __all__ = [ "env", diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a06c0df..fe35160 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -6,7 +6,7 @@ from copy import deepcopy from numbers import Number from collections.abc import Collection from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ - Sequence, KeysView, ValuesView, ItemsView + Sequence # Disable pickle warning related to torch, since it has been removed # on torch master branch. See Pull Request #39003 for details: @@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): try: - return torch.stack(v) + return torch.stack(v) # type: ignore except RuntimeError as e: raise TypeError("Batch does not support non-stackable iterable" " of torch.Tensor as unique value yet.") from e @@ -191,12 +191,20 @@ class Batch: elif _is_batch_set(batch_dict): self.stack_(batch_dict) if len(kwargs) > 0: - self.__init__(kwargs, copy=copy) + self.__init__(kwargs, copy=copy) # type: ignore def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" self.__dict__[key] = _parse_value(value) + def __getattr__(self, key: str) -> Any: + """Return self.key. The "Any" return type is needed for mypy.""" + return getattr(self.__dict__, key) + + def __contains__(self, key: str) -> bool: + """Return key in self.""" + return key in self.__dict__ + def __getstate__(self) -> Dict[str, Any]: """Pickling interface. @@ -215,11 +223,11 @@ class Batch: At this point, self is an empty Batch instance that has not been initialized, so it can safely be initialized by the pickle state. """ - self.__init__(**state) + self.__init__(**state) # type: ignore def __getitem__( self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] - ) -> Union["Batch", np.ndarray, torch.Tensor]: + ) -> Any: """Return self[index].""" if isinstance(index, str): return self.__dict__[index] @@ -245,7 +253,7 @@ class Batch: if isinstance(index, str): self.__dict__[index] = value return - if isinstance(value, (np.ndarray, torch.Tensor)): + if not isinstance(value, Batch): raise ValueError("Batch does not supported tensor assignment. " "Use a compatible Batch or dict instead.") if not set(value.keys()).issubset(self.__dict__.keys()): @@ -330,30 +338,6 @@ class Batch: s = self.__class__.__name__ + "()" return s - def __contains__(self, key: str) -> bool: - """Return key in self.""" - return key in self.__dict__ - - def keys(self) -> KeysView[str]: - """Return self.keys().""" - return self.__dict__.keys() - - def values(self) -> ValuesView[Any]: - """Return self.values().""" - return self.__dict__.values() - - def items(self) -> ItemsView[str, Any]: - """Return self.items().""" - return self.__dict__.items() - - def get(self, k: str, d: Optional[Any] = None) -> Any: - """Return self[k] if k in self else d. d defaults to None.""" - return self.__dict__.get(k, d) - - def pop(self, k: str, d: Optional[Any] = None) -> Any: - """Return & remove self[k] if k in self else d. d defaults to None.""" - return self.__dict__.pop(k, d) - def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" for k, v in self.items(): @@ -375,7 +359,6 @@ class Batch: if isinstance(v, torch.Tensor): if dtype is not None and v.dtype != dtype or \ v.device.type != device.type or \ - device.index is not None and \ device.index != v.device.index: if dtype is not None: v = v.type(dtype) @@ -517,7 +500,7 @@ class Batch: return batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] if not self.is_empty(): - batches = [self] + list(batches) + batches = [self] + batches # collect non-empty keys keys_map = [ set(k for k, v in batch.items() @@ -672,8 +655,8 @@ class Batch: for v in self.__dict__.values(): if isinstance(v, Batch) and v.is_empty(recurse=True): continue - elif hasattr(v, "__len__") and (not isinstance( - v, (np.ndarray, torch.Tensor)) or v.ndim > 0 + elif hasattr(v, "__len__") and ( + isinstance(v, Batch) or v.ndim > 0 ): r.append(len(v)) else: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index ff16085..1e817d9 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,7 +1,7 @@ import torch import numpy as np from numbers import Number -from typing import Any, Dict, Tuple, Union, Optional +from typing import Any, Dict, List, Tuple, Union, Optional from tianshou.data import Batch, SegmentTree, to_numpy from tianshou.data.batch import _create_value @@ -138,7 +138,7 @@ class ReplayBuffer: self._indices = np.arange(size) self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 - self._avail_index = [] + self._avail_index: List[int] = [] self._save_s_ = not ignore_obs_next self._last_obs = save_only_last_obs self._index = 0 @@ -175,12 +175,12 @@ class ReplayBuffer: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, (torch.Tensor, np.ndarray)) \ - and inst.shape != value.shape[1:]: - raise ValueError( - "Cannot add data to a buffer with different shape, with key " - f"{name}, expect {value.shape[1:]}, given {inst.shape}." - ) + if isinstance(inst, (torch.Tensor, np.ndarray)): + if inst.shape != value.shape[1:]: + raise ValueError( + "Cannot add data to a buffer with different shape with key" + f" {name}, expect {value.shape[1:]}, given {inst.shape}." + ) try: value[self._index] = inst except KeyError: @@ -205,7 +205,7 @@ class ReplayBuffer: stack_num_orig = buffer.stack_num buffer.stack_num = 1 while True: - self.add(**buffer[i]) + self.add(**buffer[i]) # type: ignore i = (i + 1) % len(buffer) if i == begin: break @@ -323,7 +323,7 @@ class ReplayBuffer: try: if stack_num == 1: return val[indice] - stack = [] + stack: List[Any] = [] for _ in range(stack_num): stack = [val[indice]] + stack pre_indice = np.asarray(indice - 1) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index aacaa9b..5315538 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -212,10 +212,8 @@ class Collector(object): finished_env_ids = [] reward_total = 0.0 whole_data = Batch() - list_n_episode = False - if n_episode is not None and not np.isscalar(n_episode): + if isinstance(n_episode, list): assert len(n_episode) == self.get_env_num() - list_n_episode = True finished_env_ids = [ i for i in self._ready_env_ids if n_episode[i] <= 0] self._ready_env_ids = np.array( @@ -266,7 +264,8 @@ class Collector(object): self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) - if self._action_noise is not None: # noqa + if self._action_noise is not None: + assert isinstance(self.data.act, np.ndarray) self.data.act += self._action_noise(self.data.act.shape) # step in env @@ -291,7 +290,7 @@ class Collector(object): # add data into the buffer if self.preprocess_fn: - result = self.preprocess_fn(**self.data) + result = self.preprocess_fn(**self.data) # type: ignore self.data.update(result) for j, i in enumerate(self._ready_env_ids): @@ -305,14 +304,14 @@ class Collector(object): self._cached_buf[i].add(**self.data[j]) if done[j]: - if not (list_n_episode and - episode_count[i] >= n_episode[i]): + if not (isinstance(n_episode, list) + and episode_count[i] >= n_episode[i]): episode_count[i] += 1 reward_total += np.sum(self._cached_buf[i].rew, axis=0) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) - if list_n_episode and \ + if isinstance(n_episode, list) and \ episode_count[i] >= n_episode[i]: # env i has collected enough data, it has finished finished_env_ids.append(i) @@ -324,10 +323,9 @@ class Collector(object): env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: - obs_next[env_ind_local] = self.preprocess_fn( + obs_reset = self.preprocess_fn( obs=obs_reset).get("obs", obs_reset) - else: - obs_next[env_ind_local] = obs_reset + obs_next[env_ind_local] = obs_reset self.data.obs = obs_next if is_async: # set data back @@ -362,7 +360,7 @@ class Collector(object): # average reward across the number of episodes reward_avg = reward_total / episode_count if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg - reward_avg = self._rew_metric(reward_avg) + reward_avg = self._rew_metric(reward_avg) # type: ignore return { "n/ep": episode_count, "n/st": step_count, @@ -372,30 +370,6 @@ class Collector(object): "len": step_count / episode_count, } - def sample(self, batch_size: int) -> Batch: - """Sample a data batch from the internal replay buffer. - - It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before - returning the final batch data. - - :param int batch_size: ``0`` means it will extract all the data from - the buffer, otherwise it will extract the data with the given - batch_size. - """ - warnings.warn( - "Collector.sample is deprecated and will cause error if you use " - "prioritized experience replay! Collector.sample will be removed " - "upon version 0.3. Use policy.update instead!", Warning) - assert self.buffer is not None, "Cannot get sample from empty buffer!" - batch_data, indice = self.buffer.sample(batch_size) - batch_data = self.process_fn(batch_data, self.buffer, indice) - return batch_data - - def close(self) -> None: - warnings.warn( - "Collector.close is deprecated and will be removed upon version " - "0.3.", Warning) - def _batch_set_item( source: Batch, indices: np.ndarray, target: Batch, size: int diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 6a5d43e..a8e8845 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -45,14 +45,14 @@ def to_torch( if isinstance(x, np.ndarray) and issubclass( x.dtype.type, (np.bool_, np.number) ): # most often case - x = torch.from_numpy(x).to(device) + x = torch.from_numpy(x).to(device) # type: ignore if dtype is not None: x = x.type(dtype) return x elif isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) - return x.to(device) + return x.to(device) # type: ignore elif isinstance(x, (np.number, np.bool_, Number)): return to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, dict): diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index d3a49b7..a25e06f 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,11 +1,10 @@ -from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \ +from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \ SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv from tianshou.env.maenv import MultiAgentEnv __all__ = [ "BaseVectorEnv", "DummyVectorEnv", - "VectorEnv", # TODO: remove in later version "SubprocVectorEnv", "ShmemVectorEnv", "RayVectorEnv", diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 6b82001..ddf7aca 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,5 +1,4 @@ import gym -import warnings import numpy as np from typing import Any, List, Union, Optional, Callable @@ -84,12 +83,12 @@ class BaseVectorEnv(gym.Env): self.timeout is None or self.timeout > 0 ), f"timeout is {timeout}, it should be positive if provided!" self.is_async = self.wait_num != len(env_fns) or timeout is not None - self.waiting_conn = [] + self.waiting_conn: List[EnvWorker] = [] # environments in self.ready_id is actually ready # but environments in self.waiting_id are just waiting when checked, # and they may be ready now, but this is not known until we check it # in the step() function - self.waiting_id = [] + self.waiting_id: List[int] = [] # all environments are ready in the beginning self.ready_id = list(range(self.env_num)) self.is_closed = False @@ -216,10 +215,11 @@ class BaseVectorEnv(gym.Env): self.waiting_conn.append(self.workers[env_id]) self.waiting_id.append(env_id) self.ready_id = [x for x in self.ready_id if x not in id] - ready_conns, result = [], [] + ready_conns: List[EnvWorker] = [] while not ready_conns: ready_conns = self.worker_class.wait( self.waiting_conn, self.wait_num, self.timeout) + result = [] for conn in ready_conns: waiting_index = self.waiting_conn.index(conn) self.waiting_conn.pop(waiting_index) @@ -243,11 +243,14 @@ class BaseVectorEnv(gym.Env): which a reproducer pass to "seed". """ self._assert_is_not_closed() + seed_list: Union[List[None], List[int]] if seed is None: - seed = [seed] * self.env_num - elif np.isscalar(seed): - seed = [seed + i for i in range(self.env_num)] - return [w.seed(s) for w, s in zip(self.workers, seed)] + seed_list = [seed] * self.env_num + elif isinstance(seed, int): + seed_list = [seed + i for i in range(self.env_num)] + else: + seed_list = seed + return [w.seed(s) for w, s in zip(self.workers, seed_list)] def render(self, **kwargs: Any) -> List[Any]: """Render all of the environments.""" @@ -295,16 +298,6 @@ class DummyVectorEnv(BaseVectorEnv): env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) -class VectorEnv(DummyVectorEnv): - """VectorEnv is renamed to DummyVectorEnv.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( - "VectorEnv is renamed to DummyVectorEnv, and will be removed in " - "0.3. Use DummyVectorEnv instead!", Warning) - super().__init__(*args, **kwargs) - - class SubprocVectorEnv(BaseVectorEnv): """Vectorized environment wrapper based on subprocess. diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index d097260..f13fe37 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -10,7 +10,7 @@ class EnvWorker(ABC): def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False - self.result = (None, None, None, None) + self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] @abstractmethod def __getattr__(self, key: str) -> Any: diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 9e88840..9e0c3c5 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -19,7 +19,7 @@ class DummyEnvWorker(EnvWorker): return self.env.reset() @staticmethod - def wait( + def wait( # type: ignore workers: List["DummyEnvWorker"], wait_num: int, timeout: Optional[float] = None, diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 5517388..165e104 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -24,7 +24,7 @@ class RayEnvWorker(EnvWorker): return ray.get(self.env.reset.remote()) @staticmethod - def wait( + def wait( # type: ignore workers: List["RayEnvWorker"], wait_num: int, timeout: Optional[float] = None, diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 4b280c2..b578dd7 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -11,22 +11,71 @@ from tianshou.env.worker import EnvWorker from tianshou.env.utils import CloudpickleWrapper +_NP_TO_CT = { + np.bool: ctypes.c_bool, + np.bool_: ctypes.c_bool, + np.uint8: ctypes.c_uint8, + np.uint16: ctypes.c_uint16, + np.uint32: ctypes.c_uint32, + np.uint64: ctypes.c_uint64, + np.int8: ctypes.c_int8, + np.int16: ctypes.c_int16, + np.int32: ctypes.c_int32, + np.int64: ctypes.c_int64, + np.float32: ctypes.c_float, + np.float64: ctypes.c_double, +} + + +class ShArray: + """Wrapper of multiprocessing Array.""" + + def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: + self.arr = Array( + _NP_TO_CT[dtype.type], # type: ignore + int(np.prod(shape)), + ) + self.dtype = dtype + self.shape = shape + + def save(self, ndarray: np.ndarray) -> None: + assert isinstance(ndarray, np.ndarray) + dst = self.arr.get_obj() + dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape) + np.copyto(dst_np, ndarray) + + def get(self) -> np.ndarray: + obj = self.arr.get_obj() + return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) + + +def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: + if isinstance(space, gym.spaces.Dict): + assert isinstance(space.spaces, OrderedDict) + return {k: _setup_buf(v) for k, v in space.spaces.items()} + elif isinstance(space, gym.spaces.Tuple): + assert isinstance(space.spaces, tuple) + return tuple([_setup_buf(t) for t in space.spaces]) + else: + return ShArray(space.dtype, space.shape) + + def _worker( parent: connection.Connection, p: connection.Connection, env_fn_wrapper: CloudpickleWrapper, - obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None, + obs_bufs: Optional[Union[dict, tuple, ShArray]] = None, ) -> None: def _encode_obs( obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray], ) -> None: - if isinstance(obs, np.ndarray): + if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray): buffer.save(obs) - elif isinstance(obs, tuple): + elif isinstance(obs, tuple) and isinstance(buffer, tuple): for o, b in zip(obs, buffer): _encode_obs(o, b) - elif isinstance(obs, dict): + elif isinstance(obs, dict) and isinstance(buffer, dict): for k in obs.keys(): _encode_obs(obs[k], buffer[k]) return None @@ -69,52 +118,6 @@ def _worker( p.close() -_NP_TO_CT = { - np.bool: ctypes.c_bool, - np.bool_: ctypes.c_bool, - np.uint8: ctypes.c_uint8, - np.uint16: ctypes.c_uint16, - np.uint32: ctypes.c_uint32, - np.uint64: ctypes.c_uint64, - np.int8: ctypes.c_int8, - np.int16: ctypes.c_int16, - np.int32: ctypes.c_int32, - np.int64: ctypes.c_int64, - np.float32: ctypes.c_float, - np.float64: ctypes.c_double, -} - - -class ShArray: - """Wrapper of multiprocessing Array.""" - - def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: - self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) - self.dtype = dtype - self.shape = shape - - def save(self, ndarray: np.ndarray) -> None: - assert isinstance(ndarray, np.ndarray) - dst = self.arr.get_obj() - dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape) - np.copyto(dst_np, ndarray) - - def get(self) -> np.ndarray: - obj = self.arr.get_obj() - return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) - - -def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: - if isinstance(space, gym.spaces.Dict): - assert isinstance(space.spaces, OrderedDict) - return {k: _setup_buf(v) for k, v in space.spaces.items()} - elif isinstance(space, gym.spaces.Tuple): - assert isinstance(space.spaces, tuple) - return tuple([_setup_buf(t) for t in space.spaces]) - else: - return ShArray(space.dtype, space.shape) - - class SubprocEnvWorker(EnvWorker): """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" @@ -124,7 +127,7 @@ class SubprocEnvWorker(EnvWorker): super().__init__(env_fn) self.parent_remote, self.child_remote = Pipe() self.share_memory = share_memory - self.buffer = None + self.buffer: Optional[Union[dict, tuple, ShArray]] = None if self.share_memory: dummy = env_fn() obs_space = dummy.observation_space @@ -168,25 +171,23 @@ class SubprocEnvWorker(EnvWorker): return obs @staticmethod - def wait( + def wait( # type: ignore workers: List["SubprocEnvWorker"], wait_num: int, timeout: Optional[float] = None, ) -> List["SubprocEnvWorker"]: - conns, ready_conns = [x.parent_remote for x in workers], [] - remain_conns = conns - t1 = time.time() + remain_conns = conns = [x.parent_remote for x in workers] + ready_conns: List[connection.Connection] = [] + remain_time, t1 = timeout, time.time() while len(remain_conns) > 0 and len(ready_conns) < wait_num: if timeout: remain_time = timeout - (time.time() - t1) if remain_time <= 0: break - else: - remain_time = timeout # connection.wait hangs if the list is empty new_ready_conns = connection.wait( remain_conns, timeout=remain_time) - ready_conns.extend(new_ready_conns) + ready_conns.extend(new_ready_conns) # type: ignore remain_conns = [ conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index f06b7b5..2e495dc 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -9,15 +9,15 @@ class BaseNoise(ABC, object): def __init__(self) -> None: super().__init__() + def reset(self) -> None: + """Reset to the initial state.""" + pass + @abstractmethod def __call__(self, size: Sequence[int]) -> np.ndarray: """Generate new noise.""" raise NotImplementedError - def reset(self) -> None: - """Reset to the initial state.""" - pass - class GaussianNoise(BaseNoise): """The vanilla gaussian process, for exploration in DDPG by default.""" @@ -64,6 +64,10 @@ class OUNoise(BaseNoise): self._x0 = x0 self.reset() + def reset(self) -> None: + """Reset to the initial state.""" + self._x = self._x0 + def __call__( self, size: Sequence[int], mu: Optional[float] = None ) -> np.ndarray: @@ -71,14 +75,11 @@ class OUNoise(BaseNoise): Return an numpy array which size is equal to ``size``. """ - if self._x is None or self._x.shape != size: + if self._x is None or isinstance( + self._x, np.ndarray) and self._x.shape != size: self._x = 0.0 if mu is None: mu = self._mu r = self._beta * np.random.normal(size=size) self._x = self._x + self._alpha * (mu - self._x) + r return self._x - - def reset(self) -> None: - """Reset to the initial state.""" - self._x = self._x0 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0b00585..8bd0fcb 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -190,7 +190,7 @@ class BasePolicy(ABC, nn.Module): array with shape (bsz, ). """ rew = batch.rew - v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten() + v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten()) returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 13dda4e..d65cbc8 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -55,11 +55,11 @@ class ImitationPolicy(BasePolicy): if self.mode == "continuous": # regression a = self(batch).act a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) - loss = F.mse_loss(a, a_) + loss = F.mse_loss(a, a_) # type: ignore elif self.mode == "discrete": # classification a = self(batch).logits a_ = to_torch(batch.act, dtype=torch.long, device=a.device) - loss = F.nll_loss(a, a_) + loss = F.nll_loss(a, a_) # type: ignore loss.backward() self.optim.step() return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index ae7cfa7..df34e01 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -103,11 +103,11 @@ class A2CPolicy(PGPolicy): if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: - dist = self.dist_fn(logits) + dist = self.dist_fn(logits) # type: ignore act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn( + def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] @@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy): r = to_torch_as(b.returns, v) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) a_loss = -(log_prob * (r - v).detach()).mean() - vf_loss = F.mse_loss(r, v) + vf_loss = F.mse_loss(r, v) # type: ignore ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 3770c64..f19f56e 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -53,14 +53,16 @@ class DDPGPolicy(BasePolicy): **kwargs: Any, ) -> None: super().__init__(**kwargs) - if actor is not None: - self.actor, self.actor_old = actor, deepcopy(actor) + if actor is not None and actor_optim is not None: + self.actor: torch.nn.Module = actor + self.actor_old = deepcopy(actor) self.actor_old.eval() - self.actor_optim = actor_optim - if critic is not None: - self.critic, self.critic_old = critic, deepcopy(critic) + self.actor_optim: torch.optim.Optimizer = actor_optim + if critic is not None and critic_optim is not None: + self.critic: torch.nn.Module = critic + self.critic_old = deepcopy(critic) self.critic_old.eval() - self.critic_optim = critic_optim + self.critic_optim: torch.optim.Optimizer = critic_optim assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" self._tau = tau assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" @@ -141,7 +143,7 @@ class DDPGPolicy(BasePolicy): obs = getattr(batch, input) actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias - if self.training and explorating: + if self._noise and self.training and explorating: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c6ba4fa..070f11d 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -146,11 +146,10 @@ class DQNPolicy(BasePolicy): obs = getattr(batch, input) obs_ = obs.obs if hasattr(obs, "obs") else obs q, h = model(obs_, state=state, info=batch.info) - act = to_numpy(q.max(dim=1)[1]) - has_mask = hasattr(obs, 'mask') - if has_mask: + act: np.ndarray = to_numpy(q.max(dim=1)[1]) + if hasattr(obs, "mask"): # some of actions are masked, they cannot be selected - q_ = to_numpy(q) + q_: np.ndarray = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) # add eps to act @@ -160,7 +159,7 @@ class DQNPolicy(BasePolicy): for i in range(len(q)): if np.random.rand() < eps: q_ = np.random.rand(*q[i].shape) - if has_mask: + if hasattr(obs, "mask"): q_[~obs.mask[i]] = -np.inf act[i] = q_.argmax() return Batch(logits=q, act=act, state=h) @@ -172,7 +171,7 @@ class DQNPolicy(BasePolicy): weight = batch.pop("weight", 1.0) q = self(batch, eps=0.0).logits q = q[np.arange(len(q)), batch.act] - r = to_torch_as(batch.returns, q).flatten() + r = to_torch_as(batch.returns.flatten(), q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index e93b4f9..2f33046 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -32,7 +32,8 @@ class PGPolicy(BasePolicy): **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.model = model + if model is not None: + self.model: torch.nn.Module = model self.optim = optim self.dist_fn = dist_fn assert ( @@ -81,11 +82,11 @@ class PGPolicy(BasePolicy): if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: - dist = self.dist_fn(logits) + dist = self.dist_fn(logits) # type: ignore act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn( + def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses = [] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 349ab9b..60bb19e 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -137,13 +137,13 @@ class PPOPolicy(PGPolicy): if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: - dist = self.dist_fn(logits) + dist = self.dist_fn(logits) # type: ignore act = dist.sample() if self._range: act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist) - def learn( + def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] @@ -157,8 +157,9 @@ class PPOPolicy(PGPolicy): surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv if self._dual_clip: - clip_loss = -torch.max(torch.min(surr1, surr2), - self._dual_clip * b.adv).mean() + clip_loss = -torch.max( + torch.min(surr1, surr2), self._dual_clip * b.adv + ).mean() else: clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index b596eb5..83c7150 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -107,7 +107,7 @@ class SACPolicy(DDPGPolicy): ): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def forward( + def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -193,5 +193,5 @@ class SACPolicy(DDPGPolicy): } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() - result["v/alpha"] = self._alpha.item() + result["v/alpha"] = self._alpha.item() # type: ignore return result diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index bbf5233..01e6530 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -73,7 +73,7 @@ def offpolicy_trainer( """ global_step = 0 best_epoch, best_reward = -1, -1.0 - stat = {} + stat: Dict[str, MovAvg] = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): @@ -91,7 +91,7 @@ def offpolicy_trainer( test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, writer, global_step) - if stop_fn and stop_fn(test_result["rew"]): + if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) for k in result.keys(): diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index ac97ba7..37f4278 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -73,7 +73,7 @@ def onpolicy_trainer( """ global_step = 0 best_epoch, best_reward = -1, -1.0 - stat = {} + stat: Dict[str, MovAvg] = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): @@ -91,7 +91,7 @@ def onpolicy_trainer( test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, writer, global_step) - if stop_fn and stop_fn(test_result["rew"]): + if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) for k in result.keys(): @@ -109,9 +109,9 @@ def onpolicy_trainer( batch_size=batch_size, repeat=repeat_per_collect) train_collector.reset_buffer() step = 1 - for k in losses.keys(): - if isinstance(losses[k], list): - step = max(step, len(losses[k])) + for v in losses.values(): + if isinstance(v, list): + step = max(step, len(v)) global_step += step * collect_per_step for k in result.keys(): data[k] = f"{result[k]:.2f}" diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 5f6698e..0c5d2dd 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -22,7 +22,7 @@ def test_episode( policy.eval() if test_fn: test_fn(epoch) - if collector.get_env_num() > 1 and np.isscalar(n_episode): + if collector.get_env_num() > 1 and isinstance(n_episode, int): n = collector.get_env_num() n_ = np.zeros(n) + n_episode // n n_[:n_episode % n] += 1 diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index 58c6792..58e8860 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -1,7 +1,7 @@ import torch import numpy as np from numbers import Number -from typing import Union +from typing import List, Union from tianshou.data import to_numpy @@ -28,7 +28,7 @@ class MovAvg(object): def __init__(self, size: int = 100) -> None: super().__init__() self.size = size - self.cache = [] + self.cache: List[Union[Number, np.number]] = [] self.banned = [np.inf, np.nan, -np.inf] def add( diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 96023fe..8b52a87 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -12,7 +12,7 @@ def miniblock( norm_layer: Optional[Callable[[int], nn.modules.Module]], ) -> List[nn.modules.Module]: """Construct a miniblock with given input/output-size and norm layer.""" - ret = [nn.Linear(inp, oup)] + ret: List[nn.modules.Module] = [nn.Linear(inp, oup)] if norm_layer is not None: ret += [norm_layer(oup)] ret += [nn.ReLU(inplace=True)] @@ -54,36 +54,33 @@ class Net(nn.Module): if concat: input_size += np.prod(action_shape) - self.model = miniblock(input_size, hidden_layer_size, norm_layer) + model = miniblock(input_size, hidden_layer_size, norm_layer) for i in range(layer_num): - self.model += miniblock(hidden_layer_size, - hidden_layer_size, norm_layer) + model += miniblock( + hidden_layer_size, hidden_layer_size, norm_layer) - if self.dueling is None: + if dueling is None: if action_shape and not concat: - self.model += [nn.Linear(hidden_layer_size, - np.prod(action_shape))] + model += [nn.Linear(hidden_layer_size, np.prod(action_shape))] else: # dueling DQN - assert isinstance(self.dueling, tuple) and len(self.dueling) == 2 - - q_layer_num, v_layer_num = self.dueling - self.Q, self.V = [], [] + q_layer_num, v_layer_num = dueling + Q, V = [], [] for i in range(q_layer_num): - self.Q += miniblock(hidden_layer_size, - hidden_layer_size, norm_layer) + Q += miniblock( + hidden_layer_size, hidden_layer_size, norm_layer) for i in range(v_layer_num): - self.V += miniblock(hidden_layer_size, - hidden_layer_size, norm_layer) + V += miniblock( + hidden_layer_size, hidden_layer_size, norm_layer) if action_shape and not concat: - self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))] - self.V += [nn.Linear(hidden_layer_size, 1)] + Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))] + V += [nn.Linear(hidden_layer_size, 1)] - self.Q = nn.Sequential(*self.Q) - self.V = nn.Sequential(*self.V) - self.model = nn.Sequential(*self.model) + self.Q = nn.Sequential(*Q) + self.V = nn.Sequential(*V) + self.model = nn.Sequential(*model) def forward( self, diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 85f03d2..7afd9eb 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -66,7 +66,7 @@ class Critic(nn.Module): s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) if a is not None: - a = to_torch(a, device=self.device, dtype=torch.float32) + a = to_torch_as(a, s) a = a.flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 6734316..58e44a7 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -4,6 +4,8 @@ from torch import nn import torch.nn.functional as F from typing import Any, Dict, Tuple, Union, Optional, Sequence +from tianshou.data import to_torch + class Actor(nn.Module): """Simple actor network with MLP. @@ -118,5 +120,5 @@ class DQN(nn.Module): ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Q(x, \*).""" if not isinstance(x, torch.Tensor): - x = torch.tensor(x, device=self.device, dtype=torch.float32) + x = to_torch(x, device=self.device, dtype=torch.float32) return self.net(x), state