fix conda support and keep API compatibility (#536)
* loose constrains * fix nni issue (#478) * fix coverage
This commit is contained in:
parent
97df511a13
commit
c248b4f87e
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
|||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
# ignore test/throughput which only profiles the code
|
# ignore test/throughput which only profiles the code
|
||||||
run: |
|
run: |
|
||||||
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
|
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v1
|
uses: codecov/codecov-action@v1
|
||||||
with:
|
with:
|
||||||
|
5
setup.py
5
setup.py
@ -15,14 +15,13 @@ def get_version() -> str:
|
|||||||
|
|
||||||
def get_install_requires() -> str:
|
def get_install_requires() -> str:
|
||||||
return [
|
return [
|
||||||
"gym>=0.21",
|
"gym>=0.15.4",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
|
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||||
"tensorboard>=2.5.0",
|
"tensorboard>=2.5.0",
|
||||||
"torch>=1.4.0",
|
"torch>=1.4.0",
|
||||||
"numba>=0.51.0",
|
"numba>=0.51.0",
|
||||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||||
"pettingzoo>=1.15",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +45,10 @@ def get_extras_require() -> str:
|
|||||||
"doc8",
|
"doc8",
|
||||||
"scipy",
|
"scipy",
|
||||||
"pillow",
|
"pillow",
|
||||||
|
"pettingzoo>=1.12",
|
||||||
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||||
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||||
|
"nni>=2.3",
|
||||||
],
|
],
|
||||||
"atari": ["atari_py", "opencv-python"],
|
"atari": ["atari_py", "opencv-python"],
|
||||||
"mujoco": ["mujoco_py"],
|
"mujoco": ["mujoco_py"],
|
||||||
|
126
test/3rd_party/test_nni.py
vendored
Normal file
126
test/3rd_party/test_nni.py
vendored
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py
|
||||||
|
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import nni.retiarii.execution.api
|
||||||
|
import nni.retiarii.nn.pytorch as nn
|
||||||
|
import nni.retiarii.strategy as strategy
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from nni.retiarii import Model
|
||||||
|
from nni.retiarii.converter import convert_to_graph
|
||||||
|
from nni.retiarii.execution import wait_models
|
||||||
|
from nni.retiarii.execution.interface import (
|
||||||
|
AbstractExecutionEngine,
|
||||||
|
AbstractGraphListener,
|
||||||
|
MetricData,
|
||||||
|
WorkerInfo,
|
||||||
|
)
|
||||||
|
from nni.retiarii.graph import DebugEvaluator, ModelStatus
|
||||||
|
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
|
||||||
|
|
||||||
|
|
||||||
|
class MockExecutionEngine(AbstractExecutionEngine):
|
||||||
|
|
||||||
|
def __init__(self, failure_prob=0.):
|
||||||
|
self.models = []
|
||||||
|
self.failure_prob = failure_prob
|
||||||
|
self._resource_left = 4
|
||||||
|
|
||||||
|
def _model_complete(self, model: Model):
|
||||||
|
time.sleep(random.uniform(0, 1))
|
||||||
|
if random.uniform(0, 1) < self.failure_prob:
|
||||||
|
model.status = ModelStatus.Failed
|
||||||
|
else:
|
||||||
|
model.metric = random.uniform(0, 1)
|
||||||
|
model.status = ModelStatus.Trained
|
||||||
|
self._resource_left += 1
|
||||||
|
|
||||||
|
def submit_models(self, *models: Model) -> None:
|
||||||
|
for model in models:
|
||||||
|
self.models.append(model)
|
||||||
|
self._resource_left -= 1
|
||||||
|
threading.Thread(target=self._model_complete, args=(model, )).start()
|
||||||
|
|
||||||
|
def list_models(self) -> List[Model]:
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
|
||||||
|
return self._resource_left
|
||||||
|
|
||||||
|
def budget_exhausted(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def trial_execute_graph(cls) -> MetricData:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_execution_engine(engine=None):
|
||||||
|
nni.retiarii.execution.api._execution_engine = engine
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=32, diff_size=False):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||||
|
self.conv2 = nn.Conv2d(20, 50, 5, 1)
|
||||||
|
self.fc1 = nn.LayerChoice(
|
||||||
|
[
|
||||||
|
nn.Linear(4 * 4 * 50, hidden_size, bias=True),
|
||||||
|
nn.Linear(4 * 4 * 50, hidden_size, bias=False)
|
||||||
|
],
|
||||||
|
label='fc1'
|
||||||
|
)
|
||||||
|
self.fc2 = nn.LayerChoice(
|
||||||
|
[
|
||||||
|
nn.Linear(hidden_size, 10, bias=False),
|
||||||
|
nn.Linear(hidden_size, 10, bias=True)
|
||||||
|
] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]),
|
||||||
|
label='fc2'
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.conv1(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = x.view(-1, 4 * 4 * 50)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = self.fc2(x)
|
||||||
|
return F.log_softmax(x, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_and_mutators(**kwargs):
|
||||||
|
base_model = Net(**kwargs)
|
||||||
|
script_module = torch.jit.script(base_model)
|
||||||
|
base_model_ir = convert_to_graph(script_module, base_model)
|
||||||
|
base_model_ir.evaluator = DebugEvaluator()
|
||||||
|
mutators = process_inline_mutation(base_model_ir)
|
||||||
|
return base_model_ir, mutators
|
||||||
|
|
||||||
|
|
||||||
|
def test_rl():
|
||||||
|
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
|
||||||
|
engine = MockExecutionEngine(failure_prob=0.2)
|
||||||
|
_reset_execution_engine(engine)
|
||||||
|
rl.run(*_get_model_and_mutators(diff_size=True))
|
||||||
|
wait_models(*engine.models)
|
||||||
|
_reset_execution_engine()
|
||||||
|
|
||||||
|
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
|
||||||
|
engine = MockExecutionEngine(failure_prob=0.2)
|
||||||
|
_reset_execution_engine(engine)
|
||||||
|
rl.run(*_get_model_and_mutators())
|
||||||
|
wait_models(*engine.models)
|
||||||
|
_reset_execution_engine()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_rl()
|
@ -1,6 +1,6 @@
|
|||||||
from tianshou import data, env, exploration, policy, trainer, utils
|
from tianshou import data, env, exploration, policy, trainer, utils
|
||||||
|
|
||||||
__version__ = "0.4.6"
|
__version__ = "0.4.6.post1"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
6
tianshou/env/__init__.py
vendored
6
tianshou/env/__init__.py
vendored
@ -1,6 +1,5 @@
|
|||||||
"""Env package."""
|
"""Env package."""
|
||||||
|
|
||||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
|
||||||
from tianshou.env.venvs import (
|
from tianshou.env.venvs import (
|
||||||
BaseVectorEnv,
|
BaseVectorEnv,
|
||||||
DummyVectorEnv,
|
DummyVectorEnv,
|
||||||
@ -9,6 +8,11 @@ from tianshou.env.venvs import (
|
|||||||
SubprocVectorEnv,
|
SubprocVectorEnv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseVectorEnv",
|
"BaseVectorEnv",
|
||||||
"DummyVectorEnv",
|
"DummyVectorEnv",
|
||||||
|
6
tianshou/env/venvs.py
vendored
6
tianshou/env/venvs.py
vendored
@ -2,7 +2,6 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pettingzoo
|
|
||||||
|
|
||||||
from tianshou.env.worker import (
|
from tianshou.env.worker import (
|
||||||
DummyEnvWorker,
|
DummyEnvWorker,
|
||||||
@ -365,10 +364,7 @@ class DummyVectorEnv(BaseVectorEnv):
|
|||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||||
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
|
|
||||||
**kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
25
tianshou/env/worker/base.py
vendored
25
tianshou/env/worker/base.py
vendored
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -11,8 +12,10 @@ class EnvWorker(ABC):
|
|||||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||||
self._env_fn = env_fn
|
self._env_fn = env_fn
|
||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
|
||||||
|
np.ndarray]
|
||||||
self.action_space = self.get_env_attr("action_space") # noqa: B009
|
self.action_space = self.get_env_attr("action_space") # noqa: B009
|
||||||
|
self.is_reset = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_env_attr(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
@ -22,7 +25,6 @@ class EnvWorker(ABC):
|
|||||||
def set_env_attr(self, key: str, value: Any) -> None:
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def send(self, action: Optional[np.ndarray]) -> None:
|
def send(self, action: Optional[np.ndarray]) -> None:
|
||||||
"""Send action signal to low-level worker.
|
"""Send action signal to low-level worker.
|
||||||
|
|
||||||
@ -30,7 +32,17 @@ class EnvWorker(ABC):
|
|||||||
it indicates "step" signal. The paired return value from "recv"
|
it indicates "step" signal. The paired return value from "recv"
|
||||||
function is determined by such kind of different signal.
|
function is determined by such kind of different signal.
|
||||||
"""
|
"""
|
||||||
pass
|
if hasattr(self, "send_action"):
|
||||||
|
warnings.warn(
|
||||||
|
"send_action will soon be deprecated. "
|
||||||
|
"Please use send and recv for your own EnvWorker."
|
||||||
|
)
|
||||||
|
if action is None:
|
||||||
|
self.is_reset = True
|
||||||
|
self.result = self.reset()
|
||||||
|
else:
|
||||||
|
self.is_reset = False
|
||||||
|
self.send_action(action) # type: ignore
|
||||||
|
|
||||||
def recv(
|
def recv(
|
||||||
self
|
self
|
||||||
@ -41,6 +53,13 @@ class EnvWorker(ABC):
|
|||||||
single observation; otherwise it returns a tuple of (obs, rew, done,
|
single observation; otherwise it returns a tuple of (obs, rew, done,
|
||||||
info).
|
info).
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self, "get_result"):
|
||||||
|
warnings.warn(
|
||||||
|
"get_result will soon be deprecated. "
|
||||||
|
"Please use send and recv for your own EnvWorker."
|
||||||
|
)
|
||||||
|
if not self.is_reset:
|
||||||
|
self.result = self.get_result() # type: ignore
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
def reset(self) -> np.ndarray:
|
def reset(self) -> np.ndarray:
|
||||||
|
@ -3,9 +3,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
|
except ImportError:
|
||||||
|
PettingZooEnv = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class MultiAgentPolicyManager(BasePolicy):
|
class MultiAgentPolicyManager(BasePolicy):
|
||||||
"""Multi-agent policy manager for MARL.
|
"""Multi-agent policy manager for MARL.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user