Tianshou/test/3rd_party/test_nni.py
Michael Panchenko 600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00

131 lines
3.8 KiB
Python

# https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py
import random
import threading
import time
from typing import Union
import nni.nas.execution.api
import nni.nas.nn.pytorch as nn
import pytest
import torch
import torch.nn.functional as F
from nni.nas import strategy
from nni.nas.execution import wait_models
from nni.nas.execution.common import (
AbstractExecutionEngine,
AbstractGraphListener,
DebugEvaluator,
MetricData,
Model,
ModelStatus,
WorkerInfo,
)
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.nn.pytorch.mutator import process_inline_mutation
class MockExecutionEngine(AbstractExecutionEngine):
def __init__(self, failure_prob=0.0):
self.models = []
self.failure_prob = failure_prob
self._resource_left = 4
def _model_complete(self, model: Model):
time.sleep(random.uniform(0, 1))
if random.uniform(0, 1) < self.failure_prob:
model.status = ModelStatus.Failed
else:
model.metric = random.uniform(0, 1)
model.status = ModelStatus.Trained
self._resource_left += 1
def submit_models(self, *models: Model) -> None:
for model in models:
self.models.append(model)
self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model,)).start()
def list_models(self) -> list[Model]:
return self.models
def query_available_resource(self) -> Union[list[WorkerInfo], int]:
return self._resource_left
def budget_exhausted(self) -> bool:
pass
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
pass
def trial_execute_graph(cls) -> MetricData:
pass
def _reset_execution_engine(engine=None):
nni.nas.execution.api._execution_engine = engine
class Net(nn.Module):
def __init__(self, hidden_size=32, diff_size=False):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice(
[
nn.Linear(4 * 4 * 50, hidden_size, bias=True),
nn.Linear(4 * 4 * 50, hidden_size, bias=False),
],
label="fc1",
)
self.fc2 = nn.LayerChoice(
[
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True),
]
+ ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]),
label="fc2",
)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _get_model_and_mutators(**kwargs):
base_model = Net(**kwargs)
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = DebugEvaluator()
mutators = process_inline_mutation(base_model_ir)
return base_model_ir, mutators
@pytest.mark.skip(
reason="NNI currently uses OpenAI Gym",
) # TODO: Remove once NNI transitions to Gymnasium
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators(diff_size=True))
wait_models(*engine.models)
_reset_execution_engine()
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
if __name__ == "__main__":
test_rl()