type check in unit test (#200)

Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml.

Also remove the out-of-the-date api
This commit is contained in:
n+e 2020-09-13 19:31:50 +08:00 committed by GitHub
parent c91def6cbc
commit b284ace102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 224 additions and 249 deletions

View File

@ -1,4 +1,4 @@
name: PEP8 and Docs Check name: PEP8, Types and Docs Check
on: [push, pull_request] on: [push, pull_request]
@ -20,6 +20,9 @@ jobs:
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
flake8 . --count --show-source --statistics flake8 . --count --show-source --statistics
- name: Type check
run: |
mypy
- name: Documentation test - name: Documentation test
run: | run: |
pydocstyle tianshou pydocstyle tianshou

View File

@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run:
$ flake8 . --count --show-source --statistics $ flake8 . --count --show-source --statistics
Type Check
----------
We use `mypy <https://github.com/python/mypy/>`_ to check the type annotations. To check, in the main directory, run:
.. code-block:: bash
$ mypy
Test Locally Test Locally
------------ ------------

View File

@ -31,6 +31,7 @@ Here is Tianshou's other features:
* Support :ref:`customize_training` * Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe` * Support :doc:`/tutorials/tictactoe`
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_ 中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_

View File

@ -1,3 +1,26 @@
[mypy]
files = tianshou/**/*.py
allow_redefinition = True
check_untyped_defs = True
disallow_incomplete_defs = True
disallow_untyped_defs = True
ignore_missing_imports = True
no_implicit_optional = True
pretty = True
show_error_codes = True
show_error_context = True
show_traceback = True
strict_equality = True
strict_optional = True
warn_no_return = True
warn_redundant_casts = True
warn_unreachable = True
warn_unused_configs = True
warn_unused_ignores = True
[mypy-tianshou.utils.net.*]
ignore_errors = True
[pydocstyle] [pydocstyle]
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402 ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402

View File

@ -100,11 +100,6 @@ def test_collect_ep(data):
data["collector"].collect(n_episode=10) data["collector"].collect(n_episode=10)
def test_sample(data):
for _ in range(5000):
data["collector"].sample(256)
def test_init_vec_env(data): def test_init_vec_env(data):
for _ in range(5000): for _ in range(5000):
Collector(data["policy"], data["env_vec"], data["buffer"]) Collector(data["policy"], data["env_vec"], data["buffer"])
@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data):
data["collector_vec"].collect(n_episode=10) data["collector_vec"].collect(n_episode=10)
def test_sample_vec_env(data):
for _ in range(5000):
data["collector_vec"].sample(256)
def test_init_subproc_env(data): def test_init_subproc_env(data):
for _ in range(5000): for _ in range(5000):
Collector(data["policy"], data["env_subproc_init"], data["buffer"]) Collector(data["policy"], data["env_subproc_init"], data["buffer"])
@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data):
data["collector_subproc"].collect(n_episode=10) data["collector_subproc"].collect(n_episode=10)
def test_sample_subproc_env(data):
for _ in range(5000):
data["collector_subproc"].sample(256)
if __name__ == '__main__': if __name__ == '__main__':
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"]) pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration from tianshou import data, env, utils, policy, trainer, exploration
__version__ = "0.2.7" __version__ = "0.3.0rc0"
__all__ = [ __all__ = [
"env", "env",

View File

@ -6,7 +6,7 @@ from copy import deepcopy
from numbers import Number from numbers import Number
from collections.abc import Collection from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
Sequence, KeysView, ValuesView, ItemsView Sequence
# Disable pickle warning related to torch, since it has been removed # Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details: # on torch master branch. See Pull Request #39003 for details:
@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try: try:
return torch.stack(v) return torch.stack(v) # type: ignore
except RuntimeError as e: except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable" raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e " of torch.Tensor as unique value yet.") from e
@ -191,12 +191,20 @@ class Batch:
elif _is_batch_set(batch_dict): elif _is_batch_set(batch_dict):
self.stack_(batch_dict) self.stack_(batch_dict)
if len(kwargs) > 0: if len(kwargs) > 0:
self.__init__(kwargs, copy=copy) self.__init__(kwargs, copy=copy) # type: ignore
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value.""" """Set self.key = value."""
self.__dict__[key] = _parse_value(value) self.__dict__[key] = _parse_value(value)
def __getattr__(self, key: str) -> Any:
"""Return self.key. The "Any" return type is needed for mypy."""
return getattr(self.__dict__, key)
def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
"""Pickling interface. """Pickling interface.
@ -215,11 +223,11 @@ class Batch:
At this point, self is an empty Batch instance that has not been At this point, self is an empty Batch instance that has not been
initialized, so it can safely be initialized by the pickle state. initialized, so it can safely be initialized by the pickle state.
""" """
self.__init__(**state) self.__init__(**state) # type: ignore
def __getitem__( def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
) -> Union["Batch", np.ndarray, torch.Tensor]: ) -> Any:
"""Return self[index].""" """Return self[index]."""
if isinstance(index, str): if isinstance(index, str):
return self.__dict__[index] return self.__dict__[index]
@ -245,7 +253,7 @@ class Batch:
if isinstance(index, str): if isinstance(index, str):
self.__dict__[index] = value self.__dict__[index] = value
return return
if isinstance(value, (np.ndarray, torch.Tensor)): if not isinstance(value, Batch):
raise ValueError("Batch does not supported tensor assignment. " raise ValueError("Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.") "Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()): if not set(value.keys()).issubset(self.__dict__.keys()):
@ -330,30 +338,6 @@ class Batch:
s = self.__class__.__name__ + "()" s = self.__class__.__name__ + "()"
return s return s
def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__
def keys(self) -> KeysView[str]:
"""Return self.keys()."""
return self.__dict__.keys()
def values(self) -> ValuesView[Any]:
"""Return self.values()."""
return self.__dict__.values()
def items(self) -> ItemsView[str, Any]:
"""Return self.items()."""
return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Any:
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)
def pop(self, k: str, d: Optional[Any] = None) -> Any:
"""Return & remove self[k] if k in self else d. d defaults to None."""
return self.__dict__.pop(k, d)
def to_numpy(self) -> None: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place.""" """Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items(): for k, v in self.items():
@ -375,7 +359,6 @@ class Batch:
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \ if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \ v.device.type != device.type or \
device.index is not None and \
device.index != v.device.index: device.index != v.device.index:
if dtype is not None: if dtype is not None:
v = v.type(dtype) v = v.type(dtype)
@ -517,7 +500,7 @@ class Batch:
return return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if not self.is_empty(): if not self.is_empty():
batches = [self] + list(batches) batches = [self] + batches
# collect non-empty keys # collect non-empty keys
keys_map = [ keys_map = [
set(k for k, v in batch.items() set(k for k, v in batch.items()
@ -672,8 +655,8 @@ class Batch:
for v in self.__dict__.values(): for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True): if isinstance(v, Batch) and v.is_empty(recurse=True):
continue continue
elif hasattr(v, "__len__") and (not isinstance( elif hasattr(v, "__len__") and (
v, (np.ndarray, torch.Tensor)) or v.ndim > 0 isinstance(v, Batch) or v.ndim > 0
): ):
r.append(len(v)) r.append(len(v))
else: else:

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from numbers import Number from numbers import Number
from typing import Any, Dict, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value from tianshou.data.batch import _create_value
@ -138,7 +138,7 @@ class ReplayBuffer:
self._indices = np.arange(size) self._indices = np.arange(size)
self.stack_num = stack_num self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1 self._avail = sample_avail and stack_num > 1
self._avail_index = [] self._avail_index: List[int] = []
self._save_s_ = not ignore_obs_next self._save_s_ = not ignore_obs_next
self._last_obs = save_only_last_obs self._last_obs = save_only_last_obs
self._index = 0 self._index = 0
@ -175,12 +175,12 @@ class ReplayBuffer:
except KeyError: except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize) self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name] value = self._meta.__dict__[name]
if isinstance(inst, (torch.Tensor, np.ndarray)) \ if isinstance(inst, (torch.Tensor, np.ndarray)):
and inst.shape != value.shape[1:]: if inst.shape != value.shape[1:]:
raise ValueError( raise ValueError(
"Cannot add data to a buffer with different shape, with key " "Cannot add data to a buffer with different shape with key"
f"{name}, expect {value.shape[1:]}, given {inst.shape}." f" {name}, expect {value.shape[1:]}, given {inst.shape}."
) )
try: try:
value[self._index] = inst value[self._index] = inst
except KeyError: except KeyError:
@ -205,7 +205,7 @@ class ReplayBuffer:
stack_num_orig = buffer.stack_num stack_num_orig = buffer.stack_num
buffer.stack_num = 1 buffer.stack_num = 1
while True: while True:
self.add(**buffer[i]) self.add(**buffer[i]) # type: ignore
i = (i + 1) % len(buffer) i = (i + 1) % len(buffer)
if i == begin: if i == begin:
break break
@ -323,7 +323,7 @@ class ReplayBuffer:
try: try:
if stack_num == 1: if stack_num == 1:
return val[indice] return val[indice]
stack = [] stack: List[Any] = []
for _ in range(stack_num): for _ in range(stack_num):
stack = [val[indice]] + stack stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1) pre_indice = np.asarray(indice - 1)

View File

@ -212,10 +212,8 @@ class Collector(object):
finished_env_ids = [] finished_env_ids = []
reward_total = 0.0 reward_total = 0.0
whole_data = Batch() whole_data = Batch()
list_n_episode = False if isinstance(n_episode, list):
if n_episode is not None and not np.isscalar(n_episode):
assert len(n_episode) == self.get_env_num() assert len(n_episode) == self.get_env_num()
list_n_episode = True
finished_env_ids = [ finished_env_ids = [
i for i in self._ready_env_ids if n_episode[i] <= 0] i for i in self._ready_env_ids if n_episode[i] <= 0]
self._ready_env_ids = np.array( self._ready_env_ids = np.array(
@ -266,7 +264,8 @@ class Collector(object):
self.data.policy._state = self.data.state self.data.policy._state = self.data.state
self.data.act = to_numpy(result.act) self.data.act = to_numpy(result.act)
if self._action_noise is not None: # noqa if self._action_noise is not None:
assert isinstance(self.data.act, np.ndarray)
self.data.act += self._action_noise(self.data.act.shape) self.data.act += self._action_noise(self.data.act.shape)
# step in env # step in env
@ -291,7 +290,7 @@ class Collector(object):
# add data into the buffer # add data into the buffer
if self.preprocess_fn: if self.preprocess_fn:
result = self.preprocess_fn(**self.data) result = self.preprocess_fn(**self.data) # type: ignore
self.data.update(result) self.data.update(result)
for j, i in enumerate(self._ready_env_ids): for j, i in enumerate(self._ready_env_ids):
@ -305,14 +304,14 @@ class Collector(object):
self._cached_buf[i].add(**self.data[j]) self._cached_buf[i].add(**self.data[j])
if done[j]: if done[j]:
if not (list_n_episode and if not (isinstance(n_episode, list)
episode_count[i] >= n_episode[i]): and episode_count[i] >= n_episode[i]):
episode_count[i] += 1 episode_count[i] += 1
reward_total += np.sum(self._cached_buf[i].rew, axis=0) reward_total += np.sum(self._cached_buf[i].rew, axis=0)
step_count += len(self._cached_buf[i]) step_count += len(self._cached_buf[i])
if self.buffer is not None: if self.buffer is not None:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
if list_n_episode and \ if isinstance(n_episode, list) and \
episode_count[i] >= n_episode[i]: episode_count[i] >= n_episode[i]:
# env i has collected enough data, it has finished # env i has collected enough data, it has finished
finished_env_ids.append(i) finished_env_ids.append(i)
@ -324,10 +323,9 @@ class Collector(object):
env_ind_global = self._ready_env_ids[env_ind_local] env_ind_global = self._ready_env_ids[env_ind_local]
obs_reset = self.env.reset(env_ind_global) obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn: if self.preprocess_fn:
obs_next[env_ind_local] = self.preprocess_fn( obs_reset = self.preprocess_fn(
obs=obs_reset).get("obs", obs_reset) obs=obs_reset).get("obs", obs_reset)
else: obs_next[env_ind_local] = obs_reset
obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next self.data.obs = obs_next
if is_async: if is_async:
# set data back # set data back
@ -362,7 +360,7 @@ class Collector(object):
# average reward across the number of episodes # average reward across the number of episodes
reward_avg = reward_total / episode_count reward_avg = reward_total / episode_count
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
reward_avg = self._rew_metric(reward_avg) reward_avg = self._rew_metric(reward_avg) # type: ignore
return { return {
"n/ep": episode_count, "n/ep": episode_count,
"n/st": step_count, "n/st": step_count,
@ -372,30 +370,6 @@ class Collector(object):
"len": step_count / episode_count, "len": step_count / episode_count,
} }
def sample(self, batch_size: int) -> Batch:
"""Sample a data batch from the internal replay buffer.
It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
returning the final batch data.
:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
warnings.warn(
"Collector.sample is deprecated and will cause error if you use "
"prioritized experience replay! Collector.sample will be removed "
"upon version 0.3. Use policy.update instead!", Warning)
assert self.buffer is not None, "Cannot get sample from empty buffer!"
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data
def close(self) -> None:
warnings.warn(
"Collector.close is deprecated and will be removed upon version "
"0.3.", Warning)
def _batch_set_item( def _batch_set_item(
source: Batch, indices: np.ndarray, target: Batch, size: int source: Batch, indices: np.ndarray, target: Batch, size: int

View File

@ -45,14 +45,14 @@ def to_torch(
if isinstance(x, np.ndarray) and issubclass( if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number) x.dtype.type, (np.bool_, np.number)
): # most often case ): # most often case
x = torch.from_numpy(x).to(device) x = torch.from_numpy(x).to(device) # type: ignore
if dtype is not None: if dtype is not None:
x = x.type(dtype) x = x.type(dtype)
return x return x
elif isinstance(x, torch.Tensor): # second often case elif isinstance(x, torch.Tensor): # second often case
if dtype is not None: if dtype is not None:
x = x.type(dtype) x = x.type(dtype)
return x.to(device) return x.to(device) # type: ignore
elif isinstance(x, (np.number, np.bool_, Number)): elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device) return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, dict): elif isinstance(x, dict):

View File

@ -1,11 +1,10 @@
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \ from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
from tianshou.env.maenv import MultiAgentEnv from tianshou.env.maenv import MultiAgentEnv
__all__ = [ __all__ = [
"BaseVectorEnv", "BaseVectorEnv",
"DummyVectorEnv", "DummyVectorEnv",
"VectorEnv", # TODO: remove in later version
"SubprocVectorEnv", "SubprocVectorEnv",
"ShmemVectorEnv", "ShmemVectorEnv",
"RayVectorEnv", "RayVectorEnv",

29
tianshou/env/venvs.py vendored
View File

@ -1,5 +1,4 @@
import gym import gym
import warnings
import numpy as np import numpy as np
from typing import Any, List, Union, Optional, Callable from typing import Any, List, Union, Optional, Callable
@ -84,12 +83,12 @@ class BaseVectorEnv(gym.Env):
self.timeout is None or self.timeout > 0 self.timeout is None or self.timeout > 0
), f"timeout is {timeout}, it should be positive if provided!" ), f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None self.is_async = self.wait_num != len(env_fns) or timeout is not None
self.waiting_conn = [] self.waiting_conn: List[EnvWorker] = []
# environments in self.ready_id is actually ready # environments in self.ready_id is actually ready
# but environments in self.waiting_id are just waiting when checked, # but environments in self.waiting_id are just waiting when checked,
# and they may be ready now, but this is not known until we check it # and they may be ready now, but this is not known until we check it
# in the step() function # in the step() function
self.waiting_id = [] self.waiting_id: List[int] = []
# all environments are ready in the beginning # all environments are ready in the beginning
self.ready_id = list(range(self.env_num)) self.ready_id = list(range(self.env_num))
self.is_closed = False self.is_closed = False
@ -216,10 +215,11 @@ class BaseVectorEnv(gym.Env):
self.waiting_conn.append(self.workers[env_id]) self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id) self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id] self.ready_id = [x for x in self.ready_id if x not in id]
ready_conns, result = [], [] ready_conns: List[EnvWorker] = []
while not ready_conns: while not ready_conns:
ready_conns = self.worker_class.wait( ready_conns = self.worker_class.wait(
self.waiting_conn, self.wait_num, self.timeout) self.waiting_conn, self.wait_num, self.timeout)
result = []
for conn in ready_conns: for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn) waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index) self.waiting_conn.pop(waiting_index)
@ -243,11 +243,14 @@ class BaseVectorEnv(gym.Env):
which a reproducer pass to "seed". which a reproducer pass to "seed".
""" """
self._assert_is_not_closed() self._assert_is_not_closed()
seed_list: Union[List[None], List[int]]
if seed is None: if seed is None:
seed = [seed] * self.env_num seed_list = [seed] * self.env_num
elif np.isscalar(seed): elif isinstance(seed, int):
seed = [seed + i for i in range(self.env_num)] seed_list = [seed + i for i in range(self.env_num)]
return [w.seed(s) for w, s in zip(self.workers, seed)] else:
seed_list = seed
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
def render(self, **kwargs: Any) -> List[Any]: def render(self, **kwargs: Any) -> List[Any]:
"""Render all of the environments.""" """Render all of the environments."""
@ -295,16 +298,6 @@ class DummyVectorEnv(BaseVectorEnv):
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
class VectorEnv(DummyVectorEnv):
"""VectorEnv is renamed to DummyVectorEnv."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
"VectorEnv is renamed to DummyVectorEnv, and will be removed in "
"0.3. Use DummyVectorEnv instead!", Warning)
super().__init__(*args, **kwargs)
class SubprocVectorEnv(BaseVectorEnv): class SubprocVectorEnv(BaseVectorEnv):
"""Vectorized environment wrapper based on subprocess. """Vectorized environment wrapper based on subprocess.

View File

@ -10,7 +10,7 @@ class EnvWorker(ABC):
def __init__(self, env_fn: Callable[[], gym.Env]) -> None: def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn self._env_fn = env_fn
self.is_closed = False self.is_closed = False
self.result = (None, None, None, None) self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
@abstractmethod @abstractmethod
def __getattr__(self, key: str) -> Any: def __getattr__(self, key: str) -> Any:

View File

@ -19,7 +19,7 @@ class DummyEnvWorker(EnvWorker):
return self.env.reset() return self.env.reset()
@staticmethod @staticmethod
def wait( def wait( # type: ignore
workers: List["DummyEnvWorker"], workers: List["DummyEnvWorker"],
wait_num: int, wait_num: int,
timeout: Optional[float] = None, timeout: Optional[float] = None,

View File

@ -24,7 +24,7 @@ class RayEnvWorker(EnvWorker):
return ray.get(self.env.reset.remote()) return ray.get(self.env.reset.remote())
@staticmethod @staticmethod
def wait( def wait( # type: ignore
workers: List["RayEnvWorker"], workers: List["RayEnvWorker"],
wait_num: int, wait_num: int,
timeout: Optional[float] = None, timeout: Optional[float] = None,

View File

@ -11,22 +11,71 @@ from tianshou.env.worker import EnvWorker
from tianshou.env.utils import CloudpickleWrapper from tianshou.env.utils import CloudpickleWrapper
_NP_TO_CT = {
np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(
_NP_TO_CT[dtype.type], # type: ignore
int(np.prod(shape)),
)
self.dtype = dtype
self.shape = shape
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
else:
return ShArray(space.dtype, space.shape)
def _worker( def _worker(
parent: connection.Connection, parent: connection.Connection,
p: connection.Connection, p: connection.Connection,
env_fn_wrapper: CloudpickleWrapper, env_fn_wrapper: CloudpickleWrapper,
obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None, obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None: ) -> None:
def _encode_obs( def _encode_obs(
obs: Union[dict, tuple, np.ndarray], obs: Union[dict, tuple, np.ndarray],
buffer: Union[dict, tuple, ShArray], buffer: Union[dict, tuple, ShArray],
) -> None: ) -> None:
if isinstance(obs, np.ndarray): if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
buffer.save(obs) buffer.save(obs)
elif isinstance(obs, tuple): elif isinstance(obs, tuple) and isinstance(buffer, tuple):
for o, b in zip(obs, buffer): for o, b in zip(obs, buffer):
_encode_obs(o, b) _encode_obs(o, b)
elif isinstance(obs, dict): elif isinstance(obs, dict) and isinstance(buffer, dict):
for k in obs.keys(): for k in obs.keys():
_encode_obs(obs[k], buffer[k]) _encode_obs(obs[k], buffer[k])
return None return None
@ -69,52 +118,6 @@ def _worker(
p.close() p.close()
_NP_TO_CT = {
np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
else:
return ShArray(space.dtype, space.shape)
class SubprocEnvWorker(EnvWorker): class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
@ -124,7 +127,7 @@ class SubprocEnvWorker(EnvWorker):
super().__init__(env_fn) super().__init__(env_fn)
self.parent_remote, self.child_remote = Pipe() self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory self.share_memory = share_memory
self.buffer = None self.buffer: Optional[Union[dict, tuple, ShArray]] = None
if self.share_memory: if self.share_memory:
dummy = env_fn() dummy = env_fn()
obs_space = dummy.observation_space obs_space = dummy.observation_space
@ -168,25 +171,23 @@ class SubprocEnvWorker(EnvWorker):
return obs return obs
@staticmethod @staticmethod
def wait( def wait( # type: ignore
workers: List["SubprocEnvWorker"], workers: List["SubprocEnvWorker"],
wait_num: int, wait_num: int,
timeout: Optional[float] = None, timeout: Optional[float] = None,
) -> List["SubprocEnvWorker"]: ) -> List["SubprocEnvWorker"]:
conns, ready_conns = [x.parent_remote for x in workers], [] remain_conns = conns = [x.parent_remote for x in workers]
remain_conns = conns ready_conns: List[connection.Connection] = []
t1 = time.time() remain_time, t1 = timeout, time.time()
while len(remain_conns) > 0 and len(ready_conns) < wait_num: while len(remain_conns) > 0 and len(ready_conns) < wait_num:
if timeout: if timeout:
remain_time = timeout - (time.time() - t1) remain_time = timeout - (time.time() - t1)
if remain_time <= 0: if remain_time <= 0:
break break
else:
remain_time = timeout
# connection.wait hangs if the list is empty # connection.wait hangs if the list is empty
new_ready_conns = connection.wait( new_ready_conns = connection.wait(
remain_conns, timeout=remain_time) remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns) ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [ remain_conns = [
conn for conn in remain_conns if conn not in ready_conns] conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns] return [workers[conns.index(con)] for con in ready_conns]

View File

@ -9,15 +9,15 @@ class BaseNoise(ABC, object):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def reset(self) -> None:
"""Reset to the initial state."""
pass
@abstractmethod @abstractmethod
def __call__(self, size: Sequence[int]) -> np.ndarray: def __call__(self, size: Sequence[int]) -> np.ndarray:
"""Generate new noise.""" """Generate new noise."""
raise NotImplementedError raise NotImplementedError
def reset(self) -> None:
"""Reset to the initial state."""
pass
class GaussianNoise(BaseNoise): class GaussianNoise(BaseNoise):
"""The vanilla gaussian process, for exploration in DDPG by default.""" """The vanilla gaussian process, for exploration in DDPG by default."""
@ -64,6 +64,10 @@ class OUNoise(BaseNoise):
self._x0 = x0 self._x0 = x0
self.reset() self.reset()
def reset(self) -> None:
"""Reset to the initial state."""
self._x = self._x0
def __call__( def __call__(
self, size: Sequence[int], mu: Optional[float] = None self, size: Sequence[int], mu: Optional[float] = None
) -> np.ndarray: ) -> np.ndarray:
@ -71,14 +75,11 @@ class OUNoise(BaseNoise):
Return an numpy array which size is equal to ``size``. Return an numpy array which size is equal to ``size``.
""" """
if self._x is None or self._x.shape != size: if self._x is None or isinstance(
self._x, np.ndarray) and self._x.shape != size:
self._x = 0.0 self._x = 0.0
if mu is None: if mu is None:
mu = self._mu mu = self._mu
r = self._beta * np.random.normal(size=size) r = self._beta * np.random.normal(size=size)
self._x = self._x + self._alpha * (mu - self._x) + r self._x = self._x + self._alpha * (mu - self._x) + r
return self._x return self._x
def reset(self) -> None:
"""Reset to the initial state."""
self._x = self._x0

View File

@ -190,7 +190,7 @@ class BasePolicy(ABC, nn.Module):
array with shape (bsz, ). array with shape (bsz, ).
""" """
rew = batch.rew rew = batch.rew
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten() v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda) returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std() returns = (returns - returns.mean()) / returns.std()

View File

@ -55,11 +55,11 @@ class ImitationPolicy(BasePolicy):
if self.mode == "continuous": # regression if self.mode == "continuous": # regression
a = self(batch).act a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification elif self.mode == "discrete": # classification
a = self(batch).logits a = self(batch).logits
a_ = to_torch(batch.act, dtype=torch.long, device=a.device) a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) loss = F.nll_loss(a, a_) # type: ignore
loss.backward() loss.backward()
self.optim.step() self.optim.step()
return {"loss": loss.item()} return {"loss": loss.item()}

View File

@ -103,11 +103,11 @@ class A2CPolicy(PGPolicy):
if isinstance(logits, tuple): if isinstance(logits, tuple):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
dist = self.dist_fn(logits) dist = self.dist_fn(logits) # type: ignore
act = dist.sample() act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy):
r = to_torch_as(b.returns, v) r = to_torch_as(b.returns, v)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
a_loss = -(log_prob * (r - v).detach()).mean() a_loss = -(log_prob * (r - v).detach()).mean()
vf_loss = F.mse_loss(r, v) vf_loss = F.mse_loss(r, v) # type: ignore
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward() loss.backward()

View File

@ -53,14 +53,16 @@ class DDPGPolicy(BasePolicy):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
if actor is not None: if actor is not None and actor_optim is not None:
self.actor, self.actor_old = actor, deepcopy(actor) self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)
self.actor_old.eval() self.actor_old.eval()
self.actor_optim = actor_optim self.actor_optim: torch.optim.Optimizer = actor_optim
if critic is not None: if critic is not None and critic_optim is not None:
self.critic, self.critic_old = critic, deepcopy(critic) self.critic: torch.nn.Module = critic
self.critic_old = deepcopy(critic)
self.critic_old.eval() self.critic_old.eval()
self.critic_optim = critic_optim self.critic_optim: torch.optim.Optimizer = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self._tau = tau self._tau = tau
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
@ -141,7 +143,7 @@ class DDPGPolicy(BasePolicy):
obs = getattr(batch, input) obs = getattr(batch, input)
actions, h = model(obs, state=state, info=batch.info) actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias actions += self._action_bias
if self.training and explorating: if self._noise and self.training and explorating:
actions += to_torch_as(self._noise(actions.shape), actions) actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1]) actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h) return Batch(act=actions, state=h)

View File

@ -146,11 +146,10 @@ class DQNPolicy(BasePolicy):
obs = getattr(batch, input) obs = getattr(batch, input)
obs_ = obs.obs if hasattr(obs, "obs") else obs obs_ = obs.obs if hasattr(obs, "obs") else obs
q, h = model(obs_, state=state, info=batch.info) q, h = model(obs_, state=state, info=batch.info)
act = to_numpy(q.max(dim=1)[1]) act: np.ndarray = to_numpy(q.max(dim=1)[1])
has_mask = hasattr(obs, 'mask') if hasattr(obs, "mask"):
if has_mask:
# some of actions are masked, they cannot be selected # some of actions are masked, they cannot be selected
q_ = to_numpy(q) q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1) act = q_.argmax(axis=1)
# add eps to act # add eps to act
@ -160,7 +159,7 @@ class DQNPolicy(BasePolicy):
for i in range(len(q)): for i in range(len(q)):
if np.random.rand() < eps: if np.random.rand() < eps:
q_ = np.random.rand(*q[i].shape) q_ = np.random.rand(*q[i].shape)
if has_mask: if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax() act[i] = q_.argmax()
return Batch(logits=q, act=act, state=h) return Batch(logits=q, act=act, state=h)
@ -172,7 +171,7 @@ class DQNPolicy(BasePolicy):
weight = batch.pop("weight", 1.0) weight = batch.pop("weight", 1.0)
q = self(batch, eps=0.0).logits q = self(batch, eps=0.0).logits
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten() r = to_torch_as(batch.returns.flatten(), q)
td = r - q td = r - q
loss = (td.pow(2) * weight).mean() loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer batch.weight = td # prio-buffer

View File

@ -32,7 +32,8 @@ class PGPolicy(BasePolicy):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.model = model if model is not None:
self.model: torch.nn.Module = model
self.optim = optim self.optim = optim
self.dist_fn = dist_fn self.dist_fn = dist_fn
assert ( assert (
@ -81,11 +82,11 @@ class PGPolicy(BasePolicy):
if isinstance(logits, tuple): if isinstance(logits, tuple):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
dist = self.dist_fn(logits) dist = self.dist_fn(logits) # type: ignore
act = dist.sample() act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
losses = [] losses = []

View File

@ -137,13 +137,13 @@ class PPOPolicy(PGPolicy):
if isinstance(logits, tuple): if isinstance(logits, tuple):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
dist = self.dist_fn(logits) dist = self.dist_fn(logits) # type: ignore
act = dist.sample() act = dist.sample()
if self._range: if self._range:
act = act.clamp(self._range[0], self._range[1]) act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]: ) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
@ -157,8 +157,9 @@ class PPOPolicy(PGPolicy):
surr2 = ratio.clamp(1.0 - self._eps_clip, surr2 = ratio.clamp(1.0 - self._eps_clip,
1.0 + self._eps_clip) * b.adv 1.0 + self._eps_clip) * b.adv
if self._dual_clip: if self._dual_clip:
clip_loss = -torch.max(torch.min(surr1, surr2), clip_loss = -torch.max(
self._dual_clip * b.adv).mean() torch.min(surr1, surr2), self._dual_clip * b.adv
).mean()
else: else:
clip_loss = -torch.min(surr1, surr2).mean() clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item()) clip_losses.append(clip_loss.item())

View File

@ -107,7 +107,7 @@ class SACPolicy(DDPGPolicy):
): ):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def forward( def forward( # type: ignore
self, self,
batch: Batch, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
@ -193,5 +193,5 @@ class SACPolicy(DDPGPolicy):
} }
if self._is_auto_alpha: if self._is_auto_alpha:
result["loss/alpha"] = alpha_loss.item() result["loss/alpha"] = alpha_loss.item()
result["v/alpha"] = self._alpha.item() result["v/alpha"] = self._alpha.item() # type: ignore
return result return result

View File

@ -73,7 +73,7 @@ def offpolicy_trainer(
""" """
global_step = 0 global_step = 0
best_epoch, best_reward = -1, -1.0 best_epoch, best_reward = -1, -1.0
stat = {} stat: Dict[str, MovAvg] = {}
start_time = time.time() start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
@ -91,7 +91,7 @@ def offpolicy_trainer(
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step) epoch, episode_per_test, writer, global_step)
if stop_fn and stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
for k in result.keys(): for k in result.keys():

View File

@ -73,7 +73,7 @@ def onpolicy_trainer(
""" """
global_step = 0 global_step = 0
best_epoch, best_reward = -1, -1.0 best_epoch, best_reward = -1, -1.0
stat = {} stat: Dict[str, MovAvg] = {}
start_time = time.time() start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
@ -91,7 +91,7 @@ def onpolicy_trainer(
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step) epoch, episode_per_test, writer, global_step)
if stop_fn and stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
for k in result.keys(): for k in result.keys():
@ -109,9 +109,9 @@ def onpolicy_trainer(
batch_size=batch_size, repeat=repeat_per_collect) batch_size=batch_size, repeat=repeat_per_collect)
train_collector.reset_buffer() train_collector.reset_buffer()
step = 1 step = 1
for k in losses.keys(): for v in losses.values():
if isinstance(losses[k], list): if isinstance(v, list):
step = max(step, len(losses[k])) step = max(step, len(v))
global_step += step * collect_per_step global_step += step * collect_per_step
for k in result.keys(): for k in result.keys():
data[k] = f"{result[k]:.2f}" data[k] = f"{result[k]:.2f}"

View File

@ -22,7 +22,7 @@ def test_episode(
policy.eval() policy.eval()
if test_fn: if test_fn:
test_fn(epoch) test_fn(epoch)
if collector.get_env_num() > 1 and np.isscalar(n_episode): if collector.get_env_num() > 1 and isinstance(n_episode, int):
n = collector.get_env_num() n = collector.get_env_num()
n_ = np.zeros(n) + n_episode // n n_ = np.zeros(n) + n_episode // n
n_[:n_episode % n] += 1 n_[:n_episode % n] += 1

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from numbers import Number from numbers import Number
from typing import Union from typing import List, Union
from tianshou.data import to_numpy from tianshou.data import to_numpy
@ -28,7 +28,7 @@ class MovAvg(object):
def __init__(self, size: int = 100) -> None: def __init__(self, size: int = 100) -> None:
super().__init__() super().__init__()
self.size = size self.size = size
self.cache = [] self.cache: List[Union[Number, np.number]] = []
self.banned = [np.inf, np.nan, -np.inf] self.banned = [np.inf, np.nan, -np.inf]
def add( def add(

View File

@ -12,7 +12,7 @@ def miniblock(
norm_layer: Optional[Callable[[int], nn.modules.Module]], norm_layer: Optional[Callable[[int], nn.modules.Module]],
) -> List[nn.modules.Module]: ) -> List[nn.modules.Module]:
"""Construct a miniblock with given input/output-size and norm layer.""" """Construct a miniblock with given input/output-size and norm layer."""
ret = [nn.Linear(inp, oup)] ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
if norm_layer is not None: if norm_layer is not None:
ret += [norm_layer(oup)] ret += [norm_layer(oup)]
ret += [nn.ReLU(inplace=True)] ret += [nn.ReLU(inplace=True)]
@ -54,36 +54,33 @@ class Net(nn.Module):
if concat: if concat:
input_size += np.prod(action_shape) input_size += np.prod(action_shape)
self.model = miniblock(input_size, hidden_layer_size, norm_layer) model = miniblock(input_size, hidden_layer_size, norm_layer)
for i in range(layer_num): for i in range(layer_num):
self.model += miniblock(hidden_layer_size, model += miniblock(
hidden_layer_size, norm_layer) hidden_layer_size, hidden_layer_size, norm_layer)
if self.dueling is None: if dueling is None:
if action_shape and not concat: if action_shape and not concat:
self.model += [nn.Linear(hidden_layer_size, model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
np.prod(action_shape))]
else: # dueling DQN else: # dueling DQN
assert isinstance(self.dueling, tuple) and len(self.dueling) == 2 q_layer_num, v_layer_num = dueling
Q, V = [], []
q_layer_num, v_layer_num = self.dueling
self.Q, self.V = [], []
for i in range(q_layer_num): for i in range(q_layer_num):
self.Q += miniblock(hidden_layer_size, Q += miniblock(
hidden_layer_size, norm_layer) hidden_layer_size, hidden_layer_size, norm_layer)
for i in range(v_layer_num): for i in range(v_layer_num):
self.V += miniblock(hidden_layer_size, V += miniblock(
hidden_layer_size, norm_layer) hidden_layer_size, hidden_layer_size, norm_layer)
if action_shape and not concat: if action_shape and not concat:
self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))] Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
self.V += [nn.Linear(hidden_layer_size, 1)] V += [nn.Linear(hidden_layer_size, 1)]
self.Q = nn.Sequential(*self.Q) self.Q = nn.Sequential(*Q)
self.V = nn.Sequential(*self.V) self.V = nn.Sequential(*V)
self.model = nn.Sequential(*self.model) self.model = nn.Sequential(*model)
def forward( def forward(
self, self,

View File

@ -66,7 +66,7 @@ class Critic(nn.Module):
s = to_torch(s, device=self.device, dtype=torch.float32) s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1) s = s.flatten(1)
if a is not None: if a is not None:
a = to_torch(a, device=self.device, dtype=torch.float32) a = to_torch_as(a, s)
a = a.flatten(1) a = a.flatten(1)
s = torch.cat([s, a], dim=1) s = torch.cat([s, a], dim=1)
logits, h = self.preprocess(s) logits, h = self.preprocess(s)

View File

@ -4,6 +4,8 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch
class Actor(nn.Module): class Actor(nn.Module):
"""Simple actor network with MLP. """Simple actor network with MLP.
@ -118,5 +120,5 @@ class DQN(nn.Module):
) -> Tuple[torch.Tensor, Any]: ) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*).""" r"""Mapping: x -> Q(x, \*)."""
if not isinstance(x, torch.Tensor): if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32) x = to_torch(x, device=self.device, dtype=torch.float32)
return self.net(x), state return self.net(x), state