type check in unit test (#200)
Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml. Also remove the out-of-the-date api
This commit is contained in:
parent
c91def6cbc
commit
b284ace102
5
.github/workflows/lint_and_docs.yml
vendored
5
.github/workflows/lint_and_docs.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: PEP8 and Docs Check
|
name: PEP8, Types and Docs Check
|
||||||
|
|
||||||
on: [push, pull_request]
|
on: [push, pull_request]
|
||||||
|
|
||||||
@ -20,6 +20,9 @@ jobs:
|
|||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
flake8 . --count --show-source --statistics
|
flake8 . --count --show-source --statistics
|
||||||
|
- name: Type check
|
||||||
|
run: |
|
||||||
|
mypy
|
||||||
- name: Documentation test
|
- name: Documentation test
|
||||||
run: |
|
run: |
|
||||||
pydocstyle tianshou
|
pydocstyle tianshou
|
||||||
|
|||||||
@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run:
|
|||||||
$ flake8 . --count --show-source --statistics
|
$ flake8 . --count --show-source --statistics
|
||||||
|
|
||||||
|
|
||||||
|
Type Check
|
||||||
|
----------
|
||||||
|
|
||||||
|
We use `mypy <https://github.com/python/mypy/>`_ to check the type annotations. To check, in the main directory, run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ mypy
|
||||||
|
|
||||||
|
|
||||||
Test Locally
|
Test Locally
|
||||||
------------
|
------------
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ Here is Tianshou's other features:
|
|||||||
* Support :ref:`customize_training`
|
* Support :ref:`customize_training`
|
||||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||||
* Support :doc:`/tutorials/tictactoe`
|
* Support :doc:`/tutorials/tictactoe`
|
||||||
|
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
|
||||||
|
|
||||||
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||||
|
|
||||||
|
|||||||
23
setup.cfg
23
setup.cfg
@ -1,3 +1,26 @@
|
|||||||
|
[mypy]
|
||||||
|
files = tianshou/**/*.py
|
||||||
|
allow_redefinition = True
|
||||||
|
check_untyped_defs = True
|
||||||
|
disallow_incomplete_defs = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
ignore_missing_imports = True
|
||||||
|
no_implicit_optional = True
|
||||||
|
pretty = True
|
||||||
|
show_error_codes = True
|
||||||
|
show_error_context = True
|
||||||
|
show_traceback = True
|
||||||
|
strict_equality = True
|
||||||
|
strict_optional = True
|
||||||
|
warn_no_return = True
|
||||||
|
warn_redundant_casts = True
|
||||||
|
warn_unreachable = True
|
||||||
|
warn_unused_configs = True
|
||||||
|
warn_unused_ignores = True
|
||||||
|
|
||||||
|
[mypy-tianshou.utils.net.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[pydocstyle]
|
[pydocstyle]
|
||||||
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402
|
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402
|
||||||
|
|
||||||
|
|||||||
@ -100,11 +100,6 @@ def test_collect_ep(data):
|
|||||||
data["collector"].collect(n_episode=10)
|
data["collector"].collect(n_episode=10)
|
||||||
|
|
||||||
|
|
||||||
def test_sample(data):
|
|
||||||
for _ in range(5000):
|
|
||||||
data["collector"].sample(256)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_vec_env(data):
|
def test_init_vec_env(data):
|
||||||
for _ in range(5000):
|
for _ in range(5000):
|
||||||
Collector(data["policy"], data["env_vec"], data["buffer"])
|
Collector(data["policy"], data["env_vec"], data["buffer"])
|
||||||
@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data):
|
|||||||
data["collector_vec"].collect(n_episode=10)
|
data["collector_vec"].collect(n_episode=10)
|
||||||
|
|
||||||
|
|
||||||
def test_sample_vec_env(data):
|
|
||||||
for _ in range(5000):
|
|
||||||
data["collector_vec"].sample(256)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_subproc_env(data):
|
def test_init_subproc_env(data):
|
||||||
for _ in range(5000):
|
for _ in range(5000):
|
||||||
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
|
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
|
||||||
@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data):
|
|||||||
data["collector_subproc"].collect(n_episode=10)
|
data["collector_subproc"].collect(n_episode=10)
|
||||||
|
|
||||||
|
|
||||||
def test_sample_subproc_env(data):
|
|
||||||
for _ in range(5000):
|
|
||||||
data["collector_subproc"].sample(256)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])
|
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer, exploration
|
from tianshou import data, env, utils, policy, trainer, exploration
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.2.7"
|
__version__ = "0.3.0rc0"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||||||
from numbers import Number
|
from numbers import Number
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
||||||
Sequence, KeysView, ValuesView, ItemsView
|
Sequence
|
||||||
|
|
||||||
# Disable pickle warning related to torch, since it has been removed
|
# Disable pickle warning related to torch, since it has been removed
|
||||||
# on torch master branch. See Pull Request #39003 for details:
|
# on torch master branch. See Pull Request #39003 for details:
|
||||||
@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
|
|||||||
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
|
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
|
||||||
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
|
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
|
||||||
try:
|
try:
|
||||||
return torch.stack(v)
|
return torch.stack(v) # type: ignore
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise TypeError("Batch does not support non-stackable iterable"
|
raise TypeError("Batch does not support non-stackable iterable"
|
||||||
" of torch.Tensor as unique value yet.") from e
|
" of torch.Tensor as unique value yet.") from e
|
||||||
@ -191,12 +191,20 @@ class Batch:
|
|||||||
elif _is_batch_set(batch_dict):
|
elif _is_batch_set(batch_dict):
|
||||||
self.stack_(batch_dict)
|
self.stack_(batch_dict)
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.__init__(kwargs, copy=copy)
|
self.__init__(kwargs, copy=copy) # type: ignore
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any) -> None:
|
def __setattr__(self, key: str, value: Any) -> None:
|
||||||
"""Set self.key = value."""
|
"""Set self.key = value."""
|
||||||
self.__dict__[key] = _parse_value(value)
|
self.__dict__[key] = _parse_value(value)
|
||||||
|
|
||||||
|
def __getattr__(self, key: str) -> Any:
|
||||||
|
"""Return self.key. The "Any" return type is needed for mypy."""
|
||||||
|
return getattr(self.__dict__, key)
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Return key in self."""
|
||||||
|
return key in self.__dict__
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, Any]:
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
"""Pickling interface.
|
"""Pickling interface.
|
||||||
|
|
||||||
@ -215,11 +223,11 @@ class Batch:
|
|||||||
At this point, self is an empty Batch instance that has not been
|
At this point, self is an empty Batch instance that has not been
|
||||||
initialized, so it can safely be initialized by the pickle state.
|
initialized, so it can safely be initialized by the pickle state.
|
||||||
"""
|
"""
|
||||||
self.__init__(**state)
|
self.__init__(**state) # type: ignore
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
|
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
|
||||||
) -> Union["Batch", np.ndarray, torch.Tensor]:
|
) -> Any:
|
||||||
"""Return self[index]."""
|
"""Return self[index]."""
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
return self.__dict__[index]
|
return self.__dict__[index]
|
||||||
@ -245,7 +253,7 @@ class Batch:
|
|||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
self.__dict__[index] = value
|
self.__dict__[index] = value
|
||||||
return
|
return
|
||||||
if isinstance(value, (np.ndarray, torch.Tensor)):
|
if not isinstance(value, Batch):
|
||||||
raise ValueError("Batch does not supported tensor assignment. "
|
raise ValueError("Batch does not supported tensor assignment. "
|
||||||
"Use a compatible Batch or dict instead.")
|
"Use a compatible Batch or dict instead.")
|
||||||
if not set(value.keys()).issubset(self.__dict__.keys()):
|
if not set(value.keys()).issubset(self.__dict__.keys()):
|
||||||
@ -330,30 +338,6 @@ class Batch:
|
|||||||
s = self.__class__.__name__ + "()"
|
s = self.__class__.__name__ + "()"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def __contains__(self, key: str) -> bool:
|
|
||||||
"""Return key in self."""
|
|
||||||
return key in self.__dict__
|
|
||||||
|
|
||||||
def keys(self) -> KeysView[str]:
|
|
||||||
"""Return self.keys()."""
|
|
||||||
return self.__dict__.keys()
|
|
||||||
|
|
||||||
def values(self) -> ValuesView[Any]:
|
|
||||||
"""Return self.values()."""
|
|
||||||
return self.__dict__.values()
|
|
||||||
|
|
||||||
def items(self) -> ItemsView[str, Any]:
|
|
||||||
"""Return self.items()."""
|
|
||||||
return self.__dict__.items()
|
|
||||||
|
|
||||||
def get(self, k: str, d: Optional[Any] = None) -> Any:
|
|
||||||
"""Return self[k] if k in self else d. d defaults to None."""
|
|
||||||
return self.__dict__.get(k, d)
|
|
||||||
|
|
||||||
def pop(self, k: str, d: Optional[Any] = None) -> Any:
|
|
||||||
"""Return & remove self[k] if k in self else d. d defaults to None."""
|
|
||||||
return self.__dict__.pop(k, d)
|
|
||||||
|
|
||||||
def to_numpy(self) -> None:
|
def to_numpy(self) -> None:
|
||||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
@ -375,7 +359,6 @@ class Batch:
|
|||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
if dtype is not None and v.dtype != dtype or \
|
if dtype is not None and v.dtype != dtype or \
|
||||||
v.device.type != device.type or \
|
v.device.type != device.type or \
|
||||||
device.index is not None and \
|
|
||||||
device.index != v.device.index:
|
device.index != v.device.index:
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
v = v.type(dtype)
|
v = v.type(dtype)
|
||||||
@ -517,7 +500,7 @@ class Batch:
|
|||||||
return
|
return
|
||||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||||
if not self.is_empty():
|
if not self.is_empty():
|
||||||
batches = [self] + list(batches)
|
batches = [self] + batches
|
||||||
# collect non-empty keys
|
# collect non-empty keys
|
||||||
keys_map = [
|
keys_map = [
|
||||||
set(k for k, v in batch.items()
|
set(k for k, v in batch.items()
|
||||||
@ -672,8 +655,8 @@ class Batch:
|
|||||||
for v in self.__dict__.values():
|
for v in self.__dict__.values():
|
||||||
if isinstance(v, Batch) and v.is_empty(recurse=True):
|
if isinstance(v, Batch) and v.is_empty(recurse=True):
|
||||||
continue
|
continue
|
||||||
elif hasattr(v, "__len__") and (not isinstance(
|
elif hasattr(v, "__len__") and (
|
||||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0
|
isinstance(v, Batch) or v.ndim > 0
|
||||||
):
|
):
|
||||||
r.append(len(v))
|
r.append(len(v))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, Dict, Tuple, Union, Optional
|
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||||
|
|
||||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||||
from tianshou.data.batch import _create_value
|
from tianshou.data.batch import _create_value
|
||||||
@ -138,7 +138,7 @@ class ReplayBuffer:
|
|||||||
self._indices = np.arange(size)
|
self._indices = np.arange(size)
|
||||||
self.stack_num = stack_num
|
self.stack_num = stack_num
|
||||||
self._avail = sample_avail and stack_num > 1
|
self._avail = sample_avail and stack_num > 1
|
||||||
self._avail_index = []
|
self._avail_index: List[int] = []
|
||||||
self._save_s_ = not ignore_obs_next
|
self._save_s_ = not ignore_obs_next
|
||||||
self._last_obs = save_only_last_obs
|
self._last_obs = save_only_last_obs
|
||||||
self._index = 0
|
self._index = 0
|
||||||
@ -175,12 +175,12 @@ class ReplayBuffer:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
||||||
value = self._meta.__dict__[name]
|
value = self._meta.__dict__[name]
|
||||||
if isinstance(inst, (torch.Tensor, np.ndarray)) \
|
if isinstance(inst, (torch.Tensor, np.ndarray)):
|
||||||
and inst.shape != value.shape[1:]:
|
if inst.shape != value.shape[1:]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot add data to a buffer with different shape, with key "
|
"Cannot add data to a buffer with different shape with key"
|
||||||
f"{name}, expect {value.shape[1:]}, given {inst.shape}."
|
f" {name}, expect {value.shape[1:]}, given {inst.shape}."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -205,7 +205,7 @@ class ReplayBuffer:
|
|||||||
stack_num_orig = buffer.stack_num
|
stack_num_orig = buffer.stack_num
|
||||||
buffer.stack_num = 1
|
buffer.stack_num = 1
|
||||||
while True:
|
while True:
|
||||||
self.add(**buffer[i])
|
self.add(**buffer[i]) # type: ignore
|
||||||
i = (i + 1) % len(buffer)
|
i = (i + 1) % len(buffer)
|
||||||
if i == begin:
|
if i == begin:
|
||||||
break
|
break
|
||||||
@ -323,7 +323,7 @@ class ReplayBuffer:
|
|||||||
try:
|
try:
|
||||||
if stack_num == 1:
|
if stack_num == 1:
|
||||||
return val[indice]
|
return val[indice]
|
||||||
stack = []
|
stack: List[Any] = []
|
||||||
for _ in range(stack_num):
|
for _ in range(stack_num):
|
||||||
stack = [val[indice]] + stack
|
stack = [val[indice]] + stack
|
||||||
pre_indice = np.asarray(indice - 1)
|
pre_indice = np.asarray(indice - 1)
|
||||||
|
|||||||
@ -212,10 +212,8 @@ class Collector(object):
|
|||||||
finished_env_ids = []
|
finished_env_ids = []
|
||||||
reward_total = 0.0
|
reward_total = 0.0
|
||||||
whole_data = Batch()
|
whole_data = Batch()
|
||||||
list_n_episode = False
|
if isinstance(n_episode, list):
|
||||||
if n_episode is not None and not np.isscalar(n_episode):
|
|
||||||
assert len(n_episode) == self.get_env_num()
|
assert len(n_episode) == self.get_env_num()
|
||||||
list_n_episode = True
|
|
||||||
finished_env_ids = [
|
finished_env_ids = [
|
||||||
i for i in self._ready_env_ids if n_episode[i] <= 0]
|
i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||||
self._ready_env_ids = np.array(
|
self._ready_env_ids = np.array(
|
||||||
@ -266,7 +264,8 @@ class Collector(object):
|
|||||||
self.data.policy._state = self.data.state
|
self.data.policy._state = self.data.state
|
||||||
|
|
||||||
self.data.act = to_numpy(result.act)
|
self.data.act = to_numpy(result.act)
|
||||||
if self._action_noise is not None: # noqa
|
if self._action_noise is not None:
|
||||||
|
assert isinstance(self.data.act, np.ndarray)
|
||||||
self.data.act += self._action_noise(self.data.act.shape)
|
self.data.act += self._action_noise(self.data.act.shape)
|
||||||
|
|
||||||
# step in env
|
# step in env
|
||||||
@ -291,7 +290,7 @@ class Collector(object):
|
|||||||
|
|
||||||
# add data into the buffer
|
# add data into the buffer
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
result = self.preprocess_fn(**self.data)
|
result = self.preprocess_fn(**self.data) # type: ignore
|
||||||
self.data.update(result)
|
self.data.update(result)
|
||||||
|
|
||||||
for j, i in enumerate(self._ready_env_ids):
|
for j, i in enumerate(self._ready_env_ids):
|
||||||
@ -305,14 +304,14 @@ class Collector(object):
|
|||||||
self._cached_buf[i].add(**self.data[j])
|
self._cached_buf[i].add(**self.data[j])
|
||||||
|
|
||||||
if done[j]:
|
if done[j]:
|
||||||
if not (list_n_episode and
|
if not (isinstance(n_episode, list)
|
||||||
episode_count[i] >= n_episode[i]):
|
and episode_count[i] >= n_episode[i]):
|
||||||
episode_count[i] += 1
|
episode_count[i] += 1
|
||||||
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
|
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
|
||||||
step_count += len(self._cached_buf[i])
|
step_count += len(self._cached_buf[i])
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.update(self._cached_buf[i])
|
self.buffer.update(self._cached_buf[i])
|
||||||
if list_n_episode and \
|
if isinstance(n_episode, list) and \
|
||||||
episode_count[i] >= n_episode[i]:
|
episode_count[i] >= n_episode[i]:
|
||||||
# env i has collected enough data, it has finished
|
# env i has collected enough data, it has finished
|
||||||
finished_env_ids.append(i)
|
finished_env_ids.append(i)
|
||||||
@ -324,10 +323,9 @@ class Collector(object):
|
|||||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||||
obs_reset = self.env.reset(env_ind_global)
|
obs_reset = self.env.reset(env_ind_global)
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs_next[env_ind_local] = self.preprocess_fn(
|
obs_reset = self.preprocess_fn(
|
||||||
obs=obs_reset).get("obs", obs_reset)
|
obs=obs_reset).get("obs", obs_reset)
|
||||||
else:
|
obs_next[env_ind_local] = obs_reset
|
||||||
obs_next[env_ind_local] = obs_reset
|
|
||||||
self.data.obs = obs_next
|
self.data.obs = obs_next
|
||||||
if is_async:
|
if is_async:
|
||||||
# set data back
|
# set data back
|
||||||
@ -362,7 +360,7 @@ class Collector(object):
|
|||||||
# average reward across the number of episodes
|
# average reward across the number of episodes
|
||||||
reward_avg = reward_total / episode_count
|
reward_avg = reward_total / episode_count
|
||||||
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
|
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
|
||||||
reward_avg = self._rew_metric(reward_avg)
|
reward_avg = self._rew_metric(reward_avg) # type: ignore
|
||||||
return {
|
return {
|
||||||
"n/ep": episode_count,
|
"n/ep": episode_count,
|
||||||
"n/st": step_count,
|
"n/st": step_count,
|
||||||
@ -372,30 +370,6 @@ class Collector(object):
|
|||||||
"len": step_count / episode_count,
|
"len": step_count / episode_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> Batch:
|
|
||||||
"""Sample a data batch from the internal replay buffer.
|
|
||||||
|
|
||||||
It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
|
|
||||||
returning the final batch data.
|
|
||||||
|
|
||||||
:param int batch_size: ``0`` means it will extract all the data from
|
|
||||||
the buffer, otherwise it will extract the data with the given
|
|
||||||
batch_size.
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"Collector.sample is deprecated and will cause error if you use "
|
|
||||||
"prioritized experience replay! Collector.sample will be removed "
|
|
||||||
"upon version 0.3. Use policy.update instead!", Warning)
|
|
||||||
assert self.buffer is not None, "Cannot get sample from empty buffer!"
|
|
||||||
batch_data, indice = self.buffer.sample(batch_size)
|
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
|
||||||
return batch_data
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
warnings.warn(
|
|
||||||
"Collector.close is deprecated and will be removed upon version "
|
|
||||||
"0.3.", Warning)
|
|
||||||
|
|
||||||
|
|
||||||
def _batch_set_item(
|
def _batch_set_item(
|
||||||
source: Batch, indices: np.ndarray, target: Batch, size: int
|
source: Batch, indices: np.ndarray, target: Batch, size: int
|
||||||
|
|||||||
@ -45,14 +45,14 @@ def to_torch(
|
|||||||
if isinstance(x, np.ndarray) and issubclass(
|
if isinstance(x, np.ndarray) and issubclass(
|
||||||
x.dtype.type, (np.bool_, np.number)
|
x.dtype.type, (np.bool_, np.number)
|
||||||
): # most often case
|
): # most often case
|
||||||
x = torch.from_numpy(x).to(device)
|
x = torch.from_numpy(x).to(device) # type: ignore
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
x = x.type(dtype)
|
x = x.type(dtype)
|
||||||
return x
|
return x
|
||||||
elif isinstance(x, torch.Tensor): # second often case
|
elif isinstance(x, torch.Tensor): # second often case
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
x = x.type(dtype)
|
x = x.type(dtype)
|
||||||
return x.to(device)
|
return x.to(device) # type: ignore
|
||||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||||
return to_torch(np.asanyarray(x), dtype, device)
|
return to_torch(np.asanyarray(x), dtype, device)
|
||||||
elif isinstance(x, dict):
|
elif isinstance(x, dict):
|
||||||
|
|||||||
3
tianshou/env/__init__.py
vendored
3
tianshou/env/__init__.py
vendored
@ -1,11 +1,10 @@
|
|||||||
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
|
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \
|
||||||
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
|
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
|
||||||
from tianshou.env.maenv import MultiAgentEnv
|
from tianshou.env.maenv import MultiAgentEnv
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseVectorEnv",
|
"BaseVectorEnv",
|
||||||
"DummyVectorEnv",
|
"DummyVectorEnv",
|
||||||
"VectorEnv", # TODO: remove in later version
|
|
||||||
"SubprocVectorEnv",
|
"SubprocVectorEnv",
|
||||||
"ShmemVectorEnv",
|
"ShmemVectorEnv",
|
||||||
"RayVectorEnv",
|
"RayVectorEnv",
|
||||||
|
|||||||
29
tianshou/env/venvs.py
vendored
29
tianshou/env/venvs.py
vendored
@ -1,5 +1,4 @@
|
|||||||
import gym
|
import gym
|
||||||
import warnings
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Union, Optional, Callable
|
from typing import Any, List, Union, Optional, Callable
|
||||||
|
|
||||||
@ -84,12 +83,12 @@ class BaseVectorEnv(gym.Env):
|
|||||||
self.timeout is None or self.timeout > 0
|
self.timeout is None or self.timeout > 0
|
||||||
), f"timeout is {timeout}, it should be positive if provided!"
|
), f"timeout is {timeout}, it should be positive if provided!"
|
||||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||||
self.waiting_conn = []
|
self.waiting_conn: List[EnvWorker] = []
|
||||||
# environments in self.ready_id is actually ready
|
# environments in self.ready_id is actually ready
|
||||||
# but environments in self.waiting_id are just waiting when checked,
|
# but environments in self.waiting_id are just waiting when checked,
|
||||||
# and they may be ready now, but this is not known until we check it
|
# and they may be ready now, but this is not known until we check it
|
||||||
# in the step() function
|
# in the step() function
|
||||||
self.waiting_id = []
|
self.waiting_id: List[int] = []
|
||||||
# all environments are ready in the beginning
|
# all environments are ready in the beginning
|
||||||
self.ready_id = list(range(self.env_num))
|
self.ready_id = list(range(self.env_num))
|
||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
@ -216,10 +215,11 @@ class BaseVectorEnv(gym.Env):
|
|||||||
self.waiting_conn.append(self.workers[env_id])
|
self.waiting_conn.append(self.workers[env_id])
|
||||||
self.waiting_id.append(env_id)
|
self.waiting_id.append(env_id)
|
||||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||||
ready_conns, result = [], []
|
ready_conns: List[EnvWorker] = []
|
||||||
while not ready_conns:
|
while not ready_conns:
|
||||||
ready_conns = self.worker_class.wait(
|
ready_conns = self.worker_class.wait(
|
||||||
self.waiting_conn, self.wait_num, self.timeout)
|
self.waiting_conn, self.wait_num, self.timeout)
|
||||||
|
result = []
|
||||||
for conn in ready_conns:
|
for conn in ready_conns:
|
||||||
waiting_index = self.waiting_conn.index(conn)
|
waiting_index = self.waiting_conn.index(conn)
|
||||||
self.waiting_conn.pop(waiting_index)
|
self.waiting_conn.pop(waiting_index)
|
||||||
@ -243,11 +243,14 @@ class BaseVectorEnv(gym.Env):
|
|||||||
which a reproducer pass to "seed".
|
which a reproducer pass to "seed".
|
||||||
"""
|
"""
|
||||||
self._assert_is_not_closed()
|
self._assert_is_not_closed()
|
||||||
|
seed_list: Union[List[None], List[int]]
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = [seed] * self.env_num
|
seed_list = [seed] * self.env_num
|
||||||
elif np.isscalar(seed):
|
elif isinstance(seed, int):
|
||||||
seed = [seed + i for i in range(self.env_num)]
|
seed_list = [seed + i for i in range(self.env_num)]
|
||||||
return [w.seed(s) for w, s in zip(self.workers, seed)]
|
else:
|
||||||
|
seed_list = seed
|
||||||
|
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
|
||||||
|
|
||||||
def render(self, **kwargs: Any) -> List[Any]:
|
def render(self, **kwargs: Any) -> List[Any]:
|
||||||
"""Render all of the environments."""
|
"""Render all of the environments."""
|
||||||
@ -295,16 +298,6 @@ class DummyVectorEnv(BaseVectorEnv):
|
|||||||
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
class VectorEnv(DummyVectorEnv):
|
|
||||||
"""VectorEnv is renamed to DummyVectorEnv."""
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
warnings.warn(
|
|
||||||
"VectorEnv is renamed to DummyVectorEnv, and will be removed in "
|
|
||||||
"0.3. Use DummyVectorEnv instead!", Warning)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
"""Vectorized environment wrapper based on subprocess.
|
"""Vectorized environment wrapper based on subprocess.
|
||||||
|
|
||||||
|
|||||||
2
tianshou/env/worker/base.py
vendored
2
tianshou/env/worker/base.py
vendored
@ -10,7 +10,7 @@ class EnvWorker(ABC):
|
|||||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||||
self._env_fn = env_fn
|
self._env_fn = env_fn
|
||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
self.result = (None, None, None, None)
|
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getattr__(self, key: str) -> Any:
|
def __getattr__(self, key: str) -> Any:
|
||||||
|
|||||||
2
tianshou/env/worker/dummy.py
vendored
2
tianshou/env/worker/dummy.py
vendored
@ -19,7 +19,7 @@ class DummyEnvWorker(EnvWorker):
|
|||||||
return self.env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait(
|
def wait( # type: ignore
|
||||||
workers: List["DummyEnvWorker"],
|
workers: List["DummyEnvWorker"],
|
||||||
wait_num: int,
|
wait_num: int,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
|||||||
2
tianshou/env/worker/ray.py
vendored
2
tianshou/env/worker/ray.py
vendored
@ -24,7 +24,7 @@ class RayEnvWorker(EnvWorker):
|
|||||||
return ray.get(self.env.reset.remote())
|
return ray.get(self.env.reset.remote())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait(
|
def wait( # type: ignore
|
||||||
workers: List["RayEnvWorker"],
|
workers: List["RayEnvWorker"],
|
||||||
wait_num: int,
|
wait_num: int,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
|||||||
117
tianshou/env/worker/subproc.py
vendored
117
tianshou/env/worker/subproc.py
vendored
@ -11,22 +11,71 @@ from tianshou.env.worker import EnvWorker
|
|||||||
from tianshou.env.utils import CloudpickleWrapper
|
from tianshou.env.utils import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
||||||
|
_NP_TO_CT = {
|
||||||
|
np.bool: ctypes.c_bool,
|
||||||
|
np.bool_: ctypes.c_bool,
|
||||||
|
np.uint8: ctypes.c_uint8,
|
||||||
|
np.uint16: ctypes.c_uint16,
|
||||||
|
np.uint32: ctypes.c_uint32,
|
||||||
|
np.uint64: ctypes.c_uint64,
|
||||||
|
np.int8: ctypes.c_int8,
|
||||||
|
np.int16: ctypes.c_int16,
|
||||||
|
np.int32: ctypes.c_int32,
|
||||||
|
np.int64: ctypes.c_int64,
|
||||||
|
np.float32: ctypes.c_float,
|
||||||
|
np.float64: ctypes.c_double,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ShArray:
|
||||||
|
"""Wrapper of multiprocessing Array."""
|
||||||
|
|
||||||
|
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||||
|
self.arr = Array(
|
||||||
|
_NP_TO_CT[dtype.type], # type: ignore
|
||||||
|
int(np.prod(shape)),
|
||||||
|
)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def save(self, ndarray: np.ndarray) -> None:
|
||||||
|
assert isinstance(ndarray, np.ndarray)
|
||||||
|
dst = self.arr.get_obj()
|
||||||
|
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||||
|
np.copyto(dst_np, ndarray)
|
||||||
|
|
||||||
|
def get(self) -> np.ndarray:
|
||||||
|
obj = self.arr.get_obj()
|
||||||
|
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||||
|
if isinstance(space, gym.spaces.Dict):
|
||||||
|
assert isinstance(space.spaces, OrderedDict)
|
||||||
|
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||||
|
elif isinstance(space, gym.spaces.Tuple):
|
||||||
|
assert isinstance(space.spaces, tuple)
|
||||||
|
return tuple([_setup_buf(t) for t in space.spaces])
|
||||||
|
else:
|
||||||
|
return ShArray(space.dtype, space.shape)
|
||||||
|
|
||||||
|
|
||||||
def _worker(
|
def _worker(
|
||||||
parent: connection.Connection,
|
parent: connection.Connection,
|
||||||
p: connection.Connection,
|
p: connection.Connection,
|
||||||
env_fn_wrapper: CloudpickleWrapper,
|
env_fn_wrapper: CloudpickleWrapper,
|
||||||
obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None,
|
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
def _encode_obs(
|
def _encode_obs(
|
||||||
obs: Union[dict, tuple, np.ndarray],
|
obs: Union[dict, tuple, np.ndarray],
|
||||||
buffer: Union[dict, tuple, ShArray],
|
buffer: Union[dict, tuple, ShArray],
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(obs, np.ndarray):
|
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
|
||||||
buffer.save(obs)
|
buffer.save(obs)
|
||||||
elif isinstance(obs, tuple):
|
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
|
||||||
for o, b in zip(obs, buffer):
|
for o, b in zip(obs, buffer):
|
||||||
_encode_obs(o, b)
|
_encode_obs(o, b)
|
||||||
elif isinstance(obs, dict):
|
elif isinstance(obs, dict) and isinstance(buffer, dict):
|
||||||
for k in obs.keys():
|
for k in obs.keys():
|
||||||
_encode_obs(obs[k], buffer[k])
|
_encode_obs(obs[k], buffer[k])
|
||||||
return None
|
return None
|
||||||
@ -69,52 +118,6 @@ def _worker(
|
|||||||
p.close()
|
p.close()
|
||||||
|
|
||||||
|
|
||||||
_NP_TO_CT = {
|
|
||||||
np.bool: ctypes.c_bool,
|
|
||||||
np.bool_: ctypes.c_bool,
|
|
||||||
np.uint8: ctypes.c_uint8,
|
|
||||||
np.uint16: ctypes.c_uint16,
|
|
||||||
np.uint32: ctypes.c_uint32,
|
|
||||||
np.uint64: ctypes.c_uint64,
|
|
||||||
np.int8: ctypes.c_int8,
|
|
||||||
np.int16: ctypes.c_int16,
|
|
||||||
np.int32: ctypes.c_int32,
|
|
||||||
np.int64: ctypes.c_int64,
|
|
||||||
np.float32: ctypes.c_float,
|
|
||||||
np.float64: ctypes.c_double,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ShArray:
|
|
||||||
"""Wrapper of multiprocessing Array."""
|
|
||||||
|
|
||||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
|
||||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
|
||||||
self.dtype = dtype
|
|
||||||
self.shape = shape
|
|
||||||
|
|
||||||
def save(self, ndarray: np.ndarray) -> None:
|
|
||||||
assert isinstance(ndarray, np.ndarray)
|
|
||||||
dst = self.arr.get_obj()
|
|
||||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
|
||||||
np.copyto(dst_np, ndarray)
|
|
||||||
|
|
||||||
def get(self) -> np.ndarray:
|
|
||||||
obj = self.arr.get_obj()
|
|
||||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
|
||||||
if isinstance(space, gym.spaces.Dict):
|
|
||||||
assert isinstance(space.spaces, OrderedDict)
|
|
||||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
|
||||||
elif isinstance(space, gym.spaces.Tuple):
|
|
||||||
assert isinstance(space.spaces, tuple)
|
|
||||||
return tuple([_setup_buf(t) for t in space.spaces])
|
|
||||||
else:
|
|
||||||
return ShArray(space.dtype, space.shape)
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocEnvWorker(EnvWorker):
|
class SubprocEnvWorker(EnvWorker):
|
||||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||||
|
|
||||||
@ -124,7 +127,7 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
super().__init__(env_fn)
|
super().__init__(env_fn)
|
||||||
self.parent_remote, self.child_remote = Pipe()
|
self.parent_remote, self.child_remote = Pipe()
|
||||||
self.share_memory = share_memory
|
self.share_memory = share_memory
|
||||||
self.buffer = None
|
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
|
||||||
if self.share_memory:
|
if self.share_memory:
|
||||||
dummy = env_fn()
|
dummy = env_fn()
|
||||||
obs_space = dummy.observation_space
|
obs_space = dummy.observation_space
|
||||||
@ -168,25 +171,23 @@ class SubprocEnvWorker(EnvWorker):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait(
|
def wait( # type: ignore
|
||||||
workers: List["SubprocEnvWorker"],
|
workers: List["SubprocEnvWorker"],
|
||||||
wait_num: int,
|
wait_num: int,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
) -> List["SubprocEnvWorker"]:
|
) -> List["SubprocEnvWorker"]:
|
||||||
conns, ready_conns = [x.parent_remote for x in workers], []
|
remain_conns = conns = [x.parent_remote for x in workers]
|
||||||
remain_conns = conns
|
ready_conns: List[connection.Connection] = []
|
||||||
t1 = time.time()
|
remain_time, t1 = timeout, time.time()
|
||||||
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
||||||
if timeout:
|
if timeout:
|
||||||
remain_time = timeout - (time.time() - t1)
|
remain_time = timeout - (time.time() - t1)
|
||||||
if remain_time <= 0:
|
if remain_time <= 0:
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
remain_time = timeout
|
|
||||||
# connection.wait hangs if the list is empty
|
# connection.wait hangs if the list is empty
|
||||||
new_ready_conns = connection.wait(
|
new_ready_conns = connection.wait(
|
||||||
remain_conns, timeout=remain_time)
|
remain_conns, timeout=remain_time)
|
||||||
ready_conns.extend(new_ready_conns)
|
ready_conns.extend(new_ready_conns) # type: ignore
|
||||||
remain_conns = [
|
remain_conns = [
|
||||||
conn for conn in remain_conns if conn not in ready_conns]
|
conn for conn in remain_conns if conn not in ready_conns]
|
||||||
return [workers[conns.index(con)] for con in ready_conns]
|
return [workers[conns.index(con)] for con in ready_conns]
|
||||||
|
|||||||
@ -9,15 +9,15 @@ class BaseNoise(ABC, object):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset to the initial state."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, size: Sequence[int]) -> np.ndarray:
|
def __call__(self, size: Sequence[int]) -> np.ndarray:
|
||||||
"""Generate new noise."""
|
"""Generate new noise."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset to the initial state."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianNoise(BaseNoise):
|
class GaussianNoise(BaseNoise):
|
||||||
"""The vanilla gaussian process, for exploration in DDPG by default."""
|
"""The vanilla gaussian process, for exploration in DDPG by default."""
|
||||||
@ -64,6 +64,10 @@ class OUNoise(BaseNoise):
|
|||||||
self._x0 = x0
|
self._x0 = x0
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset to the initial state."""
|
||||||
|
self._x = self._x0
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, size: Sequence[int], mu: Optional[float] = None
|
self, size: Sequence[int], mu: Optional[float] = None
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
@ -71,14 +75,11 @@ class OUNoise(BaseNoise):
|
|||||||
|
|
||||||
Return an numpy array which size is equal to ``size``.
|
Return an numpy array which size is equal to ``size``.
|
||||||
"""
|
"""
|
||||||
if self._x is None or self._x.shape != size:
|
if self._x is None or isinstance(
|
||||||
|
self._x, np.ndarray) and self._x.shape != size:
|
||||||
self._x = 0.0
|
self._x = 0.0
|
||||||
if mu is None:
|
if mu is None:
|
||||||
mu = self._mu
|
mu = self._mu
|
||||||
r = self._beta * np.random.normal(size=size)
|
r = self._beta * np.random.normal(size=size)
|
||||||
self._x = self._x + self._alpha * (mu - self._x) + r
|
self._x = self._x + self._alpha * (mu - self._x) + r
|
||||||
return self._x
|
return self._x
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset to the initial state."""
|
|
||||||
self._x = self._x0
|
|
||||||
|
|||||||
@ -190,7 +190,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
array with shape (bsz, ).
|
array with shape (bsz, ).
|
||||||
"""
|
"""
|
||||||
rew = batch.rew
|
rew = batch.rew
|
||||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
|
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
|
||||||
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
||||||
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
||||||
returns = (returns - returns.mean()) / returns.std()
|
returns = (returns - returns.mean()) / returns.std()
|
||||||
|
|||||||
@ -55,11 +55,11 @@ class ImitationPolicy(BasePolicy):
|
|||||||
if self.mode == "continuous": # regression
|
if self.mode == "continuous": # regression
|
||||||
a = self(batch).act
|
a = self(batch).act
|
||||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||||
loss = F.mse_loss(a, a_)
|
loss = F.mse_loss(a, a_) # type: ignore
|
||||||
elif self.mode == "discrete": # classification
|
elif self.mode == "discrete": # classification
|
||||||
a = self(batch).logits
|
a = self(batch).logits
|
||||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||||
loss = F.nll_loss(a, a_)
|
loss = F.nll_loss(a, a_) # type: ignore
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
return {"loss": loss.item()}
|
return {"loss": loss.item()}
|
||||||
|
|||||||
@ -103,11 +103,11 @@ class A2CPolicy(PGPolicy):
|
|||||||
if isinstance(logits, tuple):
|
if isinstance(logits, tuple):
|
||||||
dist = self.dist_fn(*logits)
|
dist = self.dist_fn(*logits)
|
||||||
else:
|
else:
|
||||||
dist = self.dist_fn(logits)
|
dist = self.dist_fn(logits) # type: ignore
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(
|
def learn( # type: ignore
|
||||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
) -> Dict[str, List[float]]:
|
) -> Dict[str, List[float]]:
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
r = to_torch_as(b.returns, v)
|
r = to_torch_as(b.returns, v)
|
||||||
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
||||||
a_loss = -(log_prob * (r - v).detach()).mean()
|
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||||
vf_loss = F.mse_loss(r, v)
|
vf_loss = F.mse_loss(r, v) # type: ignore
|
||||||
ent_loss = dist.entropy().mean()
|
ent_loss = dist.entropy().mean()
|
||||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
@ -53,14 +53,16 @@ class DDPGPolicy(BasePolicy):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if actor is not None:
|
if actor is not None and actor_optim is not None:
|
||||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
self.actor: torch.nn.Module = actor
|
||||||
|
self.actor_old = deepcopy(actor)
|
||||||
self.actor_old.eval()
|
self.actor_old.eval()
|
||||||
self.actor_optim = actor_optim
|
self.actor_optim: torch.optim.Optimizer = actor_optim
|
||||||
if critic is not None:
|
if critic is not None and critic_optim is not None:
|
||||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
self.critic: torch.nn.Module = critic
|
||||||
|
self.critic_old = deepcopy(critic)
|
||||||
self.critic_old.eval()
|
self.critic_old.eval()
|
||||||
self.critic_optim = critic_optim
|
self.critic_optim: torch.optim.Optimizer = critic_optim
|
||||||
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
|
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
|
||||||
self._tau = tau
|
self._tau = tau
|
||||||
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
|
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
|
||||||
@ -141,7 +143,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
actions, h = model(obs, state=state, info=batch.info)
|
actions, h = model(obs, state=state, info=batch.info)
|
||||||
actions += self._action_bias
|
actions += self._action_bias
|
||||||
if self.training and explorating:
|
if self._noise and self.training and explorating:
|
||||||
actions += to_torch_as(self._noise(actions.shape), actions)
|
actions += to_torch_as(self._noise(actions.shape), actions)
|
||||||
actions = actions.clamp(self._range[0], self._range[1])
|
actions = actions.clamp(self._range[0], self._range[1])
|
||||||
return Batch(act=actions, state=h)
|
return Batch(act=actions, state=h)
|
||||||
|
|||||||
@ -146,11 +146,10 @@ class DQNPolicy(BasePolicy):
|
|||||||
obs = getattr(batch, input)
|
obs = getattr(batch, input)
|
||||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||||
q, h = model(obs_, state=state, info=batch.info)
|
q, h = model(obs_, state=state, info=batch.info)
|
||||||
act = to_numpy(q.max(dim=1)[1])
|
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||||
has_mask = hasattr(obs, 'mask')
|
if hasattr(obs, "mask"):
|
||||||
if has_mask:
|
|
||||||
# some of actions are masked, they cannot be selected
|
# some of actions are masked, they cannot be selected
|
||||||
q_ = to_numpy(q)
|
q_: np.ndarray = to_numpy(q)
|
||||||
q_[~obs.mask] = -np.inf
|
q_[~obs.mask] = -np.inf
|
||||||
act = q_.argmax(axis=1)
|
act = q_.argmax(axis=1)
|
||||||
# add eps to act
|
# add eps to act
|
||||||
@ -160,7 +159,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
for i in range(len(q)):
|
for i in range(len(q)):
|
||||||
if np.random.rand() < eps:
|
if np.random.rand() < eps:
|
||||||
q_ = np.random.rand(*q[i].shape)
|
q_ = np.random.rand(*q[i].shape)
|
||||||
if has_mask:
|
if hasattr(obs, "mask"):
|
||||||
q_[~obs.mask[i]] = -np.inf
|
q_[~obs.mask[i]] = -np.inf
|
||||||
act[i] = q_.argmax()
|
act[i] = q_.argmax()
|
||||||
return Batch(logits=q, act=act, state=h)
|
return Batch(logits=q, act=act, state=h)
|
||||||
@ -172,7 +171,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
weight = batch.pop("weight", 1.0)
|
weight = batch.pop("weight", 1.0)
|
||||||
q = self(batch, eps=0.0).logits
|
q = self(batch, eps=0.0).logits
|
||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = to_torch_as(batch.returns, q).flatten()
|
r = to_torch_as(batch.returns.flatten(), q)
|
||||||
td = r - q
|
td = r - q
|
||||||
loss = (td.pow(2) * weight).mean()
|
loss = (td.pow(2) * weight).mean()
|
||||||
batch.weight = td # prio-buffer
|
batch.weight = td # prio-buffer
|
||||||
|
|||||||
@ -32,7 +32,8 @@ class PGPolicy(BasePolicy):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.model = model
|
if model is not None:
|
||||||
|
self.model: torch.nn.Module = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
assert (
|
assert (
|
||||||
@ -81,11 +82,11 @@ class PGPolicy(BasePolicy):
|
|||||||
if isinstance(logits, tuple):
|
if isinstance(logits, tuple):
|
||||||
dist = self.dist_fn(*logits)
|
dist = self.dist_fn(*logits)
|
||||||
else:
|
else:
|
||||||
dist = self.dist_fn(logits)
|
dist = self.dist_fn(logits) # type: ignore
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(
|
def learn( # type: ignore
|
||||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
) -> Dict[str, List[float]]:
|
) -> Dict[str, List[float]]:
|
||||||
losses = []
|
losses = []
|
||||||
|
|||||||
@ -137,13 +137,13 @@ class PPOPolicy(PGPolicy):
|
|||||||
if isinstance(logits, tuple):
|
if isinstance(logits, tuple):
|
||||||
dist = self.dist_fn(*logits)
|
dist = self.dist_fn(*logits)
|
||||||
else:
|
else:
|
||||||
dist = self.dist_fn(logits)
|
dist = self.dist_fn(logits) # type: ignore
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
if self._range:
|
if self._range:
|
||||||
act = act.clamp(self._range[0], self._range[1])
|
act = act.clamp(self._range[0], self._range[1])
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(
|
def learn( # type: ignore
|
||||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
) -> Dict[str, List[float]]:
|
) -> Dict[str, List[float]]:
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
@ -157,8 +157,9 @@ class PPOPolicy(PGPolicy):
|
|||||||
surr2 = ratio.clamp(1.0 - self._eps_clip,
|
surr2 = ratio.clamp(1.0 - self._eps_clip,
|
||||||
1.0 + self._eps_clip) * b.adv
|
1.0 + self._eps_clip) * b.adv
|
||||||
if self._dual_clip:
|
if self._dual_clip:
|
||||||
clip_loss = -torch.max(torch.min(surr1, surr2),
|
clip_loss = -torch.max(
|
||||||
self._dual_clip * b.adv).mean()
|
torch.min(surr1, surr2), self._dual_clip * b.adv
|
||||||
|
).mean()
|
||||||
else:
|
else:
|
||||||
clip_loss = -torch.min(surr1, surr2).mean()
|
clip_loss = -torch.min(surr1, surr2).mean()
|
||||||
clip_losses.append(clip_loss.item())
|
clip_losses.append(clip_loss.item())
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class SACPolicy(DDPGPolicy):
|
|||||||
):
|
):
|
||||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||||
|
|
||||||
def forward(
|
def forward( # type: ignore
|
||||||
self,
|
self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
@ -193,5 +193,5 @@ class SACPolicy(DDPGPolicy):
|
|||||||
}
|
}
|
||||||
if self._is_auto_alpha:
|
if self._is_auto_alpha:
|
||||||
result["loss/alpha"] = alpha_loss.item()
|
result["loss/alpha"] = alpha_loss.item()
|
||||||
result["v/alpha"] = self._alpha.item()
|
result["v/alpha"] = self._alpha.item() # type: ignore
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -73,7 +73,7 @@ def offpolicy_trainer(
|
|||||||
"""
|
"""
|
||||||
global_step = 0
|
global_step = 0
|
||||||
best_epoch, best_reward = -1, -1.0
|
best_epoch, best_reward = -1, -1.0
|
||||||
stat = {}
|
stat: Dict[str, MovAvg] = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
@ -91,7 +91,7 @@ def offpolicy_trainer(
|
|||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test, writer, global_step)
|
epoch, episode_per_test, writer, global_step)
|
||||||
if stop_fn and stop_fn(test_result["rew"]):
|
if stop_fn(test_result["rew"]):
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
|
|||||||
@ -73,7 +73,7 @@ def onpolicy_trainer(
|
|||||||
"""
|
"""
|
||||||
global_step = 0
|
global_step = 0
|
||||||
best_epoch, best_reward = -1, -1.0
|
best_epoch, best_reward = -1, -1.0
|
||||||
stat = {}
|
stat: Dict[str, MovAvg] = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
@ -91,7 +91,7 @@ def onpolicy_trainer(
|
|||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test, writer, global_step)
|
epoch, episode_per_test, writer, global_step)
|
||||||
if stop_fn and stop_fn(test_result["rew"]):
|
if stop_fn(test_result["rew"]):
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
@ -109,9 +109,9 @@ def onpolicy_trainer(
|
|||||||
batch_size=batch_size, repeat=repeat_per_collect)
|
batch_size=batch_size, repeat=repeat_per_collect)
|
||||||
train_collector.reset_buffer()
|
train_collector.reset_buffer()
|
||||||
step = 1
|
step = 1
|
||||||
for k in losses.keys():
|
for v in losses.values():
|
||||||
if isinstance(losses[k], list):
|
if isinstance(v, list):
|
||||||
step = max(step, len(losses[k]))
|
step = max(step, len(v))
|
||||||
global_step += step * collect_per_step
|
global_step += step * collect_per_step
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
data[k] = f"{result[k]:.2f}"
|
data[k] = f"{result[k]:.2f}"
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def test_episode(
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
if test_fn:
|
if test_fn:
|
||||||
test_fn(epoch)
|
test_fn(epoch)
|
||||||
if collector.get_env_num() > 1 and np.isscalar(n_episode):
|
if collector.get_env_num() > 1 and isinstance(n_episode, int):
|
||||||
n = collector.get_env_num()
|
n = collector.get_env_num()
|
||||||
n_ = np.zeros(n) + n_episode // n
|
n_ = np.zeros(n) + n_episode // n
|
||||||
n_[:n_episode % n] += 1
|
n_[:n_episode % n] += 1
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Union
|
from typing import List, Union
|
||||||
|
|
||||||
from tianshou.data import to_numpy
|
from tianshou.data import to_numpy
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class MovAvg(object):
|
|||||||
def __init__(self, size: int = 100) -> None:
|
def __init__(self, size: int = 100) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.size = size
|
self.size = size
|
||||||
self.cache = []
|
self.cache: List[Union[Number, np.number]] = []
|
||||||
self.banned = [np.inf, np.nan, -np.inf]
|
self.banned = [np.inf, np.nan, -np.inf]
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
|
|||||||
@ -12,7 +12,7 @@ def miniblock(
|
|||||||
norm_layer: Optional[Callable[[int], nn.modules.Module]],
|
norm_layer: Optional[Callable[[int], nn.modules.Module]],
|
||||||
) -> List[nn.modules.Module]:
|
) -> List[nn.modules.Module]:
|
||||||
"""Construct a miniblock with given input/output-size and norm layer."""
|
"""Construct a miniblock with given input/output-size and norm layer."""
|
||||||
ret = [nn.Linear(inp, oup)]
|
ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
|
||||||
if norm_layer is not None:
|
if norm_layer is not None:
|
||||||
ret += [norm_layer(oup)]
|
ret += [norm_layer(oup)]
|
||||||
ret += [nn.ReLU(inplace=True)]
|
ret += [nn.ReLU(inplace=True)]
|
||||||
@ -54,36 +54,33 @@ class Net(nn.Module):
|
|||||||
if concat:
|
if concat:
|
||||||
input_size += np.prod(action_shape)
|
input_size += np.prod(action_shape)
|
||||||
|
|
||||||
self.model = miniblock(input_size, hidden_layer_size, norm_layer)
|
model = miniblock(input_size, hidden_layer_size, norm_layer)
|
||||||
|
|
||||||
for i in range(layer_num):
|
for i in range(layer_num):
|
||||||
self.model += miniblock(hidden_layer_size,
|
model += miniblock(
|
||||||
hidden_layer_size, norm_layer)
|
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||||
|
|
||||||
if self.dueling is None:
|
if dueling is None:
|
||||||
if action_shape and not concat:
|
if action_shape and not concat:
|
||||||
self.model += [nn.Linear(hidden_layer_size,
|
model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||||
np.prod(action_shape))]
|
|
||||||
else: # dueling DQN
|
else: # dueling DQN
|
||||||
assert isinstance(self.dueling, tuple) and len(self.dueling) == 2
|
q_layer_num, v_layer_num = dueling
|
||||||
|
Q, V = [], []
|
||||||
q_layer_num, v_layer_num = self.dueling
|
|
||||||
self.Q, self.V = [], []
|
|
||||||
|
|
||||||
for i in range(q_layer_num):
|
for i in range(q_layer_num):
|
||||||
self.Q += miniblock(hidden_layer_size,
|
Q += miniblock(
|
||||||
hidden_layer_size, norm_layer)
|
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||||
for i in range(v_layer_num):
|
for i in range(v_layer_num):
|
||||||
self.V += miniblock(hidden_layer_size,
|
V += miniblock(
|
||||||
hidden_layer_size, norm_layer)
|
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||||
|
|
||||||
if action_shape and not concat:
|
if action_shape and not concat:
|
||||||
self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||||
self.V += [nn.Linear(hidden_layer_size, 1)]
|
V += [nn.Linear(hidden_layer_size, 1)]
|
||||||
|
|
||||||
self.Q = nn.Sequential(*self.Q)
|
self.Q = nn.Sequential(*Q)
|
||||||
self.V = nn.Sequential(*self.V)
|
self.V = nn.Sequential(*V)
|
||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*model)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class Critic(nn.Module):
|
|||||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
s = s.flatten(1)
|
s = s.flatten(1)
|
||||||
if a is not None:
|
if a is not None:
|
||||||
a = to_torch(a, device=self.device, dtype=torch.float32)
|
a = to_torch_as(a, s)
|
||||||
a = a.flatten(1)
|
a = a.flatten(1)
|
||||||
s = torch.cat([s, a], dim=1)
|
s = torch.cat([s, a], dim=1)
|
||||||
logits, h = self.preprocess(s)
|
logits, h = self.preprocess(s)
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from torch import nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||||
|
|
||||||
|
from tianshou.data import to_torch
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class Actor(nn.Module):
|
||||||
"""Simple actor network with MLP.
|
"""Simple actor network with MLP.
|
||||||
@ -118,5 +120,5 @@ class DQN(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
r"""Mapping: x -> Q(x, \*)."""
|
r"""Mapping: x -> Q(x, \*)."""
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
x = to_torch(x, device=self.device, dtype=torch.float32)
|
||||||
return self.net(x), state
|
return self.net(x), state
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user