Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix search space compatibility with JSON #4455

Merged
merged 3 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions nni/experiment/config/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
__all__ = ['ExperimentConfig']

from dataclasses import dataclass
import json
import logging
from pathlib import Path
from typing import Any, List, Optional, Union

import yaml
Expand Down Expand Up @@ -113,6 +115,16 @@ def _canonicalize(self, _parents):

super()._canonicalize([self])

if self.search_space_file is not None:
yaml_error = None
try:
self.search_space = _load_search_space_file(self.search_space_file)
except Exception as e:
yaml_error = repr(e)
if yaml_error is not None: # raise it outside except block to make stack trace clear
msg = f'ExperimentConfig: Failed to load search space file "{self.search_space_file}": {yaml_error}'
raise ValueError(msg)

if self.nni_manager_ip is None:
# show a warning if user does not set nni_manager_ip. we have many issues caused by this
# the simple detection logic won't work for hybrid, but advanced users should not need it
Expand All @@ -133,10 +145,6 @@ def _validate_canonical(self):
if not self.use_annotation and space_cnt < 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')

if self.search_space_file is not None:
with open(self.search_space_file) as ss_file:
self.search_space = yaml.safe_load(ss_file)

# to make the error message clear, ideally it should be:
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
Expand All @@ -156,3 +164,13 @@ def _validate_canonical(self):
tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
if tuner_cnt != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')

def _load_search_space_file(search_space_path):
# FIXME
# we need this because PyYAML 6.0 does not support YAML 1.2,
# which means it is not fully compatible with JSON
content = Path(search_space_path).read_text()
try:
return json.loads(content)
except Exception:
return yaml.safe_load(content)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the yaml format search space still does not support scientific notation, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't.

9 changes: 9 additions & 0 deletions test/ut/experiment/assets/ss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 0.0000001, 0.1 ]
10 changes: 10 additions & 0 deletions test/ut/experiment/assets/ss_comma.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 0.0000001, 0.1 ],
},
}
10 changes: 10 additions & 0 deletions test/ut/experiment/assets/ss_tab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ]
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ]
}
}
10 changes: 10 additions & 0 deletions test/ut/experiment/assets/ss_tab_comma.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ],
},
}
9 changes: 9 additions & 0 deletions test/ut/experiment/assets/ss_yaml12.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 1e-7, 0.1 ] # test scientific notation
52 changes: 52 additions & 0 deletions test/ut/experiment/test_search_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import json
from pathlib import Path

import yaml

from nni.experiment.config import ExperimentConfig, AlgorithmConfig, LocalConfig

## template ##

config = ExperimentConfig(
search_space_file = '',
trial_command = 'echo hello',
trial_concurrency = 1,
tuner = AlgorithmConfig(name='randomm'),
training_service = LocalConfig()
)

space_correct = {
'pool_type': {
'_type': 'choice',
'_value': ['max', 'min', 'avg']
},
'学习率': {
'_type': 'loguniform',
'_value': [1e-7, 0.1]
}
}

# FIXME
# PyYAML 6.0 (YAML 1.1) does not support tab and scientific notation
# JSON does not support comment and extra comma
# So some combinations will fail to load
formats = [
('ss_tab.json', 'JSON (tabs + scientific notation)'),
('ss_comma.json', 'JSON with extra comma'),
#('ss_tab_comma.json', 'JSON (tabs + scientific notation) with extra comma'),
('ss.yaml', 'YAML'),
#('ss_yaml12.yaml', 'YAML 1.2 with scientific notation'),
]

def test_search_space():
for space_file, description in formats:
try:
config.search_space_file = Path(__file__).parent / 'assets' / space_file
space = config.json()['searchSpace']
assert space == space_correct
except Exception as e:
print('Failed to load search space format: ' + description)
raise e

if __name__ == '__main__':
test_search_space()