type check in unit test (#200)
Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml. Also remove the out-of-the-date api
This commit is contained in:
parent
c91def6cbc
commit
b284ace102
5
.github/workflows/lint_and_docs.yml
vendored
5
.github/workflows/lint_and_docs.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: PEP8 and Docs Check
|
||||
name: PEP8, Types and Docs Check
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
@ -20,6 +20,9 @@ jobs:
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
flake8 . --count --show-source --statistics
|
||||
- name: Type check
|
||||
run: |
|
||||
mypy
|
||||
- name: Documentation test
|
||||
run: |
|
||||
pydocstyle tianshou
|
||||
|
||||
@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run:
|
||||
$ flake8 . --count --show-source --statistics
|
||||
|
||||
|
||||
Type Check
|
||||
----------
|
||||
|
||||
We use `mypy <https://github.com/python/mypy/>`_ to check the type annotations. To check, in the main directory, run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ mypy
|
||||
|
||||
|
||||
Test Locally
|
||||
------------
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ Here is Tianshou's other features:
|
||||
* Support :ref:`customize_training`
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
* Support :doc:`/tutorials/tictactoe`
|
||||
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
|
||||
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||
|
||||
|
||||
23
setup.cfg
23
setup.cfg
@ -1,3 +1,26 @@
|
||||
[mypy]
|
||||
files = tianshou/**/*.py
|
||||
allow_redefinition = True
|
||||
check_untyped_defs = True
|
||||
disallow_incomplete_defs = True
|
||||
disallow_untyped_defs = True
|
||||
ignore_missing_imports = True
|
||||
no_implicit_optional = True
|
||||
pretty = True
|
||||
show_error_codes = True
|
||||
show_error_context = True
|
||||
show_traceback = True
|
||||
strict_equality = True
|
||||
strict_optional = True
|
||||
warn_no_return = True
|
||||
warn_redundant_casts = True
|
||||
warn_unreachable = True
|
||||
warn_unused_configs = True
|
||||
warn_unused_ignores = True
|
||||
|
||||
[mypy-tianshou.utils.net.*]
|
||||
ignore_errors = True
|
||||
|
||||
[pydocstyle]
|
||||
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402
|
||||
|
||||
|
||||
@ -100,11 +100,6 @@ def test_collect_ep(data):
|
||||
data["collector"].collect(n_episode=10)
|
||||
|
||||
|
||||
def test_sample(data):
|
||||
for _ in range(5000):
|
||||
data["collector"].sample(256)
|
||||
|
||||
|
||||
def test_init_vec_env(data):
|
||||
for _ in range(5000):
|
||||
Collector(data["policy"], data["env_vec"], data["buffer"])
|
||||
@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data):
|
||||
data["collector_vec"].collect(n_episode=10)
|
||||
|
||||
|
||||
def test_sample_vec_env(data):
|
||||
for _ in range(5000):
|
||||
data["collector_vec"].sample(256)
|
||||
|
||||
|
||||
def test_init_subproc_env(data):
|
||||
for _ in range(5000):
|
||||
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
|
||||
@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data):
|
||||
data["collector_subproc"].collect(n_episode=10)
|
||||
|
||||
|
||||
def test_sample_subproc_env(data):
|
||||
for _ in range(5000):
|
||||
data["collector_subproc"].sample(256)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, exploration
|
||||
|
||||
|
||||
__version__ = "0.2.7"
|
||||
__version__ = "0.3.0rc0"
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
|
||||
@ -6,7 +6,7 @@ from copy import deepcopy
|
||||
from numbers import Number
|
||||
from collections.abc import Collection
|
||||
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
||||
Sequence, KeysView, ValuesView, ItemsView
|
||||
Sequence
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
# on torch master branch. See Pull Request #39003 for details:
|
||||
@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
|
||||
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
|
||||
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
|
||||
try:
|
||||
return torch.stack(v)
|
||||
return torch.stack(v) # type: ignore
|
||||
except RuntimeError as e:
|
||||
raise TypeError("Batch does not support non-stackable iterable"
|
||||
" of torch.Tensor as unique value yet.") from e
|
||||
@ -191,12 +191,20 @@ class Batch:
|
||||
elif _is_batch_set(batch_dict):
|
||||
self.stack_(batch_dict)
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs, copy=copy)
|
||||
self.__init__(kwargs, copy=copy) # type: ignore
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
"""Set self.key = value."""
|
||||
self.__dict__[key] = _parse_value(value)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
"""Return self.key. The "Any" return type is needed for mypy."""
|
||||
return getattr(self.__dict__, key)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Return key in self."""
|
||||
return key in self.__dict__
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
"""Pickling interface.
|
||||
|
||||
@ -215,11 +223,11 @@ class Batch:
|
||||
At this point, self is an empty Batch instance that has not been
|
||||
initialized, so it can safely be initialized by the pickle state.
|
||||
"""
|
||||
self.__init__(**state)
|
||||
self.__init__(**state) # type: ignore
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
|
||||
) -> Union["Batch", np.ndarray, torch.Tensor]:
|
||||
) -> Any:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
return self.__dict__[index]
|
||||
@ -245,7 +253,7 @@ class Batch:
|
||||
if isinstance(index, str):
|
||||
self.__dict__[index] = value
|
||||
return
|
||||
if isinstance(value, (np.ndarray, torch.Tensor)):
|
||||
if not isinstance(value, Batch):
|
||||
raise ValueError("Batch does not supported tensor assignment. "
|
||||
"Use a compatible Batch or dict instead.")
|
||||
if not set(value.keys()).issubset(self.__dict__.keys()):
|
||||
@ -330,30 +338,6 @@ class Batch:
|
||||
s = self.__class__.__name__ + "()"
|
||||
return s
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Return key in self."""
|
||||
return key in self.__dict__
|
||||
|
||||
def keys(self) -> KeysView[str]:
|
||||
"""Return self.keys()."""
|
||||
return self.__dict__.keys()
|
||||
|
||||
def values(self) -> ValuesView[Any]:
|
||||
"""Return self.values()."""
|
||||
return self.__dict__.values()
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
"""Return self.items()."""
|
||||
return self.__dict__.items()
|
||||
|
||||
def get(self, k: str, d: Optional[Any] = None) -> Any:
|
||||
"""Return self[k] if k in self else d. d defaults to None."""
|
||||
return self.__dict__.get(k, d)
|
||||
|
||||
def pop(self, k: str, d: Optional[Any] = None) -> Any:
|
||||
"""Return & remove self[k] if k in self else d. d defaults to None."""
|
||||
return self.__dict__.pop(k, d)
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||
for k, v in self.items():
|
||||
@ -375,7 +359,6 @@ class Batch:
|
||||
if isinstance(v, torch.Tensor):
|
||||
if dtype is not None and v.dtype != dtype or \
|
||||
v.device.type != device.type or \
|
||||
device.index is not None and \
|
||||
device.index != v.device.index:
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
@ -517,7 +500,7 @@ class Batch:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
if not self.is_empty():
|
||||
batches = [self] + list(batches)
|
||||
batches = [self] + batches
|
||||
# collect non-empty keys
|
||||
keys_map = [
|
||||
set(k for k, v in batch.items()
|
||||
@ -672,8 +655,8 @@ class Batch:
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch) and v.is_empty(recurse=True):
|
||||
continue
|
||||
elif hasattr(v, "__len__") and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0
|
||||
elif hasattr(v, "__len__") and (
|
||||
isinstance(v, Batch) or v.ndim > 0
|
||||
):
|
||||
r.append(len(v))
|
||||
else:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||
from tianshou.data.batch import _create_value
|
||||
@ -138,7 +138,7 @@ class ReplayBuffer:
|
||||
self._indices = np.arange(size)
|
||||
self.stack_num = stack_num
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index = []
|
||||
self._avail_index: List[int] = []
|
||||
self._save_s_ = not ignore_obs_next
|
||||
self._last_obs = save_only_last_obs
|
||||
self._index = 0
|
||||
@ -175,12 +175,12 @@ class ReplayBuffer:
|
||||
except KeyError:
|
||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
||||
value = self._meta.__dict__[name]
|
||||
if isinstance(inst, (torch.Tensor, np.ndarray)) \
|
||||
and inst.shape != value.shape[1:]:
|
||||
raise ValueError(
|
||||
"Cannot add data to a buffer with different shape, with key "
|
||||
f"{name}, expect {value.shape[1:]}, given {inst.shape}."
|
||||
)
|
||||
if isinstance(inst, (torch.Tensor, np.ndarray)):
|
||||
if inst.shape != value.shape[1:]:
|
||||
raise ValueError(
|
||||
"Cannot add data to a buffer with different shape with key"
|
||||
f" {name}, expect {value.shape[1:]}, given {inst.shape}."
|
||||
)
|
||||
try:
|
||||
value[self._index] = inst
|
||||
except KeyError:
|
||||
@ -205,7 +205,7 @@ class ReplayBuffer:
|
||||
stack_num_orig = buffer.stack_num
|
||||
buffer.stack_num = 1
|
||||
while True:
|
||||
self.add(**buffer[i])
|
||||
self.add(**buffer[i]) # type: ignore
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
@ -323,7 +323,7 @@ class ReplayBuffer:
|
||||
try:
|
||||
if stack_num == 1:
|
||||
return val[indice]
|
||||
stack = []
|
||||
stack: List[Any] = []
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
pre_indice = np.asarray(indice - 1)
|
||||
|
||||
@ -212,10 +212,8 @@ class Collector(object):
|
||||
finished_env_ids = []
|
||||
reward_total = 0.0
|
||||
whole_data = Batch()
|
||||
list_n_episode = False
|
||||
if n_episode is not None and not np.isscalar(n_episode):
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
list_n_episode = True
|
||||
finished_env_ids = [
|
||||
i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||
self._ready_env_ids = np.array(
|
||||
@ -266,7 +264,8 @@ class Collector(object):
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
self.data.act = to_numpy(result.act)
|
||||
if self._action_noise is not None: # noqa
|
||||
if self._action_noise is not None:
|
||||
assert isinstance(self.data.act, np.ndarray)
|
||||
self.data.act += self._action_noise(self.data.act.shape)
|
||||
|
||||
# step in env
|
||||
@ -291,7 +290,7 @@ class Collector(object):
|
||||
|
||||
# add data into the buffer
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(**self.data)
|
||||
result = self.preprocess_fn(**self.data) # type: ignore
|
||||
self.data.update(result)
|
||||
|
||||
for j, i in enumerate(self._ready_env_ids):
|
||||
@ -305,14 +304,14 @@ class Collector(object):
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
|
||||
if done[j]:
|
||||
if not (list_n_episode and
|
||||
episode_count[i] >= n_episode[i]):
|
||||
if not (isinstance(n_episode, list)
|
||||
and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
if list_n_episode and \
|
||||
if isinstance(n_episode, list) and \
|
||||
episode_count[i] >= n_episode[i]:
|
||||
# env i has collected enough data, it has finished
|
||||
finished_env_ids.append(i)
|
||||
@ -324,10 +323,9 @@ class Collector(object):
|
||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_next[env_ind_local] = self.preprocess_fn(
|
||||
obs_reset = self.preprocess_fn(
|
||||
obs=obs_reset).get("obs", obs_reset)
|
||||
else:
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
@ -362,7 +360,7 @@ class Collector(object):
|
||||
# average reward across the number of episodes
|
||||
reward_avg = reward_total / episode_count
|
||||
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
|
||||
reward_avg = self._rew_metric(reward_avg)
|
||||
reward_avg = self._rew_metric(reward_avg) # type: ignore
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
@ -372,30 +370,6 @@ class Collector(object):
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
|
||||
def sample(self, batch_size: int) -> Batch:
|
||||
"""Sample a data batch from the internal replay buffer.
|
||||
|
||||
It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
|
||||
returning the final batch data.
|
||||
|
||||
:param int batch_size: ``0`` means it will extract all the data from
|
||||
the buffer, otherwise it will extract the data with the given
|
||||
batch_size.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Collector.sample is deprecated and will cause error if you use "
|
||||
"prioritized experience replay! Collector.sample will be removed "
|
||||
"upon version 0.3. Use policy.update instead!", Warning)
|
||||
assert self.buffer is not None, "Cannot get sample from empty buffer!"
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
return batch_data
|
||||
|
||||
def close(self) -> None:
|
||||
warnings.warn(
|
||||
"Collector.close is deprecated and will be removed upon version "
|
||||
"0.3.", Warning)
|
||||
|
||||
|
||||
def _batch_set_item(
|
||||
source: Batch, indices: np.ndarray, target: Batch, size: int
|
||||
|
||||
@ -45,14 +45,14 @@ def to_torch(
|
||||
if isinstance(x, np.ndarray) and issubclass(
|
||||
x.dtype.type, (np.bool_, np.number)
|
||||
): # most often case
|
||||
x = torch.from_numpy(x).to(device)
|
||||
x = torch.from_numpy(x).to(device) # type: ignore
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
return x
|
||||
elif isinstance(x, torch.Tensor): # second often case
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
return x.to(device)
|
||||
return x.to(device) # type: ignore
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
return to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, dict):
|
||||
|
||||
3
tianshou/env/__init__.py
vendored
3
tianshou/env/__init__.py
vendored
@ -1,11 +1,10 @@
|
||||
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
|
||||
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \
|
||||
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
|
||||
from tianshou.env.maenv import MultiAgentEnv
|
||||
|
||||
__all__ = [
|
||||
"BaseVectorEnv",
|
||||
"DummyVectorEnv",
|
||||
"VectorEnv", # TODO: remove in later version
|
||||
"SubprocVectorEnv",
|
||||
"ShmemVectorEnv",
|
||||
"RayVectorEnv",
|
||||
|
||||
29
tianshou/env/venvs.py
vendored
29
tianshou/env/venvs.py
vendored
@ -1,5 +1,4 @@
|
||||
import gym
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, Optional, Callable
|
||||
|
||||
@ -84,12 +83,12 @@ class BaseVectorEnv(gym.Env):
|
||||
self.timeout is None or self.timeout > 0
|
||||
), f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||
self.waiting_conn = []
|
||||
self.waiting_conn: List[EnvWorker] = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id = []
|
||||
self.waiting_id: List[int] = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
self.is_closed = False
|
||||
@ -216,10 +215,11 @@ class BaseVectorEnv(gym.Env):
|
||||
self.waiting_conn.append(self.workers[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
ready_conns, result = [], []
|
||||
ready_conns: List[EnvWorker] = []
|
||||
while not ready_conns:
|
||||
ready_conns = self.worker_class.wait(
|
||||
self.waiting_conn, self.wait_num, self.timeout)
|
||||
result = []
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
@ -243,11 +243,14 @@ class BaseVectorEnv(gym.Env):
|
||||
which a reproducer pass to "seed".
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
seed_list: Union[List[None], List[int]]
|
||||
if seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
elif np.isscalar(seed):
|
||||
seed = [seed + i for i in range(self.env_num)]
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed)]
|
||||
seed_list = [seed] * self.env_num
|
||||
elif isinstance(seed, int):
|
||||
seed_list = [seed + i for i in range(self.env_num)]
|
||||
else:
|
||||
seed_list = seed
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
|
||||
|
||||
def render(self, **kwargs: Any) -> List[Any]:
|
||||
"""Render all of the environments."""
|
||||
@ -295,16 +298,6 @@ class DummyVectorEnv(BaseVectorEnv):
|
||||
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class VectorEnv(DummyVectorEnv):
|
||||
"""VectorEnv is renamed to DummyVectorEnv."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn(
|
||||
"VectorEnv is renamed to DummyVectorEnv, and will be removed in "
|
||||
"0.3. Use DummyVectorEnv instead!", Warning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
|
||||
|
||||
2
tianshou/env/worker/base.py
vendored
2
tianshou/env/worker/base.py
vendored
@ -10,7 +10,7 @@ class EnvWorker(ABC):
|
||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||
self._env_fn = env_fn
|
||||
self.is_closed = False
|
||||
self.result = (None, None, None, None)
|
||||
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||
|
||||
@abstractmethod
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
|
||||
2
tianshou/env/worker/dummy.py
vendored
2
tianshou/env/worker/dummy.py
vendored
@ -19,7 +19,7 @@ class DummyEnvWorker(EnvWorker):
|
||||
return self.env.reset()
|
||||
|
||||
@staticmethod
|
||||
def wait(
|
||||
def wait( # type: ignore
|
||||
workers: List["DummyEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
|
||||
2
tianshou/env/worker/ray.py
vendored
2
tianshou/env/worker/ray.py
vendored
@ -24,7 +24,7 @@ class RayEnvWorker(EnvWorker):
|
||||
return ray.get(self.env.reset.remote())
|
||||
|
||||
@staticmethod
|
||||
def wait(
|
||||
def wait( # type: ignore
|
||||
workers: List["RayEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
|
||||
117
tianshou/env/worker/subproc.py
vendored
117
tianshou/env/worker/subproc.py
vendored
@ -11,22 +11,71 @@ from tianshou.env.worker import EnvWorker
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
}
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(
|
||||
_NP_TO_CT[dtype.type], # type: ignore
|
||||
int(np.prod(shape)),
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray: np.ndarray) -> None:
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self) -> np.ndarray:
|
||||
obj = self.arr.get_obj()
|
||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
return ShArray(space.dtype, space.shape)
|
||||
|
||||
|
||||
def _worker(
|
||||
parent: connection.Connection,
|
||||
p: connection.Connection,
|
||||
env_fn_wrapper: CloudpickleWrapper,
|
||||
obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None,
|
||||
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
|
||||
) -> None:
|
||||
def _encode_obs(
|
||||
obs: Union[dict, tuple, np.ndarray],
|
||||
buffer: Union[dict, tuple, ShArray],
|
||||
) -> None:
|
||||
if isinstance(obs, np.ndarray):
|
||||
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple):
|
||||
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
|
||||
for o, b in zip(obs, buffer):
|
||||
_encode_obs(o, b)
|
||||
elif isinstance(obs, dict):
|
||||
elif isinstance(obs, dict) and isinstance(buffer, dict):
|
||||
for k in obs.keys():
|
||||
_encode_obs(obs[k], buffer[k])
|
||||
return None
|
||||
@ -69,52 +118,6 @@ def _worker(
|
||||
p.close()
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
}
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray: np.ndarray) -> None:
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self) -> np.ndarray:
|
||||
obj = self.arr.get_obj()
|
||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
return ShArray(space.dtype, space.shape)
|
||||
|
||||
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
@ -124,7 +127,7 @@ class SubprocEnvWorker(EnvWorker):
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
self.buffer = None
|
||||
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
|
||||
if self.share_memory:
|
||||
dummy = env_fn()
|
||||
obs_space = dummy.observation_space
|
||||
@ -168,25 +171,23 @@ class SubprocEnvWorker(EnvWorker):
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait(
|
||||
def wait( # type: ignore
|
||||
workers: List["SubprocEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
) -> List["SubprocEnvWorker"]:
|
||||
conns, ready_conns = [x.parent_remote for x in workers], []
|
||||
remain_conns = conns
|
||||
t1 = time.time()
|
||||
remain_conns = conns = [x.parent_remote for x in workers]
|
||||
ready_conns: List[connection.Connection] = []
|
||||
remain_time, t1 = timeout, time.time()
|
||||
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
||||
if timeout:
|
||||
remain_time = timeout - (time.time() - t1)
|
||||
if remain_time <= 0:
|
||||
break
|
||||
else:
|
||||
remain_time = timeout
|
||||
# connection.wait hangs if the list is empty
|
||||
new_ready_conns = connection.wait(
|
||||
remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns)
|
||||
ready_conns.extend(new_ready_conns) # type: ignore
|
||||
remain_conns = [
|
||||
conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
@ -9,15 +9,15 @@ class BaseNoise(ABC, object):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to the initial state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, size: Sequence[int]) -> np.ndarray:
|
||||
"""Generate new noise."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to the initial state."""
|
||||
pass
|
||||
|
||||
|
||||
class GaussianNoise(BaseNoise):
|
||||
"""The vanilla gaussian process, for exploration in DDPG by default."""
|
||||
@ -64,6 +64,10 @@ class OUNoise(BaseNoise):
|
||||
self._x0 = x0
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to the initial state."""
|
||||
self._x = self._x0
|
||||
|
||||
def __call__(
|
||||
self, size: Sequence[int], mu: Optional[float] = None
|
||||
) -> np.ndarray:
|
||||
@ -71,14 +75,11 @@ class OUNoise(BaseNoise):
|
||||
|
||||
Return an numpy array which size is equal to ``size``.
|
||||
"""
|
||||
if self._x is None or self._x.shape != size:
|
||||
if self._x is None or isinstance(
|
||||
self._x, np.ndarray) and self._x.shape != size:
|
||||
self._x = 0.0
|
||||
if mu is None:
|
||||
mu = self._mu
|
||||
r = self._beta * np.random.normal(size=size)
|
||||
self._x = self._x + self._alpha * (mu - self._x) + r
|
||||
return self._x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to the initial state."""
|
||||
self._x = self._x0
|
||||
|
||||
@ -190,7 +190,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
array with shape (bsz, ).
|
||||
"""
|
||||
rew = batch.rew
|
||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
|
||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
|
||||
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
||||
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
||||
returns = (returns - returns.mean()) / returns.std()
|
||||
|
||||
@ -55,11 +55,11 @@ class ImitationPolicy(BasePolicy):
|
||||
if self.mode == "continuous": # regression
|
||||
a = self(batch).act
|
||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||
loss = F.mse_loss(a, a_)
|
||||
loss = F.mse_loss(a, a_) # type: ignore
|
||||
elif self.mode == "discrete": # classification
|
||||
a = self(batch).logits
|
||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||
loss = F.nll_loss(a, a_)
|
||||
loss = F.nll_loss(a, a_) # type: ignore
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return {"loss": loss.item()}
|
||||
|
||||
@ -103,11 +103,11 @@ class A2CPolicy(PGPolicy):
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
dist = self.dist_fn(logits) # type: ignore
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy):
|
||||
r = to_torch_as(b.returns, v)
|
||||
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
||||
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r, v)
|
||||
vf_loss = F.mse_loss(r, v) # type: ignore
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||
loss.backward()
|
||||
|
||||
@ -53,14 +53,16 @@ class DDPGPolicy(BasePolicy):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if actor is not None:
|
||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||
if actor is not None and actor_optim is not None:
|
||||
self.actor: torch.nn.Module = actor
|
||||
self.actor_old = deepcopy(actor)
|
||||
self.actor_old.eval()
|
||||
self.actor_optim = actor_optim
|
||||
if critic is not None:
|
||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||
self.actor_optim: torch.optim.Optimizer = actor_optim
|
||||
if critic is not None and critic_optim is not None:
|
||||
self.critic: torch.nn.Module = critic
|
||||
self.critic_old = deepcopy(critic)
|
||||
self.critic_old.eval()
|
||||
self.critic_optim = critic_optim
|
||||
self.critic_optim: torch.optim.Optimizer = critic_optim
|
||||
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
|
||||
self._tau = tau
|
||||
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
|
||||
@ -141,7 +143,7 @@ class DDPGPolicy(BasePolicy):
|
||||
obs = getattr(batch, input)
|
||||
actions, h = model(obs, state=state, info=batch.info)
|
||||
actions += self._action_bias
|
||||
if self.training and explorating:
|
||||
if self._noise and self.training and explorating:
|
||||
actions += to_torch_as(self._noise(actions.shape), actions)
|
||||
actions = actions.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=actions, state=h)
|
||||
|
||||
@ -146,11 +146,10 @@ class DQNPolicy(BasePolicy):
|
||||
obs = getattr(batch, input)
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
q, h = model(obs_, state=state, info=batch.info)
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
has_mask = hasattr(obs, 'mask')
|
||||
if has_mask:
|
||||
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||
if hasattr(obs, "mask"):
|
||||
# some of actions are masked, they cannot be selected
|
||||
q_ = to_numpy(q)
|
||||
q_: np.ndarray = to_numpy(q)
|
||||
q_[~obs.mask] = -np.inf
|
||||
act = q_.argmax(axis=1)
|
||||
# add eps to act
|
||||
@ -160,7 +159,7 @@ class DQNPolicy(BasePolicy):
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < eps:
|
||||
q_ = np.random.rand(*q[i].shape)
|
||||
if has_mask:
|
||||
if hasattr(obs, "mask"):
|
||||
q_[~obs.mask[i]] = -np.inf
|
||||
act[i] = q_.argmax()
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
@ -172,7 +171,7 @@ class DQNPolicy(BasePolicy):
|
||||
weight = batch.pop("weight", 1.0)
|
||||
q = self(batch, eps=0.0).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns, q).flatten()
|
||||
r = to_torch_as(batch.returns.flatten(), q)
|
||||
td = r - q
|
||||
loss = (td.pow(2) * weight).mean()
|
||||
batch.weight = td # prio-buffer
|
||||
|
||||
@ -32,7 +32,8 @@ class PGPolicy(BasePolicy):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
if model is not None:
|
||||
self.model: torch.nn.Module = model
|
||||
self.optim = optim
|
||||
self.dist_fn = dist_fn
|
||||
assert (
|
||||
@ -81,11 +82,11 @@ class PGPolicy(BasePolicy):
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
dist = self.dist_fn(logits) # type: ignore
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses = []
|
||||
|
||||
@ -137,13 +137,13 @@ class PPOPolicy(PGPolicy):
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
dist = self.dist_fn(logits) # type: ignore
|
||||
act = dist.sample()
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
@ -157,8 +157,9 @@ class PPOPolicy(PGPolicy):
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip,
|
||||
1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2),
|
||||
self._dual_clip * b.adv).mean()
|
||||
clip_loss = -torch.max(
|
||||
torch.min(surr1, surr2), self._dual_clip * b.adv
|
||||
).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
|
||||
@ -107,7 +107,7 @@ class SACPolicy(DDPGPolicy):
|
||||
):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
|
||||
def forward(
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
@ -193,5 +193,5 @@ class SACPolicy(DDPGPolicy):
|
||||
}
|
||||
if self._is_auto_alpha:
|
||||
result["loss/alpha"] = alpha_loss.item()
|
||||
result["v/alpha"] = self._alpha.item()
|
||||
result["v/alpha"] = self._alpha.item() # type: ignore
|
||||
return result
|
||||
|
||||
@ -73,7 +73,7 @@ def offpolicy_trainer(
|
||||
"""
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
stat = {}
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
@ -91,7 +91,7 @@ def offpolicy_trainer(
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, writer, global_step)
|
||||
if stop_fn and stop_fn(test_result["rew"]):
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
for k in result.keys():
|
||||
|
||||
@ -73,7 +73,7 @@ def onpolicy_trainer(
|
||||
"""
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
stat = {}
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
@ -91,7 +91,7 @@ def onpolicy_trainer(
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, writer, global_step)
|
||||
if stop_fn and stop_fn(test_result["rew"]):
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
for k in result.keys():
|
||||
@ -109,9 +109,9 @@ def onpolicy_trainer(
|
||||
batch_size=batch_size, repeat=repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
step = 1
|
||||
for k in losses.keys():
|
||||
if isinstance(losses[k], list):
|
||||
step = max(step, len(losses[k]))
|
||||
for v in losses.values():
|
||||
if isinstance(v, list):
|
||||
step = max(step, len(v))
|
||||
global_step += step * collect_per_step
|
||||
for k in result.keys():
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
|
||||
@ -22,7 +22,7 @@ def test_episode(
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
if collector.get_env_num() > 1 and np.isscalar(n_episode):
|
||||
if collector.get_env_num() > 1 and isinstance(n_episode, int):
|
||||
n = collector.get_env_num()
|
||||
n_ = np.zeros(n) + n_episode // n
|
||||
n_[:n_episode % n] += 1
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Union
|
||||
from typing import List, Union
|
||||
|
||||
from tianshou.data import to_numpy
|
||||
|
||||
@ -28,7 +28,7 @@ class MovAvg(object):
|
||||
def __init__(self, size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache = []
|
||||
self.cache: List[Union[Number, np.number]] = []
|
||||
self.banned = [np.inf, np.nan, -np.inf]
|
||||
|
||||
def add(
|
||||
|
||||
@ -12,7 +12,7 @@ def miniblock(
|
||||
norm_layer: Optional[Callable[[int], nn.modules.Module]],
|
||||
) -> List[nn.modules.Module]:
|
||||
"""Construct a miniblock with given input/output-size and norm layer."""
|
||||
ret = [nn.Linear(inp, oup)]
|
||||
ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
|
||||
if norm_layer is not None:
|
||||
ret += [norm_layer(oup)]
|
||||
ret += [nn.ReLU(inplace=True)]
|
||||
@ -54,36 +54,33 @@ class Net(nn.Module):
|
||||
if concat:
|
||||
input_size += np.prod(action_shape)
|
||||
|
||||
self.model = miniblock(input_size, hidden_layer_size, norm_layer)
|
||||
model = miniblock(input_size, hidden_layer_size, norm_layer)
|
||||
|
||||
for i in range(layer_num):
|
||||
self.model += miniblock(hidden_layer_size,
|
||||
hidden_layer_size, norm_layer)
|
||||
model += miniblock(
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
|
||||
if self.dueling is None:
|
||||
if dueling is None:
|
||||
if action_shape and not concat:
|
||||
self.model += [nn.Linear(hidden_layer_size,
|
||||
np.prod(action_shape))]
|
||||
model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||
else: # dueling DQN
|
||||
assert isinstance(self.dueling, tuple) and len(self.dueling) == 2
|
||||
|
||||
q_layer_num, v_layer_num = self.dueling
|
||||
self.Q, self.V = [], []
|
||||
q_layer_num, v_layer_num = dueling
|
||||
Q, V = [], []
|
||||
|
||||
for i in range(q_layer_num):
|
||||
self.Q += miniblock(hidden_layer_size,
|
||||
hidden_layer_size, norm_layer)
|
||||
Q += miniblock(
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
for i in range(v_layer_num):
|
||||
self.V += miniblock(hidden_layer_size,
|
||||
hidden_layer_size, norm_layer)
|
||||
V += miniblock(
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
|
||||
if action_shape and not concat:
|
||||
self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||
self.V += [nn.Linear(hidden_layer_size, 1)]
|
||||
Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||
V += [nn.Linear(hidden_layer_size, 1)]
|
||||
|
||||
self.Q = nn.Sequential(*self.Q)
|
||||
self.V = nn.Sequential(*self.V)
|
||||
self.model = nn.Sequential(*self.model)
|
||||
self.Q = nn.Sequential(*Q)
|
||||
self.V = nn.Sequential(*V)
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -66,7 +66,7 @@ class Critic(nn.Module):
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.flatten(1)
|
||||
if a is not None:
|
||||
a = to_torch(a, device=self.device, dtype=torch.float32)
|
||||
a = to_torch_as(a, s)
|
||||
a = a.flatten(1)
|
||||
s = torch.cat([s, a], dim=1)
|
||||
logits, h = self.preprocess(s)
|
||||
|
||||
@ -4,6 +4,8 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
"""Simple actor network with MLP.
|
||||
@ -118,5 +120,5 @@ class DQN(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||
x = to_torch(x, device=self.device, dtype=torch.float32)
|
||||
return self.net(x), state
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user