Changes: - Disclaimer in README - Replaced all occurences of Gym with Gymnasium - Removed code that is now dead since we no longer need to support the old step API - Updated type hints to only allow new step API - Increased required version of envpool to support Gymnasium - Increased required version of PettingZoo to support Gymnasium - Updated `PettingZooEnv` to only use the new step API, removed hack to also support old API - I had to add some `# type: ignore` comments, due to new type hinting in Gymnasium. I'm not that familiar with type hinting but I believe that the issue is on the Gymnasium side and we are looking into it. - Had to update `MyTestEnv` to support `options` kwarg - Skip NNI tests because they still use OpenAI Gym - Also allow `PettingZooEnv` in vector environment - Updated doc page about ReplayBuffer to also talk about terminated and truncated flags. Still need to do: - Update the Jupyter notebooks in docs - Check the entire code base for more dead code (from compatibility stuff) - Check the reset functions of all environments/wrappers in code base to make sure they use the `options` kwarg - Someone might want to check test_env_finite.py - Is it okay to allow `PettingZooEnv` in vector environments? Might need to update docs?
		
			
				
	
	
		
			132 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py
 | |
| 
 | |
| import random
 | |
| import threading
 | |
| import time
 | |
| from typing import List, Union
 | |
| 
 | |
| import nni.nas.execution.api
 | |
| import nni.nas.nn.pytorch as nn
 | |
| import nni.nas.strategy as strategy
 | |
| import pytest
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| from nni.nas.execution import wait_models
 | |
| from nni.nas.execution.common import (
 | |
|     AbstractExecutionEngine,
 | |
|     AbstractGraphListener,
 | |
|     DebugEvaluator,
 | |
|     MetricData,
 | |
|     Model,
 | |
|     ModelStatus,
 | |
|     WorkerInfo,
 | |
| )
 | |
| from nni.nas.execution.pytorch.converter import convert_to_graph
 | |
| from nni.nas.nn.pytorch.mutator import process_inline_mutation
 | |
| 
 | |
| 
 | |
| class MockExecutionEngine(AbstractExecutionEngine):
 | |
| 
 | |
|     def __init__(self, failure_prob=0.):
 | |
|         self.models = []
 | |
|         self.failure_prob = failure_prob
 | |
|         self._resource_left = 4
 | |
| 
 | |
|     def _model_complete(self, model: Model):
 | |
|         time.sleep(random.uniform(0, 1))
 | |
|         if random.uniform(0, 1) < self.failure_prob:
 | |
|             model.status = ModelStatus.Failed
 | |
|         else:
 | |
|             model.metric = random.uniform(0, 1)
 | |
|             model.status = ModelStatus.Trained
 | |
|         self._resource_left += 1
 | |
| 
 | |
|     def submit_models(self, *models: Model) -> None:
 | |
|         for model in models:
 | |
|             self.models.append(model)
 | |
|             self._resource_left -= 1
 | |
|             threading.Thread(target=self._model_complete, args=(model, )).start()
 | |
| 
 | |
|     def list_models(self) -> List[Model]:
 | |
|         return self.models
 | |
| 
 | |
|     def query_available_resource(self) -> Union[List[WorkerInfo], int]:
 | |
|         return self._resource_left
 | |
| 
 | |
|     def budget_exhausted(self) -> bool:
 | |
|         pass
 | |
| 
 | |
|     def register_graph_listener(self, listener: AbstractGraphListener) -> None:
 | |
|         pass
 | |
| 
 | |
|     def trial_execute_graph(cls) -> MetricData:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| def _reset_execution_engine(engine=None):
 | |
|     nni.nas.execution.api._execution_engine = engine
 | |
| 
 | |
| 
 | |
| class Net(nn.Module):
 | |
| 
 | |
|     def __init__(self, hidden_size=32, diff_size=False):
 | |
|         super(Net, self).__init__()
 | |
|         self.conv1 = nn.Conv2d(1, 20, 5, 1)
 | |
|         self.conv2 = nn.Conv2d(20, 50, 5, 1)
 | |
|         self.fc1 = nn.LayerChoice(
 | |
|             [
 | |
|                 nn.Linear(4 * 4 * 50, hidden_size, bias=True),
 | |
|                 nn.Linear(4 * 4 * 50, hidden_size, bias=False)
 | |
|             ],
 | |
|             label='fc1'
 | |
|         )
 | |
|         self.fc2 = nn.LayerChoice(
 | |
|             [
 | |
|                 nn.Linear(hidden_size, 10, bias=False),
 | |
|                 nn.Linear(hidden_size, 10, bias=True)
 | |
|             ] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]),
 | |
|             label='fc2'
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = F.relu(self.conv1(x))
 | |
|         x = F.max_pool2d(x, 2, 2)
 | |
|         x = F.relu(self.conv2(x))
 | |
|         x = F.max_pool2d(x, 2, 2)
 | |
|         x = x.view(-1, 4 * 4 * 50)
 | |
|         x = F.relu(self.fc1(x))
 | |
|         x = self.fc2(x)
 | |
|         return F.log_softmax(x, dim=1)
 | |
| 
 | |
| 
 | |
| def _get_model_and_mutators(**kwargs):
 | |
|     base_model = Net(**kwargs)
 | |
|     script_module = torch.jit.script(base_model)
 | |
|     base_model_ir = convert_to_graph(script_module, base_model)
 | |
|     base_model_ir.evaluator = DebugEvaluator()
 | |
|     mutators = process_inline_mutation(base_model_ir)
 | |
|     return base_model_ir, mutators
 | |
| 
 | |
| 
 | |
| @pytest.mark.skip(
 | |
|     reason="NNI currently uses OpenAI Gym"
 | |
| )  # TODO: Remove once NNI transitions to Gymnasium
 | |
| def test_rl():
 | |
|     rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
 | |
|     engine = MockExecutionEngine(failure_prob=0.2)
 | |
|     _reset_execution_engine(engine)
 | |
|     rl.run(*_get_model_and_mutators(diff_size=True))
 | |
|     wait_models(*engine.models)
 | |
|     _reset_execution_engine()
 | |
| 
 | |
|     rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
 | |
|     engine = MockExecutionEngine(failure_prob=0.2)
 | |
|     _reset_execution_engine(engine)
 | |
|     rl.run(*_get_model_and_mutators())
 | |
|     wait_models(*engine.models)
 | |
|     _reset_execution_engine()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_rl()
 |