Fix NNI tests upon v2.9 upgrade (#750)
* Fix NNI tests upon v2.9 upgrade * Un-ignore * fix
This commit is contained in:
		
							parent
							
								
									ea36dc5195
								
							
						
					
					
						commit
						65c4e3d4cd
					
				
							
								
								
									
										2
									
								
								.github/workflows/gputest.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/gputest.yml
									
									
									
									
										vendored
									
									
								
							@ -28,4 +28,4 @@ jobs:
 | 
			
		||||
      - name: Test with pytest
 | 
			
		||||
        # ignore test/throughput which only profiles the code
 | 
			
		||||
        run: |
 | 
			
		||||
          pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
 | 
			
		||||
          pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -50,7 +50,7 @@ def get_extras_require() -> str:
 | 
			
		||||
            "pettingzoo>=1.17",
 | 
			
		||||
            "pygame>=2.1.0",  # pettingzoo test cases pistonball
 | 
			
		||||
            "pymunk>=6.2.1",  # pettingzoo test cases pistonball
 | 
			
		||||
            "nni>=2.3",
 | 
			
		||||
            "nni>=2.3,<3.0",  # expect breaking changes at next major version
 | 
			
		||||
            "pytorch_lightning",
 | 
			
		||||
        ],
 | 
			
		||||
        "atari": ["atari_py", "opencv-python"],
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										21
									
								
								test/3rd_party/test_nni.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								test/3rd_party/test_nni.py
									
									
									
									
										vendored
									
									
								
							@ -5,22 +5,23 @@ import threading
 | 
			
		||||
import time
 | 
			
		||||
from typing import List, Union
 | 
			
		||||
 | 
			
		||||
import nni.retiarii.execution.api
 | 
			
		||||
import nni.retiarii.nn.pytorch as nn
 | 
			
		||||
import nni.retiarii.strategy as strategy
 | 
			
		||||
import nni.nas.execution.api
 | 
			
		||||
import nni.nas.nn.pytorch as nn
 | 
			
		||||
import nni.nas.strategy as strategy
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from nni.retiarii import Model
 | 
			
		||||
from nni.retiarii.converter import convert_to_graph
 | 
			
		||||
from nni.retiarii.execution import wait_models
 | 
			
		||||
from nni.retiarii.execution.interface import (
 | 
			
		||||
from nni.nas.execution import wait_models
 | 
			
		||||
from nni.nas.execution.common import (
 | 
			
		||||
    AbstractExecutionEngine,
 | 
			
		||||
    AbstractGraphListener,
 | 
			
		||||
    DebugEvaluator,
 | 
			
		||||
    MetricData,
 | 
			
		||||
    Model,
 | 
			
		||||
    ModelStatus,
 | 
			
		||||
    WorkerInfo,
 | 
			
		||||
)
 | 
			
		||||
from nni.retiarii.graph import DebugEvaluator, ModelStatus
 | 
			
		||||
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
 | 
			
		||||
from nni.nas.execution.pytorch.converter import convert_to_graph
 | 
			
		||||
from nni.nas.nn.pytorch.mutator import process_inline_mutation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockExecutionEngine(AbstractExecutionEngine):
 | 
			
		||||
@ -62,7 +63,7 @@ class MockExecutionEngine(AbstractExecutionEngine):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _reset_execution_engine(engine=None):
 | 
			
		||||
    nni.retiarii.execution.api._execution_engine = engine
 | 
			
		||||
    nni.nas.execution.api._execution_engine = engine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Net(nn.Module):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user