From 65c4e3d4cdc1920db2cb1b8000d4038318c4b5c0 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 27 Sep 2022 04:55:26 +0800 Subject: [PATCH] Fix NNI tests upon v2.9 upgrade (#750) * Fix NNI tests upon v2.9 upgrade * Un-ignore * fix --- .github/workflows/gputest.yml | 2 +- setup.py | 2 +- test/3rd_party/test_nni.py | 21 +++++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.github/workflows/gputest.yml b/.github/workflows/gputest.yml index 2ddecd4..5461977 100644 --- a/.github/workflows/gputest.yml +++ b/.github/workflows/gputest.yml @@ -28,4 +28,4 @@ jobs: - name: Test with pytest # ignore test/throughput which only profiles the code run: | - pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --durations=0 -v --color=yes + pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes diff --git a/setup.py b/setup.py index e1dc5e1..cbed99d 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ def get_extras_require() -> str: "pettingzoo>=1.17", "pygame>=2.1.0", # pettingzoo test cases pistonball "pymunk>=6.2.1", # pettingzoo test cases pistonball - "nni>=2.3", + "nni>=2.3,<3.0", # expect breaking changes at next major version "pytorch_lightning", ], "atari": ["atari_py", "opencv-python"], diff --git a/test/3rd_party/test_nni.py b/test/3rd_party/test_nni.py index 23d714b..ddeeff6 100644 --- a/test/3rd_party/test_nni.py +++ b/test/3rd_party/test_nni.py @@ -5,22 +5,23 @@ import threading import time from typing import List, Union -import nni.retiarii.execution.api -import nni.retiarii.nn.pytorch as nn -import nni.retiarii.strategy as strategy +import nni.nas.execution.api +import nni.nas.nn.pytorch as nn +import nni.nas.strategy as strategy import torch import torch.nn.functional as F -from nni.retiarii import Model -from nni.retiarii.converter import convert_to_graph -from nni.retiarii.execution import wait_models -from nni.retiarii.execution.interface import ( +from nni.nas.execution import wait_models +from nni.nas.execution.common import ( AbstractExecutionEngine, AbstractGraphListener, + DebugEvaluator, MetricData, + Model, + ModelStatus, WorkerInfo, ) -from nni.retiarii.graph import DebugEvaluator, ModelStatus -from nni.retiarii.nn.pytorch.mutator import process_inline_mutation +from nni.nas.execution.pytorch.converter import convert_to_graph +from nni.nas.nn.pytorch.mutator import process_inline_mutation class MockExecutionEngine(AbstractExecutionEngine): @@ -62,7 +63,7 @@ class MockExecutionEngine(AbstractExecutionEngine): def _reset_execution_engine(engine=None): - nni.retiarii.execution.api._execution_engine = engine + nni.nas.execution.api._execution_engine = engine class Net(nn.Module):