-
Notifications
You must be signed in to change notification settings - Fork 84
平台架构 超参搜索
nni: 微软开源的一个 AutoML 工具包,用于超参数优化、神经架构搜索、模型压缩和特征工程。
nni的超参搜索,支持更多的算法。同时在pipeline中也支持了边训练边超参搜索的模板。只需要用户的代码能接受超参数作为输入,同时上报作为超参算法的目标值
每个用户可以启动多个超参搜索的实例,通过url prefix作为前端的路由。
可以参考nni官网的书写方式
启动超参搜索,会根据用户配置的超参搜索算法,选择好超参的可选值,并将选择值传递给用户的容器。例如上面的超参定义会在用户docker运行时传递下面的参数。所以用户不需要在启动命令或参数中添加这些变量,系统会自动添加,用户只需要在自己的业务代码中接收这些参数,并根据这些参数输出值就可以了。
--lr=0.021593113434583065 --num-layers=5 --optimizer=ftrl
业务方容器和代码启动接收超参进行迭代计算,通过主动上报结果来进行迭代。 示例如下,用户代码需要能接受超参可取值为输入参数,同时每次迭代通过nni.report_intermediate_result上报每次epoch的结果值,并使用nni.report_final_result上报每次实例的结果值。
import os
import argparse
import logging,random,time
import nni
from nni.utils import merge_parameter
logger = logging.getLogger('mnist_AutoML')
def main(args):
test_acc=random.randint(30,50)
for epoch in range(1, 11):
test_acc_epoch= random.randint(3,5)
time.sleep(3)
test_acc+=test_acc_epoch
# 上报当前迭代目标值
nni.report_intermediate_result(test_acc)
# 上报最总目标值
nni.report_final_result(test_acc)
def get_params():
# 必须接收超参数为输入参数
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)')
args, _ = parser.parse_known_args()
return args
if __name__ == '__main__':
try:
# get parameters form tuner
tuner_params = nni.get_next_parameter()
params = vars(merge_parameter(get_params(), tuner_params))
print(tuner_params,params)
main(params)
except Exception as exception:
logger.exception(exception)
raise
choice | choice(nested) | randint | uniform | quniform | loguniform | qloguniform | normal | qnormal | lognormal | qlognormal |
---|---|---|---|---|---|---|---|---|---|---|
TPE Tuner | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Random Search Tuner | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Anneal Tuner | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Evolution Tuner | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
SMAC Tuner | ✓ | ✓ | ✓ | ✓ | ✓ | |||||
Batch Tuner | ✓ | |||||||||
Grid Search Tuner | ✓ | ✓ | ✓ | |||||||
Hyperband Advisor | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | |
Metis Tuner | ✓ | ✓ | ✓ | ✓ | ||||||
GP Tuner | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
必须是标准的json。示例
{
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]}
}
启动镜像
点击运行 | 容器 | 日志 | 清理
中的 运行
按钮
运行按钮:启动搜索实验 容器按钮:查看搜索实验容器运行状况 日志按钮:查看平台日志,这里不是业务代码日志 清理按钮:删除清理实验
直到每个pod都处于running状态
等待全部容器运行成功,点击名称,进入web界面查看实验进度和日志
总览界面可以看到实验的id,和当前示例运行的状态
可以看每次trial的运行情况,计算出来的目标值
也可以看某次trial中每次epoch得到的结果值
由于是分布式进行超参数搜索,日志保存在分布式存储。我们想看某次搜索运行实例的日志。
在web界面看到是 日志地址为
Trial stdout:file://test-worker-2.test:/tmp/nni-experiments/$实验名/envs/$任务名
则对应分布式存储日志路径地址为
/mnt/$用户名/nni/$nni名称/log/$实验名/envs/$任务名
点击运行 | 容器 | 日志 | 清理
中的 清理
按钮
欢迎大家传播分享文章