fix conda support and keep API compatibility (#536)

* loose constrains

* fix nni issue (#478)

* fix coverage
This commit is contained in:
Jiayi Weng 2022-02-25 11:05:02 -05:00 committed by GitHub
parent 97df511a13
commit c248b4f87e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 164 additions and 14 deletions

View File

@ -27,7 +27,7 @@ jobs:
- name: Test with pytest - name: Test with pytest
# ignore test/throughput which only profiles the code # ignore test/throughput which only profiles the code
run: | run: |
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v1 uses: codecov/codecov-action@v1
with: with:

View File

@ -15,14 +15,13 @@ def get_version() -> str:
def get_install_requires() -> str: def get_install_requires() -> str:
return [ return [
"gym>=0.21", "gym>=0.15.4",
"tqdm", "tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0", "tensorboard>=2.5.0",
"torch>=1.4.0", "torch>=1.4.0",
"numba>=0.51.0", "numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements "h5py>=2.10.0", # to match tensorflow's minimal requirements
"pettingzoo>=1.15",
] ]
@ -46,8 +45,10 @@ def get_extras_require() -> str:
"doc8", "doc8",
"scipy", "scipy",
"pillow", "pillow",
"pettingzoo>=1.12",
"pygame>=2.1.0", # pettingzoo test cases pistonball "pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball "pymunk>=6.2.1", # pettingzoo test cases pistonball
"nni>=2.3",
], ],
"atari": ["atari_py", "opencv-python"], "atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"], "mujoco": ["mujoco_py"],

126
test/3rd_party/test_nni.py vendored Normal file
View File

@ -0,0 +1,126 @@
# https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py
import random
import threading
import time
from typing import List, Union
import nni.retiarii.execution.api
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import torch
import torch.nn.functional as F
from nni.retiarii import Model
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.execution import wait_models
from nni.retiarii.execution.interface import (
AbstractExecutionEngine,
AbstractGraphListener,
MetricData,
WorkerInfo,
)
from nni.retiarii.graph import DebugEvaluator, ModelStatus
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
class MockExecutionEngine(AbstractExecutionEngine):
def __init__(self, failure_prob=0.):
self.models = []
self.failure_prob = failure_prob
self._resource_left = 4
def _model_complete(self, model: Model):
time.sleep(random.uniform(0, 1))
if random.uniform(0, 1) < self.failure_prob:
model.status = ModelStatus.Failed
else:
model.metric = random.uniform(0, 1)
model.status = ModelStatus.Trained
self._resource_left += 1
def submit_models(self, *models: Model) -> None:
for model in models:
self.models.append(model)
self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model, )).start()
def list_models(self) -> List[Model]:
return self.models
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left
def budget_exhausted(self) -> bool:
pass
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
pass
def trial_execute_graph(cls) -> MetricData:
pass
def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine
class Net(nn.Module):
def __init__(self, hidden_size=32, diff_size=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice(
[
nn.Linear(4 * 4 * 50, hidden_size, bias=True),
nn.Linear(4 * 4 * 50, hidden_size, bias=False)
],
label='fc1'
)
self.fc2 = nn.LayerChoice(
[
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True)
] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]),
label='fc2'
)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _get_model_and_mutators(**kwargs):
base_model = Net(**kwargs)
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = DebugEvaluator()
mutators = process_inline_mutation(base_model_ir)
return base_model_ir, mutators
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators(diff_size=True))
wait_models(*engine.models)
_reset_execution_engine()
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
if __name__ == '__main__':
test_rl()

View File

@ -1,6 +1,6 @@
from tianshou import data, env, exploration, policy, trainer, utils from tianshou import data, env, exploration, policy, trainer, utils
__version__ = "0.4.6" __version__ = "0.4.6.post1"
__all__ = [ __all__ = [
"env", "env",

View File

@ -1,6 +1,5 @@
"""Env package.""" """Env package."""
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.env.venvs import ( from tianshou.env.venvs import (
BaseVectorEnv, BaseVectorEnv,
DummyVectorEnv, DummyVectorEnv,
@ -9,6 +8,11 @@ from tianshou.env.venvs import (
SubprocVectorEnv, SubprocVectorEnv,
) )
try:
from tianshou.env.pettingzoo_env import PettingZooEnv
except ImportError:
pass
__all__ = [ __all__ = [
"BaseVectorEnv", "BaseVectorEnv",
"DummyVectorEnv", "DummyVectorEnv",

View File

@ -2,7 +2,6 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import gym import gym
import numpy as np import numpy as np
import pettingzoo
from tianshou.env.worker import ( from tianshou.env.worker import (
DummyEnvWorker, DummyEnvWorker,
@ -365,10 +364,7 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
""" """
def __init__( def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
**kwargs: Any
) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs) super().__init__(env_fns, DummyEnvWorker, **kwargs)

View File

@ -1,3 +1,4 @@
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
@ -11,8 +12,10 @@ class EnvWorker(ABC):
def __init__(self, env_fn: Callable[[], gym.Env]) -> None: def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn self._env_fn = env_fn
self.is_closed = False self.is_closed = False
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
np.ndarray]
self.action_space = self.get_env_attr("action_space") # noqa: B009 self.action_space = self.get_env_attr("action_space") # noqa: B009
self.is_reset = False
@abstractmethod @abstractmethod
def get_env_attr(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
@ -22,7 +25,6 @@ class EnvWorker(ABC):
def set_env_attr(self, key: str, value: Any) -> None: def set_env_attr(self, key: str, value: Any) -> None:
pass pass
@abstractmethod
def send(self, action: Optional[np.ndarray]) -> None: def send(self, action: Optional[np.ndarray]) -> None:
"""Send action signal to low-level worker. """Send action signal to low-level worker.
@ -30,7 +32,17 @@ class EnvWorker(ABC):
it indicates "step" signal. The paired return value from "recv" it indicates "step" signal. The paired return value from "recv"
function is determined by such kind of different signal. function is determined by such kind of different signal.
""" """
pass if hasattr(self, "send_action"):
warnings.warn(
"send_action will soon be deprecated. "
"Please use send and recv for your own EnvWorker."
)
if action is None:
self.is_reset = True
self.result = self.reset()
else:
self.is_reset = False
self.send_action(action) # type: ignore
def recv( def recv(
self self
@ -41,6 +53,13 @@ class EnvWorker(ABC):
single observation; otherwise it returns a tuple of (obs, rew, done, single observation; otherwise it returns a tuple of (obs, rew, done,
info). info).
""" """
if hasattr(self, "get_result"):
warnings.warn(
"get_result will soon be deprecated. "
"Please use send and recv for your own EnvWorker."
)
if not self.is_reset:
self.result = self.get_result() # type: ignore
return self.result return self.result
def reset(self) -> np.ndarray: def reset(self) -> np.ndarray:

View File

@ -3,9 +3,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
try:
from tianshou.env.pettingzoo_env import PettingZooEnv
except ImportError:
PettingZooEnv = None # type: ignore
class MultiAgentPolicyManager(BasePolicy): class MultiAgentPolicyManager(BasePolicy):
"""Multi-agent policy manager for MARL. """Multi-agent policy manager for MARL.