Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix search space compatibility with JSON (#4455)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Jan 10, 2022
1 parent 452e69f commit 31f11f5
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 4 deletions.
26 changes: 22 additions & 4 deletions nni/experiment/config/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
__all__ = ['ExperimentConfig']

from dataclasses import dataclass
import json
import logging
from pathlib import Path
from typing import Any, List, Optional, Union

import yaml
Expand Down Expand Up @@ -113,6 +115,16 @@ def _canonicalize(self, _parents):

super()._canonicalize([self])

if self.search_space_file is not None:
yaml_error = None
try:
self.search_space = _load_search_space_file(self.search_space_file)
except Exception as e:
yaml_error = repr(e)
if yaml_error is not None: # raise it outside except block to make stack trace clear
msg = f'ExperimentConfig: Failed to load search space file "{self.search_space_file}": {yaml_error}'
raise ValueError(msg)

if self.nni_manager_ip is None:
# show a warning if user does not set nni_manager_ip. we have many issues caused by this
# the simple detection logic won't work for hybrid, but advanced users should not need it
Expand All @@ -133,10 +145,6 @@ def _validate_canonical(self):
if not self.use_annotation and space_cnt < 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')

if self.search_space_file is not None:
with open(self.search_space_file) as ss_file:
self.search_space = yaml.safe_load(ss_file)

# to make the error message clear, ideally it should be:
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
Expand All @@ -156,3 +164,13 @@ def _validate_canonical(self):
tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
if tuner_cnt != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')

def _load_search_space_file(search_space_path):
# FIXME
# we need this because PyYAML 6.0 does not support YAML 1.2,
# which means it is not fully compatible with JSON
content = Path(search_space_path).read_text(encoding='utf8')
try:
return json.loads(content)
except Exception:
return yaml.safe_load(content)
9 changes: 9 additions & 0 deletions test/ut/experiment/assets/ss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 0.0000001, 0.1 ]
10 changes: 10 additions & 0 deletions test/ut/experiment/assets/ss_comma.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 0.0000001, 0.1 ],
},
}
10 changes: 10 additions & 0 deletions test/ut/experiment/assets/ss_tab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ]
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ]
}
}
10 changes: 10 additions & 0 deletions test/ut/experiment/assets/ss_tab_comma.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ],
},
}
9 changes: 9 additions & 0 deletions test/ut/experiment/assets/ss_yaml12.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 1e-7, 0.1 ] # test scientific notation
52 changes: 52 additions & 0 deletions test/ut/experiment/test_search_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import json
from pathlib import Path

import yaml

from nni.experiment.config import ExperimentConfig, AlgorithmConfig, LocalConfig

## template ##

config = ExperimentConfig(
search_space_file = '',
trial_command = 'echo hello',
trial_concurrency = 1,
tuner = AlgorithmConfig(name='randomm'),
training_service = LocalConfig()
)

space_correct = {
'pool_type': {
'_type': 'choice',
'_value': ['max', 'min', 'avg']
},
'学习率': {
'_type': 'loguniform',
'_value': [1e-7, 0.1]
}
}

# FIXME
# PyYAML 6.0 (YAML 1.1) does not support tab and scientific notation
# JSON does not support comment and extra comma
# So some combinations will fail to load
formats = [
('ss_tab.json', 'JSON (tabs + scientific notation)'),
('ss_comma.json', 'JSON with extra comma'),
#('ss_tab_comma.json', 'JSON (tabs + scientific notation) with extra comma'),
('ss.yaml', 'YAML'),
#('ss_yaml12.yaml', 'YAML 1.2 with scientific notation'),
]

def test_search_space():
for space_file, description in formats:
try:
config.search_space_file = Path(__file__).parent / 'assets' / space_file
space = config.json()['searchSpace']
assert space == space_correct
except Exception as e:
print('Failed to load search space format: ' + description)
raise e

if __name__ == '__main__':
test_search_space()

0 comments on commit 31f11f5

Please sign in to comment.