Support different state size and fix exception in venv.__del__ (#352)
- Batch: do not raise error when it finds list of np.array with different shape[0]. - Venv's obs: add try...except block for np.stack(obs_list) - remove venv.__del__ since it is buggy
This commit is contained in:
parent
bbc3c3e32d
commit
ff4d3cd714
4
.github/workflows/extra_sys.yml
vendored
4
.github/workflows/extra_sys.yml
vendored
@ -9,7 +9,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest, windows-latest]
|
os: [macos-latest, windows-latest]
|
||||||
python-version: [3.6, 3.7, 3.8]
|
python-version: [3.7, 3.8]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
@ -24,4 +24,4 @@ jobs:
|
|||||||
python -m pip install ".[dev]" --upgrade
|
python -m pip install ".[dev]" --upgrade
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
pytest test/base test/continuous --ignore-glob "*env.py" --cov=tianshou --durations=0 -v
|
pytest test/base test/continuous --cov=tianshou --durations=0 -v
|
||||||
|
2
docs/_static/js/benchmark.js
vendored
2
docs/_static/js/benchmark.js
vendored
@ -14,7 +14,7 @@ function showEnv(elem) {
|
|||||||
var dataSource = {
|
var dataSource = {
|
||||||
$schema: "https://vega.github.io/schema/vega-lite/v5.json",
|
$schema: "https://vega.github.io/schema/vega-lite/v5.json",
|
||||||
data: {
|
data: {
|
||||||
url: "/_static/js/mujoco/benchmark/" + selectEnv + "/result.json"
|
url: "/en/master/_static/js/mujoco/benchmark/" + selectEnv + "/result.json"
|
||||||
},
|
},
|
||||||
mark: "line",
|
mark: "line",
|
||||||
height: 400,
|
height: 400,
|
||||||
|
1
setup.py
1
setup.py
@ -62,6 +62,7 @@ setup(
|
|||||||
"pytest",
|
"pytest",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"ray>=1.0.0",
|
"ray>=1.0.0",
|
||||||
|
"networkx",
|
||||||
"mypy",
|
"mypy",
|
||||||
"pydocstyle",
|
"pydocstyle",
|
||||||
"doc8",
|
"doc8",
|
||||||
|
@ -2,6 +2,8 @@ import gym
|
|||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import networkx as nx
|
||||||
|
from copy import deepcopy
|
||||||
from gym.spaces import Discrete, MultiDiscrete, Box, Dict, Tuple
|
from gym.spaces import Discrete, MultiDiscrete, Box, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
@ -107,3 +109,30 @@ class MyTestEnv(gym.Env):
|
|||||||
self.done = self.index == self.size
|
self.done = self.index == self.size
|
||||||
return self._get_state(), self._get_reward(), \
|
return self._get_state(), self._get_reward(), \
|
||||||
self.done, {'key': 1, 'env': self}
|
self.done, {'key': 1, 'env': self}
|
||||||
|
|
||||||
|
|
||||||
|
class NXEnv(gym.Env):
|
||||||
|
def __init__(self, size, obs_type, feat_dim=32):
|
||||||
|
self.size = size
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.graph = nx.Graph()
|
||||||
|
self.graph.add_nodes_from(list(range(size)))
|
||||||
|
assert obs_type in ["array", "object"]
|
||||||
|
self.obs_type = obs_type
|
||||||
|
|
||||||
|
def _encode_obs(self):
|
||||||
|
if self.obs_type == "array":
|
||||||
|
return np.stack([v["data"] for v in self.graph._node.values()])
|
||||||
|
return deepcopy(self.graph)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
graph_state = np.random.rand(self.size, self.feat_dim)
|
||||||
|
for i in range(self.size):
|
||||||
|
self.graph.nodes[i]["data"] = graph_state[i]
|
||||||
|
return self._encode_obs()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
next_graph_state = np.random.rand(self.size, self.feat_dim)
|
||||||
|
for i in range(self.size):
|
||||||
|
self.graph.nodes[i]["data"] = next_graph_state[i]
|
||||||
|
return self._encode_obs(), 1.0, 0, {}
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import networkx as nx
|
||||||
from itertools import starmap
|
from itertools import starmap
|
||||||
|
|
||||||
from tianshou.data import Batch, to_torch, to_numpy
|
from tianshou.data import Batch, to_torch, to_numpy
|
||||||
@ -36,8 +37,7 @@ def test_batch():
|
|||||||
assert 'a' not in b
|
assert 'a' not in b
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
Batch({1: 2})
|
Batch({1: 2})
|
||||||
with pytest.raises(TypeError):
|
assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
|
||||||
Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
@ -170,6 +170,14 @@ def test_batch():
|
|||||||
assert a.a[0] is None and a.a[1] is None
|
assert a.a[0] is None and a.a[1] is None
|
||||||
assert a.b[0] is None and a.b[1] is None
|
assert a.b[0] is None and a.b[1] is None
|
||||||
|
|
||||||
|
# nx.Graph corner case
|
||||||
|
assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
|
||||||
|
g1 = nx.Graph()
|
||||||
|
g1.add_nodes_from(list(range(10)))
|
||||||
|
g2 = nx.Graph()
|
||||||
|
g2.add_nodes_from(list(range(20)))
|
||||||
|
assert Batch(a=np.array([g1, g2])).a.dtype == object
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch():
|
def test_batch_over_batch():
|
||||||
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
||||||
|
@ -14,9 +14,9 @@ from tianshou.data import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv, NXEnv
|
||||||
else: # pytest
|
else: # pytest
|
||||||
from test.base.env import MyTestEnv
|
from test.base.env import MyTestEnv, NXEnv
|
||||||
|
|
||||||
|
|
||||||
class MyPolicy(BasePolicy):
|
class MyPolicy(BasePolicy):
|
||||||
@ -137,6 +137,15 @@ def test_collector():
|
|||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
c2.collect()
|
c2.collect()
|
||||||
|
|
||||||
|
# test NXEnv
|
||||||
|
for obs_type in ["array", "object"]:
|
||||||
|
envs = SubprocVectorEnv([
|
||||||
|
lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
|
||||||
|
c3 = Collector(policy, envs,
|
||||||
|
VectorReplayBuffer(total_size=100, buffer_num=4))
|
||||||
|
c3.collect(n_step=6)
|
||||||
|
assert c3.buffer.obs.dtype == object
|
||||||
|
|
||||||
|
|
||||||
def test_collector_with_async():
|
def test_collector_with_async():
|
||||||
env_lens = [2, 3, 4, 5]
|
env_lens = [2, 3, 4, 5]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces.discrete import Discrete
|
from gym.spaces.discrete import Discrete
|
||||||
@ -6,9 +7,9 @@ from tianshou.env import DummyVectorEnv, SubprocVectorEnv, \
|
|||||||
ShmemVectorEnv, RayVectorEnv
|
ShmemVectorEnv, RayVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv, NXEnv
|
||||||
else: # pytest
|
else: # pytest
|
||||||
from test.base.env import MyTestEnv
|
from test.base.env import MyTestEnv, NXEnv
|
||||||
|
|
||||||
|
|
||||||
def has_ray():
|
def has_ray():
|
||||||
@ -79,6 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
|
|||||||
Batch.cat(o)
|
Batch.cat(o)
|
||||||
v.close()
|
v.close()
|
||||||
# assure 1/7 improvement
|
# assure 1/7 improvement
|
||||||
|
if sys.platform != "darwin": # macOS cannot pass this check
|
||||||
assert spent_time < 6.0 * sleep * num / (num + 1)
|
assert spent_time < 6.0 * sleep * num / (num + 1)
|
||||||
|
|
||||||
|
|
||||||
@ -116,6 +118,7 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
|
|||||||
pass_check = 0
|
pass_check = 0
|
||||||
break
|
break
|
||||||
total_pass += pass_check
|
total_pass += pass_check
|
||||||
|
if sys.platform != "darwin": # macOS cannot pass this check
|
||||||
assert total_pass >= 2
|
assert total_pass >= 2
|
||||||
|
|
||||||
|
|
||||||
@ -167,7 +170,18 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_obs():
|
||||||
|
for obs_type in ["array", "object"]:
|
||||||
|
envs = SubprocVectorEnv([
|
||||||
|
lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
|
||||||
|
obs = envs.reset()
|
||||||
|
assert obs.dtype == object
|
||||||
|
obs = envs.step([1, 1, 1, 1])[0]
|
||||||
|
assert obs.dtype == object
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
test_env_obs()
|
||||||
test_vecenv()
|
test_vecenv()
|
||||||
test_async_env()
|
test_async_env()
|
||||||
test_async_check_id()
|
test_async_check_id()
|
||||||
|
@ -65,7 +65,9 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
|||||||
# array([{}, array({}, dtype=object)], dtype=object)
|
# array([{}, array({}, dtype=object)], dtype=object)
|
||||||
if not v.shape:
|
if not v.shape:
|
||||||
v = v.item(0)
|
v = v.item(0)
|
||||||
elif any(isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1)):
|
elif all(isinstance(e, np.ndarray) for e in v.reshape(-1)):
|
||||||
|
return v # various length, np.array([[1], [2, 3], [4, 5, 6]])
|
||||||
|
elif any(isinstance(e, torch.Tensor) for e in v.reshape(-1)):
|
||||||
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
25
tianshou/env/venvs.py
vendored
25
tianshou/env/venvs.py
vendored
@ -1,6 +1,6 @@
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Union, Optional, Callable
|
from typing import Any, List, Tuple, Union, Optional, Callable
|
||||||
|
|
||||||
from tianshou.utils import RunningMeanStd
|
from tianshou.utils import RunningMeanStd
|
||||||
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
||||||
@ -163,7 +163,11 @@ class BaseVectorEnv(gym.Env):
|
|||||||
id = self._wrap_id(id)
|
id = self._wrap_id(id)
|
||||||
if self.is_async:
|
if self.is_async:
|
||||||
self._assert_id(id)
|
self._assert_id(id)
|
||||||
obs = np.stack([self.workers[i].reset() for i in id])
|
obs_list = [self.workers[i].reset() for i in id]
|
||||||
|
try:
|
||||||
|
obs = np.stack(obs_list)
|
||||||
|
except ValueError: # different len(obs)
|
||||||
|
obs = np.array(obs_list, dtype=object)
|
||||||
if self.obs_rms and self.update_obs_rms:
|
if self.obs_rms and self.update_obs_rms:
|
||||||
self.obs_rms.update(obs)
|
self.obs_rms.update(obs)
|
||||||
return self.normalize_obs(obs)
|
return self.normalize_obs(obs)
|
||||||
@ -172,7 +176,7 @@ class BaseVectorEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
action: np.ndarray,
|
action: np.ndarray,
|
||||||
id: Optional[Union[int, List[int], np.ndarray]] = None
|
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||||
) -> List[np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Run one timestep of some environments' dynamics.
|
"""Run one timestep of some environments' dynamics.
|
||||||
|
|
||||||
If id is None, run one timestep of all the environments’ dynamics;
|
If id is None, run one timestep of all the environments’ dynamics;
|
||||||
@ -236,10 +240,16 @@ class BaseVectorEnv(gym.Env):
|
|||||||
info["env_id"] = env_id
|
info["env_id"] = env_id
|
||||||
result.append((obs, rew, done, info))
|
result.append((obs, rew, done, info))
|
||||||
self.ready_id.append(env_id)
|
self.ready_id.append(env_id)
|
||||||
obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result))
|
obs_list, rew_list, done_list, info_list = zip(*result)
|
||||||
|
try:
|
||||||
|
obs_stack = np.stack(obs_list)
|
||||||
|
except ValueError: # different len(obs)
|
||||||
|
obs_stack = np.array(obs_list, dtype=object)
|
||||||
|
rew_stack, done_stack, info_stack = map(
|
||||||
|
np.stack, [rew_list, done_list, info_list])
|
||||||
if self.obs_rms and self.update_obs_rms:
|
if self.obs_rms and self.update_obs_rms:
|
||||||
self.obs_rms.update(obs_stack)
|
self.obs_rms.update(obs_stack)
|
||||||
return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack]
|
return self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack
|
||||||
|
|
||||||
def seed(
|
def seed(
|
||||||
self, seed: Optional[Union[int, List[int]]] = None
|
self, seed: Optional[Union[int, List[int]]] = None
|
||||||
@ -292,11 +302,6 @@ class BaseVectorEnv(gym.Env):
|
|||||||
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
|
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
"""Redirect to self.close()."""
|
|
||||||
if not self.is_closed:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
|
|
||||||
class DummyVectorEnv(BaseVectorEnv):
|
class DummyVectorEnv(BaseVectorEnv):
|
||||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||||
|
@ -59,7 +59,6 @@ class TRPOPolicy(NPGPolicy):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||||
del self._step_size
|
|
||||||
self._max_backtracks = max_backtracks
|
self._max_backtracks = max_backtracks
|
||||||
self._delta = max_kl
|
self._delta = max_kl
|
||||||
self._backtrack_coeff = backtrack_coeff
|
self._backtrack_coeff = backtrack_coeff
|
||||||
@ -123,7 +122,7 @@ class TRPOPolicy(NPGPolicy):
|
|||||||
" are poor and need to be changed.")
|
" are poor and need to be changed.")
|
||||||
|
|
||||||
# optimize citirc
|
# optimize citirc
|
||||||
for _ in range(self._optim_critic_iters):
|
for _ in range(self._optim_critic_iters): # type: ignore
|
||||||
value = self.critic(b.obs).flatten()
|
value = self.critic(b.obs).flatten()
|
||||||
vf_loss = F.mse_loss(b.returns, value)
|
vf_loss = F.mse_loss(b.returns, value)
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user