diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml
index 07e1def..6226798 100644
--- a/.github/workflows/lint_and_docs.yml
+++ b/.github/workflows/lint_and_docs.yml
@@ -1,4 +1,4 @@
-name: PEP8 and Docs Check
+name: PEP8, Types and Docs Check
on: [push, pull_request]
@@ -20,6 +20,9 @@ jobs:
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
+ - name: Type check
+ run: |
+ mypy
- name: Documentation test
run: |
pydocstyle tianshou
diff --git a/docs/contributing.rst b/docs/contributing.rst
index b56a7a4..cf015a9 100644
--- a/docs/contributing.rst
+++ b/docs/contributing.rst
@@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run:
$ flake8 . --count --show-source --statistics
+Type Check
+----------
+
+We use `mypy `_ to check the type annotations. To check, in the main directory, run:
+
+.. code-block:: bash
+
+ $ mypy
+
+
Test Locally
------------
diff --git a/docs/index.rst b/docs/index.rst
index 5d962af..587dc7e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -31,6 +31,7 @@ Here is Tianshou's other features:
* Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe`
+* Comprehensive `unit tests `_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_
diff --git a/setup.cfg b/setup.cfg
index 188700c..20f8d6d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,3 +1,26 @@
+[mypy]
+files = tianshou/**/*.py
+allow_redefinition = True
+check_untyped_defs = True
+disallow_incomplete_defs = True
+disallow_untyped_defs = True
+ignore_missing_imports = True
+no_implicit_optional = True
+pretty = True
+show_error_codes = True
+show_error_context = True
+show_traceback = True
+strict_equality = True
+strict_optional = True
+warn_no_return = True
+warn_redundant_casts = True
+warn_unreachable = True
+warn_unused_configs = True
+warn_unused_ignores = True
+
+[mypy-tianshou.utils.net.*]
+ignore_errors = True
+
[pydocstyle]
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402
diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py
index 21260f5..4036472 100644
--- a/test/throughput/test_collector_profile.py
+++ b/test/throughput/test_collector_profile.py
@@ -100,11 +100,6 @@ def test_collect_ep(data):
data["collector"].collect(n_episode=10)
-def test_sample(data):
- for _ in range(5000):
- data["collector"].sample(256)
-
-
def test_init_vec_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_vec"], data["buffer"])
@@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data):
data["collector_vec"].collect(n_episode=10)
-def test_sample_vec_env(data):
- for _ in range(5000):
- data["collector_vec"].sample(256)
-
-
def test_init_subproc_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
@@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data):
data["collector_subproc"].collect(n_episode=10)
-def test_sample_subproc_env(data):
- for _ in range(5000):
- data["collector_subproc"].sample(256)
-
-
if __name__ == '__main__':
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])
diff --git a/tianshou/__init__.py b/tianshou/__init__.py
index 5ad1c9f..0b9c0e9 100644
--- a/tianshou/__init__.py
+++ b/tianshou/__init__.py
@@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration
-__version__ = "0.2.7"
+__version__ = "0.3.0rc0"
__all__ = [
"env",
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
index a06c0df..fe35160 100644
--- a/tianshou/data/batch.py
+++ b/tianshou/data/batch.py
@@ -6,7 +6,7 @@ from copy import deepcopy
from numbers import Number
from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
- Sequence, KeysView, ValuesView, ItemsView
+ Sequence
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
@@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try:
- return torch.stack(v)
+ return torch.stack(v) # type: ignore
except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e
@@ -191,12 +191,20 @@ class Batch:
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
if len(kwargs) > 0:
- self.__init__(kwargs, copy=copy)
+ self.__init__(kwargs, copy=copy) # type: ignore
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
self.__dict__[key] = _parse_value(value)
+ def __getattr__(self, key: str) -> Any:
+ """Return self.key. The "Any" return type is needed for mypy."""
+ return getattr(self.__dict__, key)
+
+ def __contains__(self, key: str) -> bool:
+ """Return key in self."""
+ return key in self.__dict__
+
def __getstate__(self) -> Dict[str, Any]:
"""Pickling interface.
@@ -215,11 +223,11 @@ class Batch:
At this point, self is an empty Batch instance that has not been
initialized, so it can safely be initialized by the pickle state.
"""
- self.__init__(**state)
+ self.__init__(**state) # type: ignore
def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
- ) -> Union["Batch", np.ndarray, torch.Tensor]:
+ ) -> Any:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
@@ -245,7 +253,7 @@ class Batch:
if isinstance(index, str):
self.__dict__[index] = value
return
- if isinstance(value, (np.ndarray, torch.Tensor)):
+ if not isinstance(value, Batch):
raise ValueError("Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()):
@@ -330,30 +338,6 @@ class Batch:
s = self.__class__.__name__ + "()"
return s
- def __contains__(self, key: str) -> bool:
- """Return key in self."""
- return key in self.__dict__
-
- def keys(self) -> KeysView[str]:
- """Return self.keys()."""
- return self.__dict__.keys()
-
- def values(self) -> ValuesView[Any]:
- """Return self.values()."""
- return self.__dict__.values()
-
- def items(self) -> ItemsView[str, Any]:
- """Return self.items()."""
- return self.__dict__.items()
-
- def get(self, k: str, d: Optional[Any] = None) -> Any:
- """Return self[k] if k in self else d. d defaults to None."""
- return self.__dict__.get(k, d)
-
- def pop(self, k: str, d: Optional[Any] = None) -> Any:
- """Return & remove self[k] if k in self else d. d defaults to None."""
- return self.__dict__.pop(k, d)
-
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
@@ -375,7 +359,6 @@ class Batch:
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
- device.index is not None and \
device.index != v.device.index:
if dtype is not None:
v = v.type(dtype)
@@ -517,7 +500,7 @@ class Batch:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if not self.is_empty():
- batches = [self] + list(batches)
+ batches = [self] + batches
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
@@ -672,8 +655,8 @@ class Batch:
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True):
continue
- elif hasattr(v, "__len__") and (not isinstance(
- v, (np.ndarray, torch.Tensor)) or v.ndim > 0
+ elif hasattr(v, "__len__") and (
+ isinstance(v, Batch) or v.ndim > 0
):
r.append(len(v))
else:
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index ff16085..1e817d9 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -1,7 +1,7 @@
import torch
import numpy as np
from numbers import Number
-from typing import Any, Dict, Tuple, Union, Optional
+from typing import Any, Dict, List, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
@@ -138,7 +138,7 @@ class ReplayBuffer:
self._indices = np.arange(size)
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
- self._avail_index = []
+ self._avail_index: List[int] = []
self._save_s_ = not ignore_obs_next
self._last_obs = save_only_last_obs
self._index = 0
@@ -175,12 +175,12 @@ class ReplayBuffer:
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
- if isinstance(inst, (torch.Tensor, np.ndarray)) \
- and inst.shape != value.shape[1:]:
- raise ValueError(
- "Cannot add data to a buffer with different shape, with key "
- f"{name}, expect {value.shape[1:]}, given {inst.shape}."
- )
+ if isinstance(inst, (torch.Tensor, np.ndarray)):
+ if inst.shape != value.shape[1:]:
+ raise ValueError(
+ "Cannot add data to a buffer with different shape with key"
+ f" {name}, expect {value.shape[1:]}, given {inst.shape}."
+ )
try:
value[self._index] = inst
except KeyError:
@@ -205,7 +205,7 @@ class ReplayBuffer:
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
- self.add(**buffer[i])
+ self.add(**buffer[i]) # type: ignore
i = (i + 1) % len(buffer)
if i == begin:
break
@@ -323,7 +323,7 @@ class ReplayBuffer:
try:
if stack_num == 1:
return val[indice]
- stack = []
+ stack: List[Any] = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
index aacaa9b..5315538 100644
--- a/tianshou/data/collector.py
+++ b/tianshou/data/collector.py
@@ -212,10 +212,8 @@ class Collector(object):
finished_env_ids = []
reward_total = 0.0
whole_data = Batch()
- list_n_episode = False
- if n_episode is not None and not np.isscalar(n_episode):
+ if isinstance(n_episode, list):
assert len(n_episode) == self.get_env_num()
- list_n_episode = True
finished_env_ids = [
i for i in self._ready_env_ids if n_episode[i] <= 0]
self._ready_env_ids = np.array(
@@ -266,7 +264,8 @@ class Collector(object):
self.data.policy._state = self.data.state
self.data.act = to_numpy(result.act)
- if self._action_noise is not None: # noqa
+ if self._action_noise is not None:
+ assert isinstance(self.data.act, np.ndarray)
self.data.act += self._action_noise(self.data.act.shape)
# step in env
@@ -291,7 +290,7 @@ class Collector(object):
# add data into the buffer
if self.preprocess_fn:
- result = self.preprocess_fn(**self.data)
+ result = self.preprocess_fn(**self.data) # type: ignore
self.data.update(result)
for j, i in enumerate(self._ready_env_ids):
@@ -305,14 +304,14 @@ class Collector(object):
self._cached_buf[i].add(**self.data[j])
if done[j]:
- if not (list_n_episode and
- episode_count[i] >= n_episode[i]):
+ if not (isinstance(n_episode, list)
+ and episode_count[i] >= n_episode[i]):
episode_count[i] += 1
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
step_count += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
- if list_n_episode and \
+ if isinstance(n_episode, list) and \
episode_count[i] >= n_episode[i]:
# env i has collected enough data, it has finished
finished_env_ids.append(i)
@@ -324,10 +323,9 @@ class Collector(object):
env_ind_global = self._ready_env_ids[env_ind_local]
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
- obs_next[env_ind_local] = self.preprocess_fn(
+ obs_reset = self.preprocess_fn(
obs=obs_reset).get("obs", obs_reset)
- else:
- obs_next[env_ind_local] = obs_reset
+ obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next
if is_async:
# set data back
@@ -362,7 +360,7 @@ class Collector(object):
# average reward across the number of episodes
reward_avg = reward_total / episode_count
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
- reward_avg = self._rew_metric(reward_avg)
+ reward_avg = self._rew_metric(reward_avg) # type: ignore
return {
"n/ep": episode_count,
"n/st": step_count,
@@ -372,30 +370,6 @@ class Collector(object):
"len": step_count / episode_count,
}
- def sample(self, batch_size: int) -> Batch:
- """Sample a data batch from the internal replay buffer.
-
- It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
- returning the final batch data.
-
- :param int batch_size: ``0`` means it will extract all the data from
- the buffer, otherwise it will extract the data with the given
- batch_size.
- """
- warnings.warn(
- "Collector.sample is deprecated and will cause error if you use "
- "prioritized experience replay! Collector.sample will be removed "
- "upon version 0.3. Use policy.update instead!", Warning)
- assert self.buffer is not None, "Cannot get sample from empty buffer!"
- batch_data, indice = self.buffer.sample(batch_size)
- batch_data = self.process_fn(batch_data, self.buffer, indice)
- return batch_data
-
- def close(self) -> None:
- warnings.warn(
- "Collector.close is deprecated and will be removed upon version "
- "0.3.", Warning)
-
def _batch_set_item(
source: Batch, indices: np.ndarray, target: Batch, size: int
diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py
index 6a5d43e..a8e8845 100644
--- a/tianshou/data/utils/converter.py
+++ b/tianshou/data/utils/converter.py
@@ -45,14 +45,14 @@ def to_torch(
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
): # most often case
- x = torch.from_numpy(x).to(device)
+ x = torch.from_numpy(x).to(device) # type: ignore
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
- return x.to(device)
+ return x.to(device) # type: ignore
elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, dict):
diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py
index d3a49b7..a25e06f 100644
--- a/tianshou/env/__init__.py
+++ b/tianshou/env/__init__.py
@@ -1,11 +1,10 @@
-from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
+from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
from tianshou.env.maenv import MultiAgentEnv
__all__ = [
"BaseVectorEnv",
"DummyVectorEnv",
- "VectorEnv", # TODO: remove in later version
"SubprocVectorEnv",
"ShmemVectorEnv",
"RayVectorEnv",
diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py
index 6b82001..ddf7aca 100644
--- a/tianshou/env/venvs.py
+++ b/tianshou/env/venvs.py
@@ -1,5 +1,4 @@
import gym
-import warnings
import numpy as np
from typing import Any, List, Union, Optional, Callable
@@ -84,12 +83,12 @@ class BaseVectorEnv(gym.Env):
self.timeout is None or self.timeout > 0
), f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None
- self.waiting_conn = []
+ self.waiting_conn: List[EnvWorker] = []
# environments in self.ready_id is actually ready
# but environments in self.waiting_id are just waiting when checked,
# and they may be ready now, but this is not known until we check it
# in the step() function
- self.waiting_id = []
+ self.waiting_id: List[int] = []
# all environments are ready in the beginning
self.ready_id = list(range(self.env_num))
self.is_closed = False
@@ -216,10 +215,11 @@ class BaseVectorEnv(gym.Env):
self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
- ready_conns, result = [], []
+ ready_conns: List[EnvWorker] = []
while not ready_conns:
ready_conns = self.worker_class.wait(
self.waiting_conn, self.wait_num, self.timeout)
+ result = []
for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
@@ -243,11 +243,14 @@ class BaseVectorEnv(gym.Env):
which a reproducer pass to "seed".
"""
self._assert_is_not_closed()
+ seed_list: Union[List[None], List[int]]
if seed is None:
- seed = [seed] * self.env_num
- elif np.isscalar(seed):
- seed = [seed + i for i in range(self.env_num)]
- return [w.seed(s) for w, s in zip(self.workers, seed)]
+ seed_list = [seed] * self.env_num
+ elif isinstance(seed, int):
+ seed_list = [seed + i for i in range(self.env_num)]
+ else:
+ seed_list = seed
+ return [w.seed(s) for w, s in zip(self.workers, seed_list)]
def render(self, **kwargs: Any) -> List[Any]:
"""Render all of the environments."""
@@ -295,16 +298,6 @@ class DummyVectorEnv(BaseVectorEnv):
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
-class VectorEnv(DummyVectorEnv):
- """VectorEnv is renamed to DummyVectorEnv."""
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- warnings.warn(
- "VectorEnv is renamed to DummyVectorEnv, and will be removed in "
- "0.3. Use DummyVectorEnv instead!", Warning)
- super().__init__(*args, **kwargs)
-
-
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.
diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py
index d097260..f13fe37 100644
--- a/tianshou/env/worker/base.py
+++ b/tianshou/env/worker/base.py
@@ -10,7 +10,7 @@ class EnvWorker(ABC):
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn
self.is_closed = False
- self.result = (None, None, None, None)
+ self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
@abstractmethod
def __getattr__(self, key: str) -> Any:
diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py
index 9e88840..9e0c3c5 100644
--- a/tianshou/env/worker/dummy.py
+++ b/tianshou/env/worker/dummy.py
@@ -19,7 +19,7 @@ class DummyEnvWorker(EnvWorker):
return self.env.reset()
@staticmethod
- def wait(
+ def wait( # type: ignore
workers: List["DummyEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py
index 5517388..165e104 100644
--- a/tianshou/env/worker/ray.py
+++ b/tianshou/env/worker/ray.py
@@ -24,7 +24,7 @@ class RayEnvWorker(EnvWorker):
return ray.get(self.env.reset.remote())
@staticmethod
- def wait(
+ def wait( # type: ignore
workers: List["RayEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py
index 4b280c2..b578dd7 100644
--- a/tianshou/env/worker/subproc.py
+++ b/tianshou/env/worker/subproc.py
@@ -11,22 +11,71 @@ from tianshou.env.worker import EnvWorker
from tianshou.env.utils import CloudpickleWrapper
+_NP_TO_CT = {
+ np.bool: ctypes.c_bool,
+ np.bool_: ctypes.c_bool,
+ np.uint8: ctypes.c_uint8,
+ np.uint16: ctypes.c_uint16,
+ np.uint32: ctypes.c_uint32,
+ np.uint64: ctypes.c_uint64,
+ np.int8: ctypes.c_int8,
+ np.int16: ctypes.c_int16,
+ np.int32: ctypes.c_int32,
+ np.int64: ctypes.c_int64,
+ np.float32: ctypes.c_float,
+ np.float64: ctypes.c_double,
+}
+
+
+class ShArray:
+ """Wrapper of multiprocessing Array."""
+
+ def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
+ self.arr = Array(
+ _NP_TO_CT[dtype.type], # type: ignore
+ int(np.prod(shape)),
+ )
+ self.dtype = dtype
+ self.shape = shape
+
+ def save(self, ndarray: np.ndarray) -> None:
+ assert isinstance(ndarray, np.ndarray)
+ dst = self.arr.get_obj()
+ dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
+ np.copyto(dst_np, ndarray)
+
+ def get(self) -> np.ndarray:
+ obj = self.arr.get_obj()
+ return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
+
+
+def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
+ if isinstance(space, gym.spaces.Dict):
+ assert isinstance(space.spaces, OrderedDict)
+ return {k: _setup_buf(v) for k, v in space.spaces.items()}
+ elif isinstance(space, gym.spaces.Tuple):
+ assert isinstance(space.spaces, tuple)
+ return tuple([_setup_buf(t) for t in space.spaces])
+ else:
+ return ShArray(space.dtype, space.shape)
+
+
def _worker(
parent: connection.Connection,
p: connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
- obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None,
+ obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None:
def _encode_obs(
obs: Union[dict, tuple, np.ndarray],
buffer: Union[dict, tuple, ShArray],
) -> None:
- if isinstance(obs, np.ndarray):
+ if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
buffer.save(obs)
- elif isinstance(obs, tuple):
+ elif isinstance(obs, tuple) and isinstance(buffer, tuple):
for o, b in zip(obs, buffer):
_encode_obs(o, b)
- elif isinstance(obs, dict):
+ elif isinstance(obs, dict) and isinstance(buffer, dict):
for k in obs.keys():
_encode_obs(obs[k], buffer[k])
return None
@@ -69,52 +118,6 @@ def _worker(
p.close()
-_NP_TO_CT = {
- np.bool: ctypes.c_bool,
- np.bool_: ctypes.c_bool,
- np.uint8: ctypes.c_uint8,
- np.uint16: ctypes.c_uint16,
- np.uint32: ctypes.c_uint32,
- np.uint64: ctypes.c_uint64,
- np.int8: ctypes.c_int8,
- np.int16: ctypes.c_int16,
- np.int32: ctypes.c_int32,
- np.int64: ctypes.c_int64,
- np.float32: ctypes.c_float,
- np.float64: ctypes.c_double,
-}
-
-
-class ShArray:
- """Wrapper of multiprocessing Array."""
-
- def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
- self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
- self.dtype = dtype
- self.shape = shape
-
- def save(self, ndarray: np.ndarray) -> None:
- assert isinstance(ndarray, np.ndarray)
- dst = self.arr.get_obj()
- dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
- np.copyto(dst_np, ndarray)
-
- def get(self) -> np.ndarray:
- obj = self.arr.get_obj()
- return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
-
-
-def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
- if isinstance(space, gym.spaces.Dict):
- assert isinstance(space.spaces, OrderedDict)
- return {k: _setup_buf(v) for k, v in space.spaces.items()}
- elif isinstance(space, gym.spaces.Tuple):
- assert isinstance(space.spaces, tuple)
- return tuple([_setup_buf(t) for t in space.spaces])
- else:
- return ShArray(space.dtype, space.shape)
-
-
class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
@@ -124,7 +127,7 @@ class SubprocEnvWorker(EnvWorker):
super().__init__(env_fn)
self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory
- self.buffer = None
+ self.buffer: Optional[Union[dict, tuple, ShArray]] = None
if self.share_memory:
dummy = env_fn()
obs_space = dummy.observation_space
@@ -168,25 +171,23 @@ class SubprocEnvWorker(EnvWorker):
return obs
@staticmethod
- def wait(
+ def wait( # type: ignore
workers: List["SubprocEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["SubprocEnvWorker"]:
- conns, ready_conns = [x.parent_remote for x in workers], []
- remain_conns = conns
- t1 = time.time()
+ remain_conns = conns = [x.parent_remote for x in workers]
+ ready_conns: List[connection.Connection] = []
+ remain_time, t1 = timeout, time.time()
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
if timeout:
remain_time = timeout - (time.time() - t1)
if remain_time <= 0:
break
- else:
- remain_time = timeout
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(
remain_conns, timeout=remain_time)
- ready_conns.extend(new_ready_conns)
+ ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [
conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py
index f06b7b5..2e495dc 100644
--- a/tianshou/exploration/random.py
+++ b/tianshou/exploration/random.py
@@ -9,15 +9,15 @@ class BaseNoise(ABC, object):
def __init__(self) -> None:
super().__init__()
+ def reset(self) -> None:
+ """Reset to the initial state."""
+ pass
+
@abstractmethod
def __call__(self, size: Sequence[int]) -> np.ndarray:
"""Generate new noise."""
raise NotImplementedError
- def reset(self) -> None:
- """Reset to the initial state."""
- pass
-
class GaussianNoise(BaseNoise):
"""The vanilla gaussian process, for exploration in DDPG by default."""
@@ -64,6 +64,10 @@ class OUNoise(BaseNoise):
self._x0 = x0
self.reset()
+ def reset(self) -> None:
+ """Reset to the initial state."""
+ self._x = self._x0
+
def __call__(
self, size: Sequence[int], mu: Optional[float] = None
) -> np.ndarray:
@@ -71,14 +75,11 @@ class OUNoise(BaseNoise):
Return an numpy array which size is equal to ``size``.
"""
- if self._x is None or self._x.shape != size:
+ if self._x is None or isinstance(
+ self._x, np.ndarray) and self._x.shape != size:
self._x = 0.0
if mu is None:
mu = self._mu
r = self._beta * np.random.normal(size=size)
self._x = self._x + self._alpha * (mu - self._x) + r
return self._x
-
- def reset(self) -> None:
- """Reset to the initial state."""
- self._x = self._x0
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 0b00585..8bd0fcb 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -190,7 +190,7 @@ class BasePolicy(ABC, nn.Module):
array with shape (bsz, ).
"""
rew = batch.rew
- v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
+ v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py
index 13dda4e..d65cbc8 100644
--- a/tianshou/policy/imitation/base.py
+++ b/tianshou/policy/imitation/base.py
@@ -55,11 +55,11 @@ class ImitationPolicy(BasePolicy):
if self.mode == "continuous": # regression
a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
- loss = F.mse_loss(a, a_)
+ loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification
a = self(batch).logits
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
- loss = F.nll_loss(a, a_)
+ loss = F.nll_loss(a, a_) # type: ignore
loss.backward()
self.optim.step()
return {"loss": loss.item()}
diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py
index ae7cfa7..df34e01 100644
--- a/tianshou/policy/modelfree/a2c.py
+++ b/tianshou/policy/modelfree/a2c.py
@@ -103,11 +103,11 @@ class A2CPolicy(PGPolicy):
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
- dist = self.dist_fn(logits)
+ dist = self.dist_fn(logits) # type: ignore
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
- def learn(
+ def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
@@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy):
r = to_torch_as(b.returns, v)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
a_loss = -(log_prob * (r - v).detach()).mean()
- vf_loss = F.mse_loss(r, v)
+ vf_loss = F.mse_loss(r, v) # type: ignore
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index 3770c64..f19f56e 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -53,14 +53,16 @@ class DDPGPolicy(BasePolicy):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
- if actor is not None:
- self.actor, self.actor_old = actor, deepcopy(actor)
+ if actor is not None and actor_optim is not None:
+ self.actor: torch.nn.Module = actor
+ self.actor_old = deepcopy(actor)
self.actor_old.eval()
- self.actor_optim = actor_optim
- if critic is not None:
- self.critic, self.critic_old = critic, deepcopy(critic)
+ self.actor_optim: torch.optim.Optimizer = actor_optim
+ if critic is not None and critic_optim is not None:
+ self.critic: torch.nn.Module = critic
+ self.critic_old = deepcopy(critic)
self.critic_old.eval()
- self.critic_optim = critic_optim
+ self.critic_optim: torch.optim.Optimizer = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self._tau = tau
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
@@ -141,7 +143,7 @@ class DDPGPolicy(BasePolicy):
obs = getattr(batch, input)
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
- if self.training and explorating:
+ if self._noise and self.training and explorating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index c6ba4fa..070f11d 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -146,11 +146,10 @@ class DQNPolicy(BasePolicy):
obs = getattr(batch, input)
obs_ = obs.obs if hasattr(obs, "obs") else obs
q, h = model(obs_, state=state, info=batch.info)
- act = to_numpy(q.max(dim=1)[1])
- has_mask = hasattr(obs, 'mask')
- if has_mask:
+ act: np.ndarray = to_numpy(q.max(dim=1)[1])
+ if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
- q_ = to_numpy(q)
+ q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act
@@ -160,7 +159,7 @@ class DQNPolicy(BasePolicy):
for i in range(len(q)):
if np.random.rand() < eps:
q_ = np.random.rand(*q[i].shape)
- if has_mask:
+ if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax()
return Batch(logits=q, act=act, state=h)
@@ -172,7 +171,7 @@ class DQNPolicy(BasePolicy):
weight = batch.pop("weight", 1.0)
q = self(batch, eps=0.0).logits
q = q[np.arange(len(q)), batch.act]
- r = to_torch_as(batch.returns, q).flatten()
+ r = to_torch_as(batch.returns.flatten(), q)
td = r - q
loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py
index e93b4f9..2f33046 100644
--- a/tianshou/policy/modelfree/pg.py
+++ b/tianshou/policy/modelfree/pg.py
@@ -32,7 +32,8 @@ class PGPolicy(BasePolicy):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
- self.model = model
+ if model is not None:
+ self.model: torch.nn.Module = model
self.optim = optim
self.dist_fn = dist_fn
assert (
@@ -81,11 +82,11 @@ class PGPolicy(BasePolicy):
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
- dist = self.dist_fn(logits)
+ dist = self.dist_fn(logits) # type: ignore
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
- def learn(
+ def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses = []
diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py
index 349ab9b..60bb19e 100644
--- a/tianshou/policy/modelfree/ppo.py
+++ b/tianshou/policy/modelfree/ppo.py
@@ -137,13 +137,13 @@ class PPOPolicy(PGPolicy):
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
- dist = self.dist_fn(logits)
+ dist = self.dist_fn(logits) # type: ignore
act = dist.sample()
if self._range:
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
- def learn(
+ def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
@@ -157,8 +157,9 @@ class PPOPolicy(PGPolicy):
surr2 = ratio.clamp(1.0 - self._eps_clip,
1.0 + self._eps_clip) * b.adv
if self._dual_clip:
- clip_loss = -torch.max(torch.min(surr1, surr2),
- self._dual_clip * b.adv).mean()
+ clip_loss = -torch.max(
+ torch.min(surr1, surr2), self._dual_clip * b.adv
+ ).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item())
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
index b596eb5..83c7150 100644
--- a/tianshou/policy/modelfree/sac.py
+++ b/tianshou/policy/modelfree/sac.py
@@ -107,7 +107,7 @@ class SACPolicy(DDPGPolicy):
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
- def forward(
+ def forward( # type: ignore
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
@@ -193,5 +193,5 @@ class SACPolicy(DDPGPolicy):
}
if self._is_auto_alpha:
result["loss/alpha"] = alpha_loss.item()
- result["v/alpha"] = self._alpha.item()
+ result["v/alpha"] = self._alpha.item() # type: ignore
return result
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
index bbf5233..01e6530 100644
--- a/tianshou/trainer/offpolicy.py
+++ b/tianshou/trainer/offpolicy.py
@@ -73,7 +73,7 @@ def offpolicy_trainer(
"""
global_step = 0
best_epoch, best_reward = -1, -1.0
- stat = {}
+ stat: Dict[str, MovAvg] = {}
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
@@ -91,7 +91,7 @@ def offpolicy_trainer(
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step)
- if stop_fn and stop_fn(test_result["rew"]):
+ if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
for k in result.keys():
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
index ac97ba7..37f4278 100644
--- a/tianshou/trainer/onpolicy.py
+++ b/tianshou/trainer/onpolicy.py
@@ -73,7 +73,7 @@ def onpolicy_trainer(
"""
global_step = 0
best_epoch, best_reward = -1, -1.0
- stat = {}
+ stat: Dict[str, MovAvg] = {}
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
@@ -91,7 +91,7 @@ def onpolicy_trainer(
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step)
- if stop_fn and stop_fn(test_result["rew"]):
+ if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
for k in result.keys():
@@ -109,9 +109,9 @@ def onpolicy_trainer(
batch_size=batch_size, repeat=repeat_per_collect)
train_collector.reset_buffer()
step = 1
- for k in losses.keys():
- if isinstance(losses[k], list):
- step = max(step, len(losses[k]))
+ for v in losses.values():
+ if isinstance(v, list):
+ step = max(step, len(v))
global_step += step * collect_per_step
for k in result.keys():
data[k] = f"{result[k]:.2f}"
diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py
index 5f6698e..0c5d2dd 100644
--- a/tianshou/trainer/utils.py
+++ b/tianshou/trainer/utils.py
@@ -22,7 +22,7 @@ def test_episode(
policy.eval()
if test_fn:
test_fn(epoch)
- if collector.get_env_num() > 1 and np.isscalar(n_episode):
+ if collector.get_env_num() > 1 and isinstance(n_episode, int):
n = collector.get_env_num()
n_ = np.zeros(n) + n_episode // n
n_[:n_episode % n] += 1
diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py
index 58c6792..58e8860 100644
--- a/tianshou/utils/moving_average.py
+++ b/tianshou/utils/moving_average.py
@@ -1,7 +1,7 @@
import torch
import numpy as np
from numbers import Number
-from typing import Union
+from typing import List, Union
from tianshou.data import to_numpy
@@ -28,7 +28,7 @@ class MovAvg(object):
def __init__(self, size: int = 100) -> None:
super().__init__()
self.size = size
- self.cache = []
+ self.cache: List[Union[Number, np.number]] = []
self.banned = [np.inf, np.nan, -np.inf]
def add(
diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py
index 96023fe..8b52a87 100644
--- a/tianshou/utils/net/common.py
+++ b/tianshou/utils/net/common.py
@@ -12,7 +12,7 @@ def miniblock(
norm_layer: Optional[Callable[[int], nn.modules.Module]],
) -> List[nn.modules.Module]:
"""Construct a miniblock with given input/output-size and norm layer."""
- ret = [nn.Linear(inp, oup)]
+ ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
if norm_layer is not None:
ret += [norm_layer(oup)]
ret += [nn.ReLU(inplace=True)]
@@ -54,36 +54,33 @@ class Net(nn.Module):
if concat:
input_size += np.prod(action_shape)
- self.model = miniblock(input_size, hidden_layer_size, norm_layer)
+ model = miniblock(input_size, hidden_layer_size, norm_layer)
for i in range(layer_num):
- self.model += miniblock(hidden_layer_size,
- hidden_layer_size, norm_layer)
+ model += miniblock(
+ hidden_layer_size, hidden_layer_size, norm_layer)
- if self.dueling is None:
+ if dueling is None:
if action_shape and not concat:
- self.model += [nn.Linear(hidden_layer_size,
- np.prod(action_shape))]
+ model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
else: # dueling DQN
- assert isinstance(self.dueling, tuple) and len(self.dueling) == 2
-
- q_layer_num, v_layer_num = self.dueling
- self.Q, self.V = [], []
+ q_layer_num, v_layer_num = dueling
+ Q, V = [], []
for i in range(q_layer_num):
- self.Q += miniblock(hidden_layer_size,
- hidden_layer_size, norm_layer)
+ Q += miniblock(
+ hidden_layer_size, hidden_layer_size, norm_layer)
for i in range(v_layer_num):
- self.V += miniblock(hidden_layer_size,
- hidden_layer_size, norm_layer)
+ V += miniblock(
+ hidden_layer_size, hidden_layer_size, norm_layer)
if action_shape and not concat:
- self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
- self.V += [nn.Linear(hidden_layer_size, 1)]
+ Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
+ V += [nn.Linear(hidden_layer_size, 1)]
- self.Q = nn.Sequential(*self.Q)
- self.V = nn.Sequential(*self.V)
- self.model = nn.Sequential(*self.model)
+ self.Q = nn.Sequential(*Q)
+ self.V = nn.Sequential(*V)
+ self.model = nn.Sequential(*model)
def forward(
self,
diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py
index 85f03d2..7afd9eb 100644
--- a/tianshou/utils/net/continuous.py
+++ b/tianshou/utils/net/continuous.py
@@ -66,7 +66,7 @@ class Critic(nn.Module):
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
if a is not None:
- a = to_torch(a, device=self.device, dtype=torch.float32)
+ a = to_torch_as(a, s)
a = a.flatten(1)
s = torch.cat([s, a], dim=1)
logits, h = self.preprocess(s)
diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py
index 6734316..58e44a7 100644
--- a/tianshou/utils/net/discrete.py
+++ b/tianshou/utils/net/discrete.py
@@ -4,6 +4,8 @@ from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence
+from tianshou.data import to_torch
+
class Actor(nn.Module):
"""Simple actor network with MLP.
@@ -118,5 +120,5 @@ class DQN(nn.Module):
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
if not isinstance(x, torch.Tensor):
- x = torch.tensor(x, device=self.device, dtype=torch.float32)
+ x = to_torch(x, device=self.device, dtype=torch.float32)
return self.net(x), state