type check in unit test (#200)

Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml.

Also remove the out-of-the-date api
This commit is contained in:
n+e 2020-09-13 19:31:50 +08:00 committed by GitHub
parent c91def6cbc
commit b284ace102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 224 additions and 249 deletions

View File

@ -1,4 +1,4 @@
name: PEP8 and Docs Check
name: PEP8, Types and Docs Check
on: [push, pull_request]
@ -20,6 +20,9 @@ jobs:
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Type check
run: |
mypy
- name: Documentation test
run: |
pydocstyle tianshou

View File

@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run:
$ flake8 . --count --show-source --statistics
Type Check
----------
We use `mypy <https://github.com/python/mypy/>`_ to check the type annotations. To check, in the main directory, run:
.. code-block:: bash
$ mypy
Test Locally
------------

View File

@ -31,6 +31,7 @@ Here is Tianshou's other features:
* Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe`
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_

View File

@ -1,3 +1,26 @@
[mypy]
files = tianshou/**/*.py
allow_redefinition = True
check_untyped_defs = True
disallow_incomplete_defs = True
disallow_untyped_defs = True
ignore_missing_imports = True
no_implicit_optional = True
pretty = True
show_error_codes = True
show_error_context = True
show_traceback = True
strict_equality = True
strict_optional = True
warn_no_return = True
warn_redundant_casts = True
warn_unreachable = True
warn_unused_configs = True
warn_unused_ignores = True
[mypy-tianshou.utils.net.*]
ignore_errors = True
[pydocstyle]
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402

View File

@ -100,11 +100,6 @@ def test_collect_ep(data):
data["collector"].collect(n_episode=10)
def test_sample(data):
for _ in range(5000):
data["collector"].sample(256)
def test_init_vec_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_vec"], data["buffer"])
@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data):
data["collector_vec"].collect(n_episode=10)
def test_sample_vec_env(data):
for _ in range(5000):
data["collector_vec"].sample(256)
def test_init_subproc_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data):
data["collector_subproc"].collect(n_episode=10)
def test_sample_subproc_env(data):
for _ in range(5000):
data["collector_subproc"].sample(256)
if __name__ == '__main__':
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration
__version__ = "0.2.7"
__version__ = "0.3.0rc0"
__all__ = [
"env",

View File

@ -6,7 +6,7 @@ from copy import deepcopy
from numbers import Number
from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
Sequence, KeysView, ValuesView, ItemsView
Sequence
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try:
return torch.stack(v)
return torch.stack(v) # type: ignore
except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e
@ -191,12 +191,20 @@ class Batch:
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy)
self.__init__(kwargs, copy=copy) # type: ignore
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
self.__dict__[key] = _parse_value(value)
def __getattr__(self, key: str) -> Any:
"""Return self.key. The "Any" return type is needed for mypy."""
return getattr(self.__dict__, key)
def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__
def __getstate__(self) -> Dict[str, Any]:
"""Pickling interface.
@ -215,11 +223,11 @@ class Batch:
At this point, self is an empty Batch instance that has not been
initialized, so it can safely be initialized by the pickle state.
"""
self.__init__(**state)
self.__init__(**state) # type: ignore
def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
) -> Union["Batch", np.ndarray, torch.Tensor]:
) -> Any:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
@ -245,7 +253,7 @@ class Batch:
if isinstance(index, str):
self.__dict__[index] = value
return
if isinstance(value, (np.ndarray, torch.Tensor)):
if not isinstance(value, Batch):
raise ValueError("Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()):
@ -330,30 +338,6 @@ class Batch:
s = self.__class__.__name__ + "()"
return s
def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__
def keys(self) -> KeysView[str]:
"""Return self.keys()."""
return self.__dict__.keys()
def values(self) -> ValuesView[Any]:
"""Return self.values()."""
return self.__dict__.values()
def items(self) -> ItemsView[str, Any]:
"""Return self.items()."""
return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Any:
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)
def pop(self, k: str, d: Optional[Any] = None) -> Any:
"""Return & remove self[k] if k in self else d. d defaults to None."""
return self.__dict__.pop(k, d)
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
@ -375,7 +359,6 @@ class Batch:
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index is not None and \
device.index != v.device.index:
if dtype is not None:
v = v.type(dtype)
@ -517,7 +500,7 @@ class Batch:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if not self.is_empty():
batches = [self] + list(batches)
batches = [self] + batches
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
@ -672,8 +655,8 @@ class Batch:
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True):
continue
elif hasattr(v, "__len__") and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0
elif hasattr(v, "__len__") and (
isinstance(v, Batch) or v.ndim > 0
):
r.append(len(v))
else:

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from numbers import Number
from typing import Any, Dict, Tuple, Union, Optional
from typing import Any, Dict, List, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
@ -138,7 +138,7 @@ class ReplayBuffer:
self._indices = np.arange(size)
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._avail_index: List[int] = []
self._save_s_ = not ignore_obs_next
self._last_obs = save_only_last_obs
self._index = 0
@ -175,12 +175,12 @@ class ReplayBuffer:
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, (torch.Tensor, np.ndarray)) \
and inst.shape != value.shape[1:]:
raise ValueError(
"Cannot add data to a buffer with different shape, with key "
f"{name}, expect {value.shape[1:]}, given {inst.shape}."
)
if isinstance(inst, (torch.Tensor, np.ndarray)):
if inst.shape != value.shape[1:]:
raise ValueError(
"Cannot add data to a buffer with different shape with key"
f" {name}, expect {value.shape[1:]}, given {inst.shape}."
)
try:
value[self._index] = inst
except KeyError:
@ -205,7 +205,7 @@ class ReplayBuffer:
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i])
self.add(**buffer[i]) # type: ignore
i = (i + 1) % len(buffer)
if i == begin:
break
@ -323,7 +323,7 @@ class ReplayBuffer:
try:
if stack_num == 1:
return val[indice]
stack = []
stack: List[Any] = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)

View File

@ -212,10 +212,8 @@ class Collector(object):
finished_env_ids = []
reward_total = 0.0
whole_data = Batch()
list_n_episode = False
if n_episode is not None and not np.isscalar(n_episode):
if isinstance(n_episode, list):
assert len(n_episode) == self.get_env_num()
list_n_episode = True
finished_env_ids = [
i for i in self._ready_env_ids if n_episode[i] <= 0]
self._ready_env_ids = np.array(
@ -266,7 +264,8 @@ class Collector(object):
self.data.policy._state = self.data.state
self.data.act = to_numpy(result.act)
if self._action_noise is not None: # noqa
if self._action_noise is not None:
assert isinstance(self.data.act, np.ndarray)
self.data.act += self._action_noise(self.data.act.shape)
# step in env
@ -291,7 +290,7 @@ class Collector(object):
# add data into the buffer
if self.preprocess_fn:
result = self.preprocess_fn(**self.data)
result = self.preprocess_fn(**self.data) # type: ignore
self.data.update(result)
for j, i in enumerate(self._ready_env_ids):
@ -305,14 +304,14 @@ class Collector(object):
self._cached_buf[i].add(**self.data[j])
if done[j]:
if not (list_n_episode and
episode_count[i] >= n_episode[i]):
if not (isinstance(n_episode, list)
and episode_count[i] >= n_episode[i]):
episode_count[i] += 1
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
step_count += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
if list_n_episode and \
if isinstance(n_episode, list) and \
episode_count[i] >= n_episode[i]:
# env i has collected enough data, it has finished
finished_env_ids.append(i)
@ -324,10 +323,9 @@ class Collector(object):
env_ind_global = self._ready_env_ids[env_ind_local]
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_next[env_ind_local] = self.preprocess_fn(
obs_reset = self.preprocess_fn(
obs=obs_reset).get("obs", obs_reset)
else:
obs_next[env_ind_local] = obs_reset
obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next
if is_async:
# set data back
@ -362,7 +360,7 @@ class Collector(object):
# average reward across the number of episodes
reward_avg = reward_total / episode_count
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
reward_avg = self._rew_metric(reward_avg)
reward_avg = self._rew_metric(reward_avg) # type: ignore
return {
"n/ep": episode_count,
"n/st": step_count,
@ -372,30 +370,6 @@ class Collector(object):
"len": step_count / episode_count,
}
def sample(self, batch_size: int) -> Batch:
"""Sample a data batch from the internal replay buffer.
It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
returning the final batch data.
:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
warnings.warn(
"Collector.sample is deprecated and will cause error if you use "
"prioritized experience replay! Collector.sample will be removed "
"upon version 0.3. Use policy.update instead!", Warning)
assert self.buffer is not None, "Cannot get sample from empty buffer!"
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data
def close(self) -> None:
warnings.warn(
"Collector.close is deprecated and will be removed upon version "
"0.3.", Warning)
def _batch_set_item(
source: Batch, indices: np.ndarray, target: Batch, size: int

View File

@ -45,14 +45,14 @@ def to_torch(
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
): # most often case
x = torch.from_numpy(x).to(device)
x = torch.from_numpy(x).to(device) # type: ignore
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
return x.to(device)
return x.to(device) # type: ignore
elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, dict):

View File

@ -1,11 +1,10 @@
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
from tianshou.env.maenv import MultiAgentEnv
__all__ = [
"BaseVectorEnv",
"DummyVectorEnv",
"VectorEnv", # TODO: remove in later version
"SubprocVectorEnv",
"ShmemVectorEnv",
"RayVectorEnv",

29
tianshou/env/venvs.py vendored
View File

@ -1,5 +1,4 @@
import gym
import warnings
import numpy as np
from typing import Any, List, Union, Optional, Callable
@ -84,12 +83,12 @@ class BaseVectorEnv(gym.Env):
self.timeout is None or self.timeout > 0
), f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None
self.waiting_conn = []
self.waiting_conn: List[EnvWorker] = []
# environments in self.ready_id is actually ready
# but environments in self.waiting_id are just waiting when checked,
# and they may be ready now, but this is not known until we check it
# in the step() function
self.waiting_id = []
self.waiting_id: List[int] = []
# all environments are ready in the beginning
self.ready_id = list(range(self.env_num))
self.is_closed = False
@ -216,10 +215,11 @@ class BaseVectorEnv(gym.Env):
self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
ready_conns, result = [], []
ready_conns: List[EnvWorker] = []
while not ready_conns:
ready_conns = self.worker_class.wait(
self.waiting_conn, self.wait_num, self.timeout)
result = []
for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
@ -243,11 +243,14 @@ class BaseVectorEnv(gym.Env):
which a reproducer pass to "seed".
"""
self._assert_is_not_closed()
seed_list: Union[List[None], List[int]]
if seed is None:
seed = [seed] * self.env_num
elif np.isscalar(seed):
seed = [seed + i for i in range(self.env_num)]
return [w.seed(s) for w, s in zip(self.workers, seed)]
seed_list = [seed] * self.env_num
elif isinstance(seed, int):
seed_list = [seed + i for i in range(self.env_num)]
else:
seed_list = seed
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
def render(self, **kwargs: Any) -> List[Any]:
"""Render all of the environments."""
@ -295,16 +298,6 @@ class DummyVectorEnv(BaseVectorEnv):
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
class VectorEnv(DummyVectorEnv):
"""VectorEnv is renamed to DummyVectorEnv."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
"VectorEnv is renamed to DummyVectorEnv, and will be removed in "
"0.3. Use DummyVectorEnv instead!", Warning)
super().__init__(*args, **kwargs)
class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess.

View File

@ -10,7 +10,7 @@ class EnvWorker(ABC):
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn
self.is_closed = False
self.result = (None, None, None, None)
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
@abstractmethod
def __getattr__(self, key: str) -> Any:

View File

@ -19,7 +19,7 @@ class DummyEnvWorker(EnvWorker):
return self.env.reset()
@staticmethod
def wait(
def wait( # type: ignore
workers: List["DummyEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,

View File

@ -24,7 +24,7 @@ class RayEnvWorker(EnvWorker):
return ray.get(self.env.reset.remote())
@staticmethod
def wait(
def wait( # type: ignore
workers: List["RayEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,

View File

@ -11,22 +11,71 @@ from tianshou.env.worker import EnvWorker
from tianshou.env.utils import CloudpickleWrapper
_NP_TO_CT = {
np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(
_NP_TO_CT[dtype.type], # type: ignore
int(np.prod(shape)),
)
self.dtype = dtype
self.shape = shape
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
else:
return ShArray(space.dtype, space.shape)
def _worker(
parent: connection.Connection,
p: connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None,
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None:
def _encode_obs(
obs: Union[dict, tuple, np.ndarray],
buffer: Union[dict, tuple, ShArray],
) -> None:
if isinstance(obs, np.ndarray):
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
buffer.save(obs)
elif isinstance(obs, tuple):
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
for o, b in zip(obs, buffer):
_encode_obs(o, b)
elif isinstance(obs, dict):
elif isinstance(obs, dict) and isinstance(buffer, dict):
for k in obs.keys():
_encode_obs(obs[k], buffer[k])
return None
@ -69,52 +118,6 @@ def _worker(
p.close()
_NP_TO_CT = {
np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
else:
return ShArray(space.dtype, space.shape)
class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
@ -124,7 +127,7 @@ class SubprocEnvWorker(EnvWorker):
super().__init__(env_fn)
self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory
self.buffer = None
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
if self.share_memory:
dummy = env_fn()
obs_space = dummy.observation_space
@ -168,25 +171,23 @@ class SubprocEnvWorker(EnvWorker):
return obs
@staticmethod
def wait(
def wait( # type: ignore
workers: List["SubprocEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["SubprocEnvWorker"]:
conns, ready_conns = [x.parent_remote for x in workers], []
remain_conns = conns
t1 = time.time()
remain_conns = conns = [x.parent_remote for x in workers]
ready_conns: List[connection.Connection] = []
remain_time, t1 = timeout, time.time()
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
if timeout:
remain_time = timeout - (time.time() - t1)
if remain_time <= 0:
break
else:
remain_time = timeout
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(
remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns)
ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [
conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]

View File

@ -9,15 +9,15 @@ class BaseNoise(ABC, object):
def __init__(self) -> None:
super().__init__()
def reset(self) -> None:
"""Reset to the initial state."""
pass
@abstractmethod
def __call__(self, size: Sequence[int]) -> np.ndarray:
"""Generate new noise."""
raise NotImplementedError
def reset(self) -> None:
"""Reset to the initial state."""
pass
class GaussianNoise(BaseNoise):
"""The vanilla gaussian process, for exploration in DDPG by default."""
@ -64,6 +64,10 @@ class OUNoise(BaseNoise):
self._x0 = x0
self.reset()
def reset(self) -> None:
"""Reset to the initial state."""
self._x = self._x0
def __call__(
self, size: Sequence[int], mu: Optional[float] = None
) -> np.ndarray:
@ -71,14 +75,11 @@ class OUNoise(BaseNoise):
Return an numpy array which size is equal to ``size``.
"""
if self._x is None or self._x.shape != size:
if self._x is None or isinstance(
self._x, np.ndarray) and self._x.shape != size:
self._x = 0.0
if mu is None:
mu = self._mu
r = self._beta * np.random.normal(size=size)
self._x = self._x + self._alpha * (mu - self._x) + r
return self._x
def reset(self) -> None:
"""Reset to the initial state."""
self._x = self._x0

View File

@ -190,7 +190,7 @@ class BasePolicy(ABC, nn.Module):
array with shape (bsz, ).
"""
rew = batch.rew
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()

View File

@ -55,11 +55,11 @@ class ImitationPolicy(BasePolicy):
if self.mode == "continuous": # regression
a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_)
loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification
a = self(batch).logits
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_)
loss = F.nll_loss(a, a_) # type: ignore
loss.backward()
self.optim.step()
return {"loss": loss.item()}

View File

@ -103,11 +103,11 @@ class A2CPolicy(PGPolicy):
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
dist = self.dist_fn(logits) # type: ignore
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy):
r = to_torch_as(b.returns, v)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
a_loss = -(log_prob * (r - v).detach()).mean()
vf_loss = F.mse_loss(r, v)
vf_loss = F.mse_loss(r, v) # type: ignore
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()

View File

@ -53,14 +53,16 @@ class DDPGPolicy(BasePolicy):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)
self.actor_old.eval()
self.actor_optim = actor_optim
if critic is not None:
self.critic, self.critic_old = critic, deepcopy(critic)
self.actor_optim: torch.optim.Optimizer = actor_optim
if critic is not None and critic_optim is not None:
self.critic: torch.nn.Module = critic
self.critic_old = deepcopy(critic)
self.critic_old.eval()
self.critic_optim = critic_optim
self.critic_optim: torch.optim.Optimizer = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self._tau = tau
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
@ -141,7 +143,7 @@ class DDPGPolicy(BasePolicy):
obs = getattr(batch, input)
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self.training and explorating:
if self._noise and self.training and explorating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)

View File

@ -146,11 +146,10 @@ class DQNPolicy(BasePolicy):
obs = getattr(batch, input)
obs_ = obs.obs if hasattr(obs, "obs") else obs
q, h = model(obs_, state=state, info=batch.info)
act = to_numpy(q.max(dim=1)[1])
has_mask = hasattr(obs, 'mask')
if has_mask:
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_ = to_numpy(q)
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act
@ -160,7 +159,7 @@ class DQNPolicy(BasePolicy):
for i in range(len(q)):
if np.random.rand() < eps:
q_ = np.random.rand(*q[i].shape)
if has_mask:
if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax()
return Batch(logits=q, act=act, state=h)
@ -172,7 +171,7 @@ class DQNPolicy(BasePolicy):
weight = batch.pop("weight", 1.0)
q = self(batch, eps=0.0).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
r = to_torch_as(batch.returns.flatten(), q)
td = r - q
loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer

View File

@ -32,7 +32,8 @@ class PGPolicy(BasePolicy):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
if model is not None:
self.model: torch.nn.Module = model
self.optim = optim
self.dist_fn = dist_fn
assert (
@ -81,11 +82,11 @@ class PGPolicy(BasePolicy):
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
dist = self.dist_fn(logits) # type: ignore
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses = []

View File

@ -137,13 +137,13 @@ class PPOPolicy(PGPolicy):
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
dist = self.dist_fn(logits) # type: ignore
act = dist.sample()
if self._range:
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
@ -157,8 +157,9 @@ class PPOPolicy(PGPolicy):
surr2 = ratio.clamp(1.0 - self._eps_clip,
1.0 + self._eps_clip) * b.adv
if self._dual_clip:
clip_loss = -torch.max(torch.min(surr1, surr2),
self._dual_clip * b.adv).mean()
clip_loss = -torch.max(
torch.min(surr1, surr2), self._dual_clip * b.adv
).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item())

View File

@ -107,7 +107,7 @@ class SACPolicy(DDPGPolicy):
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def forward(
def forward( # type: ignore
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
@ -193,5 +193,5 @@ class SACPolicy(DDPGPolicy):
}
if self._is_auto_alpha:
result["loss/alpha"] = alpha_loss.item()
result["v/alpha"] = self._alpha.item()
result["v/alpha"] = self._alpha.item() # type: ignore
return result

View File

@ -73,7 +73,7 @@ def offpolicy_trainer(
"""
global_step = 0
best_epoch, best_reward = -1, -1.0
stat = {}
stat: Dict[str, MovAvg] = {}
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
@ -91,7 +91,7 @@ def offpolicy_trainer(
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step)
if stop_fn and stop_fn(test_result["rew"]):
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
for k in result.keys():

View File

@ -73,7 +73,7 @@ def onpolicy_trainer(
"""
global_step = 0
best_epoch, best_reward = -1, -1.0
stat = {}
stat: Dict[str, MovAvg] = {}
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
@ -91,7 +91,7 @@ def onpolicy_trainer(
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step)
if stop_fn and stop_fn(test_result["rew"]):
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
for k in result.keys():
@ -109,9 +109,9 @@ def onpolicy_trainer(
batch_size=batch_size, repeat=repeat_per_collect)
train_collector.reset_buffer()
step = 1
for k in losses.keys():
if isinstance(losses[k], list):
step = max(step, len(losses[k]))
for v in losses.values():
if isinstance(v, list):
step = max(step, len(v))
global_step += step * collect_per_step
for k in result.keys():
data[k] = f"{result[k]:.2f}"

View File

@ -22,7 +22,7 @@ def test_episode(
policy.eval()
if test_fn:
test_fn(epoch)
if collector.get_env_num() > 1 and np.isscalar(n_episode):
if collector.get_env_num() > 1 and isinstance(n_episode, int):
n = collector.get_env_num()
n_ = np.zeros(n) + n_episode // n
n_[:n_episode % n] += 1

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from numbers import Number
from typing import Union
from typing import List, Union
from tianshou.data import to_numpy
@ -28,7 +28,7 @@ class MovAvg(object):
def __init__(self, size: int = 100) -> None:
super().__init__()
self.size = size
self.cache = []
self.cache: List[Union[Number, np.number]] = []
self.banned = [np.inf, np.nan, -np.inf]
def add(

View File

@ -12,7 +12,7 @@ def miniblock(
norm_layer: Optional[Callable[[int], nn.modules.Module]],
) -> List[nn.modules.Module]:
"""Construct a miniblock with given input/output-size and norm layer."""
ret = [nn.Linear(inp, oup)]
ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
if norm_layer is not None:
ret += [norm_layer(oup)]
ret += [nn.ReLU(inplace=True)]
@ -54,36 +54,33 @@ class Net(nn.Module):
if concat:
input_size += np.prod(action_shape)
self.model = miniblock(input_size, hidden_layer_size, norm_layer)
model = miniblock(input_size, hidden_layer_size, norm_layer)
for i in range(layer_num):
self.model += miniblock(hidden_layer_size,
hidden_layer_size, norm_layer)
model += miniblock(
hidden_layer_size, hidden_layer_size, norm_layer)
if self.dueling is None:
if dueling is None:
if action_shape and not concat:
self.model += [nn.Linear(hidden_layer_size,
np.prod(action_shape))]
model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
else: # dueling DQN
assert isinstance(self.dueling, tuple) and len(self.dueling) == 2
q_layer_num, v_layer_num = self.dueling
self.Q, self.V = [], []
q_layer_num, v_layer_num = dueling
Q, V = [], []
for i in range(q_layer_num):
self.Q += miniblock(hidden_layer_size,
hidden_layer_size, norm_layer)
Q += miniblock(
hidden_layer_size, hidden_layer_size, norm_layer)
for i in range(v_layer_num):
self.V += miniblock(hidden_layer_size,
hidden_layer_size, norm_layer)
V += miniblock(
hidden_layer_size, hidden_layer_size, norm_layer)
if action_shape and not concat:
self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
self.V += [nn.Linear(hidden_layer_size, 1)]
Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
V += [nn.Linear(hidden_layer_size, 1)]
self.Q = nn.Sequential(*self.Q)
self.V = nn.Sequential(*self.V)
self.model = nn.Sequential(*self.model)
self.Q = nn.Sequential(*Q)
self.V = nn.Sequential(*V)
self.model = nn.Sequential(*model)
def forward(
self,

View File

@ -66,7 +66,7 @@ class Critic(nn.Module):
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
if a is not None:
a = to_torch(a, device=self.device, dtype=torch.float32)
a = to_torch_as(a, s)
a = a.flatten(1)
s = torch.cat([s, a], dim=1)
logits, h = self.preprocess(s)

View File

@ -4,6 +4,8 @@ from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch
class Actor(nn.Module):
"""Simple actor network with MLP.
@ -118,5 +120,5 @@ class DQN(nn.Module):
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = to_torch(x, device=self.device, dtype=torch.float32)
return self.net(x), state