Fix NNI tests upon v2.9 upgrade (#750)
* Fix NNI tests upon v2.9 upgrade * Un-ignore * fix
This commit is contained in:
parent
ea36dc5195
commit
65c4e3d4cd
2
.github/workflows/gputest.yml
vendored
2
.github/workflows/gputest.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
|||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
# ignore test/throughput which only profiles the code
|
# ignore test/throughput which only profiles the code
|
||||||
run: |
|
run: |
|
||||||
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
|
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
|
||||||
|
2
setup.py
2
setup.py
@ -50,7 +50,7 @@ def get_extras_require() -> str:
|
|||||||
"pettingzoo>=1.17",
|
"pettingzoo>=1.17",
|
||||||
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||||
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||||
"nni>=2.3",
|
"nni>=2.3,<3.0", # expect breaking changes at next major version
|
||||||
"pytorch_lightning",
|
"pytorch_lightning",
|
||||||
],
|
],
|
||||||
"atari": ["atari_py", "opencv-python"],
|
"atari": ["atari_py", "opencv-python"],
|
||||||
|
21
test/3rd_party/test_nni.py
vendored
21
test/3rd_party/test_nni.py
vendored
@ -5,22 +5,23 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import nni.retiarii.execution.api
|
import nni.nas.execution.api
|
||||||
import nni.retiarii.nn.pytorch as nn
|
import nni.nas.nn.pytorch as nn
|
||||||
import nni.retiarii.strategy as strategy
|
import nni.nas.strategy as strategy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from nni.retiarii import Model
|
from nni.nas.execution import wait_models
|
||||||
from nni.retiarii.converter import convert_to_graph
|
from nni.nas.execution.common import (
|
||||||
from nni.retiarii.execution import wait_models
|
|
||||||
from nni.retiarii.execution.interface import (
|
|
||||||
AbstractExecutionEngine,
|
AbstractExecutionEngine,
|
||||||
AbstractGraphListener,
|
AbstractGraphListener,
|
||||||
|
DebugEvaluator,
|
||||||
MetricData,
|
MetricData,
|
||||||
|
Model,
|
||||||
|
ModelStatus,
|
||||||
WorkerInfo,
|
WorkerInfo,
|
||||||
)
|
)
|
||||||
from nni.retiarii.graph import DebugEvaluator, ModelStatus
|
from nni.nas.execution.pytorch.converter import convert_to_graph
|
||||||
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
|
from nni.nas.nn.pytorch.mutator import process_inline_mutation
|
||||||
|
|
||||||
|
|
||||||
class MockExecutionEngine(AbstractExecutionEngine):
|
class MockExecutionEngine(AbstractExecutionEngine):
|
||||||
@ -62,7 +63,7 @@ class MockExecutionEngine(AbstractExecutionEngine):
|
|||||||
|
|
||||||
|
|
||||||
def _reset_execution_engine(engine=None):
|
def _reset_execution_engine(engine=None):
|
||||||
nni.retiarii.execution.api._execution_engine = engine
|
nni.nas.execution.api._execution_engine = engine
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user