Fix NNI tests upon v2.9 upgrade (#750)

* Fix NNI tests upon v2.9 upgrade

* Un-ignore

* fix
This commit is contained in:
Yuge Zhang 2022-09-27 04:55:26 +08:00 committed by GitHub
parent ea36dc5195
commit 65c4e3d4cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions

View File

@ -28,4 +28,4 @@ jobs:
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes

View File

@ -50,7 +50,7 @@ def get_extras_require() -> str:
"pettingzoo>=1.17",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
"nni>=2.3",
"nni>=2.3,<3.0", # expect breaking changes at next major version
"pytorch_lightning",
],
"atari": ["atari_py", "opencv-python"],

View File

@ -5,22 +5,23 @@ import threading
import time
from typing import List, Union
import nni.retiarii.execution.api
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import nni.nas.execution.api
import nni.nas.nn.pytorch as nn
import nni.nas.strategy as strategy
import torch
import torch.nn.functional as F
from nni.retiarii import Model
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.execution import wait_models
from nni.retiarii.execution.interface import (
from nni.nas.execution import wait_models
from nni.nas.execution.common import (
AbstractExecutionEngine,
AbstractGraphListener,
DebugEvaluator,
MetricData,
Model,
ModelStatus,
WorkerInfo,
)
from nni.retiarii.graph import DebugEvaluator, ModelStatus
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.nn.pytorch.mutator import process_inline_mutation
class MockExecutionEngine(AbstractExecutionEngine):
@ -62,7 +63,7 @@ class MockExecutionEngine(AbstractExecutionEngine):
def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine
nni.nas.execution.api._execution_engine = engine
class Net(nn.Module):