Fix NNI tests upon v2.9 upgrade (#750)
* Fix NNI tests upon v2.9 upgrade * Un-ignore * fix
This commit is contained in:
parent
ea36dc5195
commit
65c4e3d4cd
2
.github/workflows/gputest.yml
vendored
2
.github/workflows/gputest.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
- name: Test with pytest
|
||||
# ignore test/throughput which only profiles the code
|
||||
run: |
|
||||
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
|
||||
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
|
||||
|
2
setup.py
2
setup.py
@ -50,7 +50,7 @@ def get_extras_require() -> str:
|
||||
"pettingzoo>=1.17",
|
||||
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||
"nni>=2.3",
|
||||
"nni>=2.3,<3.0", # expect breaking changes at next major version
|
||||
"pytorch_lightning",
|
||||
],
|
||||
"atari": ["atari_py", "opencv-python"],
|
||||
|
21
test/3rd_party/test_nni.py
vendored
21
test/3rd_party/test_nni.py
vendored
@ -5,22 +5,23 @@ import threading
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import nni.retiarii.execution.api
|
||||
import nni.retiarii.nn.pytorch as nn
|
||||
import nni.retiarii.strategy as strategy
|
||||
import nni.nas.execution.api
|
||||
import nni.nas.nn.pytorch as nn
|
||||
import nni.nas.strategy as strategy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from nni.retiarii import Model
|
||||
from nni.retiarii.converter import convert_to_graph
|
||||
from nni.retiarii.execution import wait_models
|
||||
from nni.retiarii.execution.interface import (
|
||||
from nni.nas.execution import wait_models
|
||||
from nni.nas.execution.common import (
|
||||
AbstractExecutionEngine,
|
||||
AbstractGraphListener,
|
||||
DebugEvaluator,
|
||||
MetricData,
|
||||
Model,
|
||||
ModelStatus,
|
||||
WorkerInfo,
|
||||
)
|
||||
from nni.retiarii.graph import DebugEvaluator, ModelStatus
|
||||
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
|
||||
from nni.nas.execution.pytorch.converter import convert_to_graph
|
||||
from nni.nas.nn.pytorch.mutator import process_inline_mutation
|
||||
|
||||
|
||||
class MockExecutionEngine(AbstractExecutionEngine):
|
||||
@ -62,7 +63,7 @@ class MockExecutionEngine(AbstractExecutionEngine):
|
||||
|
||||
|
||||
def _reset_execution_engine(engine=None):
|
||||
nni.retiarii.execution.api._execution_engine = engine
|
||||
nni.nas.execution.api._execution_engine = engine
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user