From 586cfb1e7168f9ffc2069c2ceb9e74de89e7d4ec Mon Sep 17 00:00:00 2001 From: Xinyi YAN <41045439+yanxinyi620@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:41:21 +0800 Subject: [PATCH] update test ml toolkit --- python/wedpr_ml_toolkit/test/UserGuide.md | 42 +++ .../wedpr_ml_toolkit/test/config.properties | 1 - .../wedpr_ml_toolkit/test/test_dataset.ipynb | 209 ++++++++++++ python/wedpr_ml_toolkit/test/test_psi.ipynb | 142 ++++++++ .../wedpr_ml_toolkit/test/test_xgboost.ipynb | 311 ++++++++++++++++++ .../config/wedpr_ml_config.py | 7 + .../wedpr_ml_toolkit/context/data_context.py | 27 +- .../wedpr_ml_toolkit/context/job_context.py | 224 +++++++++++-- .../toolkit/dataset_toolkit.py | 9 +- .../transport/storage_entrypoint.py | 4 +- .../transport/wedpr_remote_job_client.py | 43 ++- .../wedpr_ml_toolkit/wedpr_ml_toolkit.py | 32 +- 12 files changed, 991 insertions(+), 60 deletions(-) create mode 100644 python/wedpr_ml_toolkit/test/UserGuide.md create mode 100644 python/wedpr_ml_toolkit/test/test_dataset.ipynb create mode 100644 python/wedpr_ml_toolkit/test/test_psi.ipynb create mode 100644 python/wedpr_ml_toolkit/test/test_xgboost.ipynb diff --git a/python/wedpr_ml_toolkit/test/UserGuide.md b/python/wedpr_ml_toolkit/test/UserGuide.md new file mode 100644 index 00000000..47a72092 --- /dev/null +++ b/python/wedpr_ml_toolkit/test/UserGuide.md @@ -0,0 +1,42 @@ +# wedpr专家模式用户手册 + +## 配置 + +1. 左侧用户目录中新建配置文件,文件命名为:config.properties +2. 配置信息参考: + +``` +access_key_id= +access_key_secret= +remote_entrypoints=http://139.159.202.235:8005,http://139.159.202.235:8006 + +agency_name=SGD +workspace_path=/user/ppc/milestone2/sgd/ +user=test_user +storage_endpoint=http://192.168.0.18:50070 +``` + +3. 通过前端页面登录,例如:http://139.159.202.235:8005/ +4. 创建个人项目空间,通过【打开jupyter】按钮进入专家模式 + +## 基础功能 + +1. 支持通过launcher启动python,jupyter,终端,文本编辑等功能 +2. 支持在用户目录空间创建/修改/删除配置文件,文本文件,bash,python notebook等格式文件 +3. 通过launcher启动python,jupyter,终端后可以正常执行对应的代码功能 + +## hdfs数据功能 + +1. 支持注册dataset,支持两种方式: pd.Dataframe, hdfs_path +2. 支持更新dataset + +* 详细使用说明参考示例文件:【test_dataset.ipynb】 + +## wedpr任务功能 + +1. 支持配置任务参数 +2. 支持提交psi,建模训练,预测等任务 +3. 支持获取任务结果 +4. 支持对任务结果进行明文处理 + +* 详细使用说明参考示例文件:【test_psi.ipynb】和【test_xgboost.ipynb】 diff --git a/python/wedpr_ml_toolkit/test/config.properties b/python/wedpr_ml_toolkit/test/config.properties index 358d170d..ac88fcad 100644 --- a/python/wedpr_ml_toolkit/test/config.properties +++ b/python/wedpr_ml_toolkit/test/config.properties @@ -6,4 +6,3 @@ agency_name=SGD workspace_path=/user/wedpr/milestone2/sgd/ user=test_user storage_endpoint=http://127.0.0.1:50070 - diff --git a/python/wedpr_ml_toolkit/test/test_dataset.ipynb b/python/wedpr_ml_toolkit/test/test_dataset.ipynb new file mode 100644 index 00000000..e6fcef07 --- /dev/null +++ b/python/wedpr_ml_toolkit/test/test_dataset.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['/usr/lib/python3/dist-packages/wedpr_ml_toolkit/', 'd:\\\\github\\\\wedpr3.0\\\\WeDPR-Component\\\\python\\\\wedpr_ml_toolkit', 'd:\\\\github\\\\wedpr3.0\\\\WeDPR-Component\\\\python', 'd:\\\\github\\\\wedpr3.0\\\\WeDPR-Component\\\\python', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\python38.zip', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\DLLs', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\lib', 'c:\\\\Users\\\\yanxi\\\\anaconda3', '', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\lib\\\\site-packages', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\lib\\\\site-packages\\\\win32', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\lib\\\\site-packages\\\\win32\\\\lib', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\lib\\\\site-packages\\\\Pythonwin', 'c:\\\\Users\\\\yanxi\\\\anaconda3\\\\lib\\\\site-packages\\\\IPython\\\\extensions', 'C:\\\\Users\\\\yanxi\\\\.ipython']\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder\n", + "from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit\n", + "from wedpr_ml_toolkit.toolkit.dataset_toolkit import DatasetToolkit" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 读取配置文件\n", + "wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file('config.properties')\n", + "wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "http://139.159.202.235:50070 /user/ppc/milestone2/sgd/test_user SGD\n", + "/user/ppc/milestone2/sgd/test_user\\d-101\n", + " id y x1 x2 x3 x4 x5 x6 \\\n", + "0 0 1 0.954183 0.652034 0.704070 0.180889 0.025025 0.511596 \n", + "1 1 1 0.302088 0.462222 0.435542 0.029966 0.931294 0.848483 \n", + "2 2 1 0.468104 0.430161 0.239322 0.588153 0.470668 0.225856 \n", + "3 3 0 0.152269 0.811666 0.834451 0.354288 0.635447 0.062092 \n", + "4 4 0 0.841470 0.800512 0.451507 0.118651 0.748845 0.557916 \n", + "\n", + " x7 x8 x9 x10 \n", + "0 0.529848 0.759689 0.159081 0.556419 \n", + "1 0.962787 0.224096 0.464418 0.208487 \n", + "2 0.564879 0.730366 0.394245 0.299081 \n", + "3 0.424057 0.202234 0.577448 0.636958 \n", + "4 0.030906 0.514350 0.340864 0.123303 \n" + ] + } + ], + "source": [ + "# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path\n", + "# 1. pd.Dataframe\n", + "df = pd.DataFrame({\n", + " 'id': np.arange(0, 100), # id列,顺序整数\n", + " 'y': np.random.randint(0, 2, size=100),\n", + " # x1到x10列,随机数\n", + " **{f'x{i}': np.random.rand(100) for i in range(1, 11)}\n", + "})\n", + "\n", + "dataset1 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),\n", + " storage_workspace=wedpr_config.user_config.get_workspace_path(),\n", + " dataset_owner='flyhuang1',\n", + " agency=wedpr_config.user_config.agency_name,\n", + " values=df,\n", + " is_label_holder=True)\n", + "print(dataset1.storage_client.storage_client.endpoint, dataset1.storage_workspace, dataset1.agency)\n", + "dataset1.storage_client = None # 本地测试时跳过hdfs上传/下载过程\n", + "dataset1.save_values(path='d-101')\n", + "print(dataset1.dataset_path)\n", + "print(dataset1.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "http://139.159.202.235:50070 /user/ppc/milestone2/sgd/test_user WeBank\n", + "/user/ppc/milestone2/webank/flyhuang/d-9606695119693829\n", + "/user/ppc/milestone2/webank/flyhuang/d-9606695119693829\n", + " id z1 z2 z3 z4 z5 z6 z7 \\\n", + "0 0 0.597205 0.942475 0.886443 0.560584 0.254432 0.370152 0.076031 \n", + "1 1 0.778616 0.607374 0.616211 0.602282 0.385989 0.816963 0.756814 \n", + "2 2 0.999795 0.596794 0.240741 0.241070 0.857676 0.342412 0.066459 \n", + "3 3 0.968410 0.895163 0.636140 0.978791 0.237098 0.095272 0.938806 \n", + "4 4 0.921513 0.454901 0.004514 0.769216 0.627185 0.676253 0.184952 \n", + "\n", + " z8 z9 z10 \n", + "0 0.587627 0.851390 0.864929 \n", + "1 0.661537 0.865674 0.050091 \n", + "2 0.473916 0.080120 0.477873 \n", + "3 0.452399 0.953515 0.405465 \n", + "4 0.877475 0.316322 0.139290 \n" + ] + } + ], + "source": [ + "# 2. hdfs_path\n", + "dataset2 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), \n", + " storage_workspace=wedpr_config.user_config.get_workspace_path(), \n", + " dataset_owner='flyhuang',\n", + " dataset_path=\"/user/ppc/milestone2/webank/flyhuang/d-9606695119693829\", \n", + " agency=\"WeBank\")\n", + "print(dataset2.storage_client.storage_client.endpoint, dataset2.storage_workspace, dataset2.agency)\n", + "print(dataset2.dataset_path)\n", + "dataset2.storage_client = None # 本地测试时跳过hdfs上传/下载过程\n", + "\n", + "# 提供本地测试数据\n", + "if dataset2.storage_client is None:\n", + " # 支持更新dataset的values数据\n", + " df2 = pd.DataFrame({\n", + " 'id': np.arange(0, 100), # id列,顺序整数\n", + " **{f'z{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数\n", + " })\n", + " dataset2.update_values(values=df2)\n", + " dataset2.save_values()\n", + " print(dataset2.dataset_path)\n", + " print(dataset2.values.head())\n", + "\n", + "# 对于己方数据集支持load_values,其他方数据集无需load_values,可直接使用\n", + "if dataset2.storage_client is not None:\n", + " # 仅支持load本机构hdfs的数据集\n", + " dataset2.load_values(header=0)\n", + " print(dataset2.dataset_path)\n", + " print(dataset2.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/user/ppc/milestone2/sgd/test_user\\d-101\n", + " id y x1 x2 x3 x4 x5 x6 \\\n", + "0 0 1 0.954183 0.652034 0.704070 0.180889 0.025025 0.511596 \n", + "1 1 1 0.302088 0.462222 0.435542 0.029966 0.931294 0.848483 \n", + "2 2 1 0.468104 0.430161 0.239322 0.588153 0.470668 0.225856 \n", + "3 3 0 0.152269 0.811666 0.834451 0.354288 0.635447 0.062092 \n", + "4 4 0 0.841470 0.800512 0.451507 0.118651 0.748845 0.557916 \n", + "\n", + " x7 x8 x9 x10 \n", + "0 0.529848 0.759689 0.159081 0.556419 \n", + "1 0.962787 0.224096 0.464418 0.208487 \n", + "2 0.564879 0.730366 0.394245 0.299081 \n", + "3 0.424057 0.202234 0.577448 0.636958 \n", + "4 0.030906 0.514350 0.340864 0.123303 \n" + ] + } + ], + "source": [ + "# 更新数据集\n", + "if dataset1.storage_client is not None:\n", + " dataset1.update_values(\n", + " path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485')\n", + " dataset1.load_values(header=0)\n", + "print(dataset1.dataset_path)\n", + "print(dataset1.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/wedpr_ml_toolkit/test/test_psi.ipynb b/python/wedpr_ml_toolkit/test/test_psi.ipynb new file mode 100644 index 00000000..435b3586 --- /dev/null +++ b/python/wedpr_ml_toolkit/test/test_psi.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder\n", + "from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit\n", + "from wedpr_ml_toolkit.toolkit.dataset_toolkit import DatasetToolkit\n", + "from wedpr_ml_toolkit.context.data_context import DataContext\n", + "from wedpr_ml_toolkit.context.job_context import JobType" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 读取配置文件\n", + "wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file('config.properties')\n", + "wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 注册 dataset1\n", + "dataset1 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),\n", + " storage_workspace=wedpr_config.user_config.get_workspace_path(),\n", + " dataset_owner='flyhuang1',\n", + " agency=wedpr_config.user_config.agency_name,\n", + " dataset_id = 'd-9606704699156485',\n", + " dataset_path=\"/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485\",\n", + " is_label_holder=True)\n", + "print(dataset1.storage_client.storage_client.endpoint, dataset1.storage_workspace, dataset1.agency)\n", + "dataset1.load_values(header=0)\n", + "print(dataset1.dataset_path)\n", + "print(dataset1.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 注册 dataset2\n", + "dataset2 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), \n", + " storage_workspace=wedpr_config.user_config.get_workspace_path(), \n", + " dataset_owner='flyhuang',\n", + " dataset_id = 'd-9606695119693829',\n", + " dataset_path=\"/user/ppc/milestone2/webank/flyhuang/d-9606695119693829\", \n", + " agency=\"WeBank\")\n", + "print(dataset2.storage_client.storage_client.endpoint, dataset2.storage_workspace, dataset2.agency)\n", + "dataset2.load_values(header=0)\n", + "print(dataset2.dataset_path)\n", + "print(dataset2.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 构建 dataset context\n", + "dataset = DataContext(dataset1, dataset2)\n", + "print(dataset.datasets)\n", + "\n", + "# init the job context\n", + "project_id = \"9606702107011078\"\n", + "\n", + "# 构造psi任务配置\n", + "psi_job_context = wedpr_ml_toolkit.build_job_context(\n", + " JobType.PSI, project_id, dataset, None, \"id\")\n", + "print(psi_job_context.participant_id_list, psi_job_context.result_receiver_id_list)\n", + "print(psi_job_context.project_id)\n", + "\n", + "psi_job_param = psi_job_context.build()\n", + "print(psi_job_param.taskParties)\n", + "print(psi_job_param.datasetList)\n", + "print(psi_job_param.job)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 执行psi任务\n", + "# psi_job_id = '9670241574201350' # 测试时跳过创建新任务过程\n", + "psi_job_id = psi_job_context.submit()\n", + "print(psi_job_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 获取psi任务结果\n", + "# psi_job_id = '9670241574201350' # 测试时跳过创建新任务过程\n", + "print(psi_job_id)\n", + "psi_result = psi_job_context.parse_result(psi_job_id, True)\n", + "psi_result.load_values()\n", + "print(psi_result.values.shape)\n", + "print(psi_result.values.head())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/wedpr_ml_toolkit/test/test_xgboost.ipynb b/python/wedpr_ml_toolkit/test/test_xgboost.ipynb new file mode 100644 index 00000000..84082084 --- /dev/null +++ b/python/wedpr_ml_toolkit/test/test_xgboost.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder\n", + "from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit\n", + "from wedpr_ml_toolkit.toolkit.dataset_toolkit import DatasetToolkit\n", + "from wedpr_ml_toolkit.context.data_context import DataContext\n", + "from wedpr_ml_toolkit.context.job_context import JobType" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 读取配置文件\n", + "wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file('config.properties')\n", + "wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 注册 dataset1\n", + "dataset1 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),\n", + " storage_workspace=wedpr_config.user_config.get_workspace_path(),\n", + " dataset_owner='flyhuang1',\n", + " agency=wedpr_config.user_config.agency_name,\n", + " dataset_id = 'd-9606704699156485',\n", + " dataset_path=\"/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485\",\n", + " is_label_holder=True)\n", + "print(dataset1.storage_client.storage_client.endpoint, dataset1.storage_workspace, dataset1.agency)\n", + "dataset1.load_values(header=0)\n", + "print(dataset1.dataset_path)\n", + "print(dataset1.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 注册 dataset2\n", + "dataset2 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), \n", + " storage_workspace=wedpr_config.user_config.get_workspace_path(), \n", + " dataset_owner='flyhuang',\n", + " dataset_id = 'd-9606695119693829',\n", + " dataset_path=\"/user/ppc/milestone2/webank/flyhuang/d-9606695119693829\", \n", + " agency=\"WeBank\")\n", + "print(dataset2.storage_client.storage_client.endpoint, dataset2.storage_workspace, dataset2.agency)\n", + "dataset2.load_values(header=0)\n", + "print(dataset2.dataset_path)\n", + "print(dataset2.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 构建 dataset context\n", + "dataset = DataContext(dataset1, dataset2)\n", + "print(dataset.datasets)\n", + "\n", + "# init the job context\n", + "project_id = \"9606702107011078\"\n", + "\n", + "# 构造xgb任务配置\n", + "model_setting = {'use_psi': 0, 'fillna': 0, 'na_select': 1, 'filloutlier': 0, 'normalized': 0, 'standardized': 0, 'categorical': '', 'psi_select_col': '', 'psi_select_base': '', 'psi_select_thresh': 0.3, 'psi_select_bins': 4, 'corr_select': 0, 'use_iv': 0, 'group_num': 4, 'iv_thresh': 0.1, 'use_goss': 0, 'test_dataset_percentage': 0.3, 'learning_rate': 0.1, 'num_trees': 6, 'max_depth': 3, 'max_bin': 4, 'silent': 0, 'subsample': 1, 'colsample_bytree': 1, 'colsample_bylevel': 1, 'reg_alpha': 0, 'reg_lambda': 1, 'gamma': 0, 'min_child_weight': 0, 'min_child_samples': 10, 'seed': 2024, 'early_stopping_rounds': 0, 'eval_metric': 'auc', 'verbose_eval': 1, 'eval_set_column': '', 'train_set_value': '', 'eval_set_value': '', 'train_features': ''}\n", + "\n", + "xgb_job_context = wedpr_ml_toolkit.build_job_context(\n", + " JobType.XGB_TRAINING, project_id, dataset, model_setting, \"id\")\n", + "print(xgb_job_context.participant_id_list, xgb_job_context.result_receiver_id_list)\n", + "print(xgb_job_context.project_id)\n", + "\n", + "xgb_job_param = xgb_job_context.build()\n", + "print(xgb_job_param.taskParties)\n", + "print(xgb_job_param.datasetList)\n", + "print(xgb_job_param.job)\n", + "# import json\n", + "# print(json.dumps(xgb_job_param.__dict__))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 执行xgb任务\n", + "# xgb_job_id = '9707983191943174' # 测试时跳过创建新任务过程\n", + "xgb_job_id = xgb_job_context.submit()\n", + "print(xgb_job_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 获取xgb任务结果\n", + "# xgb_job_id = '9707983191943174' # 测试时跳过创建新任务过程\n", + "print(xgb_job_id)\n", + "xgb_result = xgb_job_context.parse_result(xgb_job_id, True)\n", + "xgb_result.train_result.load_values(header = 0)\n", + "xgb_result.test_result.load_values(header = 0)\n", + "print(xgb_result.train_result.values.head())\n", + "print(xgb_result.test_result.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 明文处理预测结果\n", + "from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score\n", + "import matplotlib.pyplot as plt\n", + "\n", + "data = xgb_result.test_result.values\n", + "\n", + "# 提取真实标签和预测概率\n", + "y_true = data['class_label']\n", + "y_pred_proba = data['class_pred']\n", + "y_pred = np.where(y_pred_proba >= 0.5, 1, 0) # 二分类阈值设为0.5\n", + "\n", + "# 计算评估指标\n", + "accuracy = accuracy_score(y_true, y_pred)\n", + "precision = precision_score(y_true, y_pred)\n", + "recall = recall_score(y_true, y_pred)\n", + "f1 = f1_score(y_true, y_pred)\n", + "auc = roc_auc_score(y_true, y_pred_proba)\n", + "\n", + "print(f\"Accuracy: {accuracy:.2f}\")\n", + "print(f\"Precision: {precision:.2f}\")\n", + "print(f\"Recall: {recall:.2f}\")\n", + "print(f\"F1 Score: {f1:.2f}\")\n", + "print(f\"AUC: {auc:.2f}\")\n", + "\n", + "# ROC 曲线\n", + "fpr, tpr, _ = roc_curve(y_true, y_pred_proba)\n", + "plt.figure(figsize=(12, 5))\n", + "\n", + "# ROC 曲线\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')\n", + "plt.plot([0, 1], [0, 1], 'k--')\n", + "plt.xlabel('False Positive Rate')\n", + "plt.ylabel('True Positive Rate')\n", + "plt.title('ROC Curve')\n", + "plt.legend()\n", + "\n", + "# 精确率-召回率曲线\n", + "precision_vals, recall_vals, _ = precision_recall_curve(y_true, y_pred_proba)\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(recall_vals, precision_vals)\n", + "plt.xlabel('Recall')\n", + "plt.ylabel('Precision')\n", + "plt.title('Precision-Recall Curve')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 构造xgb预测任务配置\n", + "model_setting = {'use_psi': 0, 'use_iv': 0}\n", + "\n", + "xgb_job_context = wedpr_ml_toolkit.build_job_context(\n", + " JobType.XGB_PREDICTING, project_id, dataset, model_setting, \"id\", xgb_result.model)\n", + "print(xgb_job_context.participant_id_list, xgb_job_context.result_receiver_id_list)\n", + "print(xgb_job_context.project_id)\n", + "\n", + "xgb_job_param = xgb_job_context.build()\n", + "print(xgb_job_param.taskParties)\n", + "print(xgb_job_param.datasetList)\n", + "# import json\n", + "# print(json.dumps(xgb_job_param.__dict__))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 执行xgb预测任务\n", + "# xgb_job_id = '9708824062994438' # 测试时跳过创建新任务过程\n", + "xgb_job_id = xgb_job_context.submit()\n", + "print(xgb_job_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 获取xgb预测任务结果\n", + "# xgb_job_id = '9708824062994438' # 测试时跳过创建新任务过程\n", + "print(xgb_job_id)\n", + "xgb_result = xgb_job_context.parse_result(xgb_job_id, True)\n", + "xgb_result.test_result.load_values(header = 0)\n", + "print(xgb_result.test_result.values.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 明文处理预测结果\n", + "from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score\n", + "import matplotlib.pyplot as plt\n", + "\n", + "data = xgb_result.test_result.values\n", + "\n", + "# 提取真实标签和预测概率\n", + "y_true = data['class_label']\n", + "y_pred_proba = data['class_pred']\n", + "y_pred = np.where(y_pred_proba >= 0.5, 1, 0) # 二分类阈值设为0.5\n", + "\n", + "# 计算评估指标\n", + "accuracy = accuracy_score(y_true, y_pred)\n", + "precision = precision_score(y_true, y_pred)\n", + "recall = recall_score(y_true, y_pred)\n", + "f1 = f1_score(y_true, y_pred)\n", + "auc = roc_auc_score(y_true, y_pred_proba)\n", + "\n", + "print(f\"Accuracy: {accuracy:.2f}\")\n", + "print(f\"Precision: {precision:.2f}\")\n", + "print(f\"Recall: {recall:.2f}\")\n", + "print(f\"F1 Score: {f1:.2f}\")\n", + "print(f\"AUC: {auc:.2f}\")\n", + "\n", + "# ROC 曲线\n", + "fpr, tpr, _ = roc_curve(y_true, y_pred_proba)\n", + "plt.figure(figsize=(12, 5))\n", + "\n", + "# ROC 曲线\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')\n", + "plt.plot([0, 1], [0, 1], 'k--')\n", + "plt.xlabel('False Positive Rate')\n", + "plt.ylabel('True Positive Rate')\n", + "plt.title('ROC Curve')\n", + "plt.legend()\n", + "\n", + "# 精确率-召回率曲线\n", + "precision_vals, recall_vals, _ = precision_recall_curve(y_true, y_pred_proba)\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(recall_vals, precision_vals)\n", + "plt.xlabel('Recall')\n", + "plt.ylabel('Precision')\n", + "plt.title('Precision-Recall Curve')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/config/wedpr_ml_config.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/config/wedpr_ml_config.py index 3e8e8830..7edd83b9 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/config/wedpr_ml_config.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/config/wedpr_ml_config.py @@ -51,6 +51,11 @@ def __init__(self, timeout_seconds=3): self.timeout_seconds = timeout_seconds +class AgencyConfig(BaseObject): + def __init__(self, agency_name=None): + self.agency_name = agency_name + + class WeDPRMlConfig: def __init__(self, config_dict): self.auth_config = AuthConfig() @@ -63,6 +68,8 @@ def __init__(self, config_dict): self.user_config.set_params(**config_dict) self.http_config = HttpConfig() self.http_config.set_params(**config_dict) + self.agency_config = AgencyConfig() + self.agency_config.set_params(**config_dict) class WeDPRMlConfigBuilder: diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/data_context.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/data_context.py index ade5b8bb..90b6119c 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/data_context.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/data_context.py @@ -7,7 +7,6 @@ class DataContext: def __init__(self, *datasets): self.datasets = list(datasets) - self.ctx = self.datasets[0].ctx self._check_datasets() @@ -28,13 +27,13 @@ def _check_datasets(self): def to_psi_format(self, merge_filed, result_receiver_id_list): dataset_psi = [] for dataset in self.datasets: - if dataset.agency.agency_id in result_receiver_id_list: + if dataset.agency in result_receiver_id_list: result_receiver = "true" else: result_receiver = "false" dataset_psi_info = {"idFields": [merge_filed], - "dataset": {"owner": dataset.ctx.user_name, - "ownerAgency": dataset.agency.agency_id, + "dataset": {"owner": dataset.dataset_owner, + "ownerAgency": dataset.agency, "path": dataset.dataset_path, "storageTypeStr": "HDFS", "datasetID": dataset.dataset_id}, @@ -42,8 +41,24 @@ def to_psi_format(self, merge_filed, result_receiver_id_list): dataset_psi.append(dataset_psi_info) return dataset_psi - def to_model_formort(self): + def to_model_formort(self, merge_filed, result_receiver_id_list): dataset_model = [] for dataset in self.datasets: - dataset_model.append(dataset.dataset_path) + if dataset.agency in result_receiver_id_list: + result_receiver = "true" + else: + result_receiver = "false" + if dataset.is_label_holder: + label_provider = "true" + else: + label_provider = "false" + dataset_psi_info = {"idFields": [merge_filed], + "dataset": {"owner": dataset.dataset_owner, + "ownerAgency": dataset.agency, + "path": dataset.dataset_path, + "storageTypeStr": "HDFS", + "datasetID": dataset.dataset_id}, + "labelProvider": label_provider, + "receiveResult": result_receiver} + dataset_model.append(dataset_psi_info) return dataset_model diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/job_context.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/job_context.py index 90355003..7bcfb109 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/job_context.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/job_context.py @@ -1,21 +1,25 @@ # -*- coding: utf-8 -*- import json +from wedpr_ml_toolkit.toolkit.dataset_toolkit import DatasetToolkit +from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint from wedpr_ml_toolkit.context.data_context import DataContext from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobParam from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo from abc import abstractmethod from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient -from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobType +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobType, ModelType +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import ModelResult class JobContext: - def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, dataset: DataContext = None, my_agency=None): if dataset is None: raise Exception("Must define the job related datasets!") self.remote_job_client = remote_job_client - self.project_name = project_name + self.storage_entry_point = storage_entry_point + self.project_id = project_id self.dataset = dataset self.create_agency = my_agency self.participant_id_list = [] @@ -42,10 +46,10 @@ def __init_participant__(self): participant_id_list = [] dataset_id_list = [] for dataset in self.dataset.datasets: - participant_id_list.append(dataset.agency.agency_id) + participant_id_list.append(dataset.agency) dataset_id_list.append(dataset.dataset_id) - self.task_parties.append({'userName': dataset.ctx.user_name, - 'agency': dataset.agency.agency_id}) + self.task_parties.append({'userName': dataset.dataset_owner, + 'agency': dataset.agency}) self.participant_id_list = participant_id_list self.dataset_id_list = dataset_id_list @@ -54,7 +58,7 @@ def __init_label_information__(self): label_columns = None for dataset in self.dataset.datasets: if dataset.is_label_holder: - label_holder_agency = dataset.agency.agency_id + label_holder_agency = dataset.agency label_columns = 'y' self.label_holder_agency = label_holder_agency self.label_columns = label_columns @@ -71,18 +75,30 @@ def submit(self): return self.remote_job_client.submit_job(self.build()) @abstractmethod - def parse_result(self, result_detail): + def parse_result(self, job_id, block_until_success): pass def fetch_job_result(self, job_id, block_until_success): - job_result = self.query_job_status(job_id, block_until_success) - # TODO: determine success or not here - return self.parse_result(self.remote_job_client.query_job_detail(job_id)) + # job_result = self.query_job_status(job_id, block_until_success) + # # TODO: determine success or not here + # return self.parse_result(self.remote_job_client.query_job_detail(job_id, block_until_success)) + + # # query_job_status + # job_result = self.remote_job_client.poll_job_result(job_id, block_until_success) + # # failed case + # if job_result == None or job_result.job_status == None or (not job_result.job_status.run_success()): + # raise Exception(f'job {job_id} running failed!') + # # success case + # ... + + # query_job_detail + result_detail = self.remote_job_client.query_job_detail(job_id, block_until_success) + return result_detail class PSIJobContext(JobContext): - def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): - super().__init__(remote_job_client, project_name, dataset, my_agency) + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) self.merge_field = merge_field def get_job_type(self) -> JobType: @@ -91,59 +107,205 @@ def get_job_type(self) -> JobType: def build(self) -> JobParam: self.dataset_list = self.dataset.to_psi_format( self.merge_field, self.result_receiver_id_list) - job_info = JobInfo(job_type=self.get_job_type(), project_name=self.project_name, param=json.dumps( - {'dataSetList': self.dataset_list}).replace('"', '\\"')) + # job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + # {'dataSetList': self.dataset_list}).replace('"', '\\"')) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list})) job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) return job_param + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + + psi_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.resultFileInfo['path'], agency=self.create_agency) + + return psi_result + class PreprocessingJobContext(JobContext): - def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): - super().__init__(remote_job_client, project_name, dataset, my_agency) + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, model_setting, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) self.model_setting = model_setting + self.merge_field = merge_field def get_job_type(self) -> JobType: return JobType.PREPROCESSING - # TODO: build the request def build(self) -> JobParam: - return None + self.dataset_list = self.dataset.to_model_formort( + self.merge_field, self.result_receiver_id_list) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list, 'modelSetting': self.model_setting})) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + + pre_result = result_detail + # pre_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + # storage_workspace=None, + # dataset_owner=self.storage_entry_point.user_config.user, + # dataset_path=result_detail.resultFileInfo['path'], agency=self.create_agency) + + return pre_result class FeatureEngineeringJobContext(JobContext): - def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): - super().__init__(remote_job_client, project_name, dataset, my_agency) + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, model_setting, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) self.model_setting = model_setting + self.merge_field = merge_field def get_job_type(self) -> JobType: return JobType.FEATURE_ENGINEERING - # TODO: build the jobParam def build(self) -> JobParam: - return None + self.dataset_list = self.dataset.to_model_formort( + self.merge_field, self.result_receiver_id_list) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list, 'modelSetting': self.model_setting})) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + + fe_result = result_detail + # fe_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + # storage_workspace=None, + # dataset_owner=self.storage_entry_point.user_config.user, + # dataset_path=result_detail.resultFileInfo['path'], agency=self.create_agency) + + return fe_result class SecureLGBMTrainingJobContext(JobContext): - def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): - super().__init__(remote_job_client, project_name, dataset, my_agency) + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, model_setting, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) self.model_setting = model_setting + self.merge_field = merge_field def get_job_type(self) -> JobType: return JobType.XGB_TRAINING - # TODO: build the jobParam def build(self) -> JobParam: - return None + self.dataset_list = self.dataset.to_model_formort( + self.merge_field, self.result_receiver_id_list) + # job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + # {'dataSetList': self.dataset_list}).replace('"', '\\"')) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list, 'modelSetting': self.model_setting})) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + # result_detail.modelResultDetail['ModelResult'] + train_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.modelResultDetail['ModelResult']['trainResultPath'], agency=self.create_agency) + test_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.modelResultDetail['ModelResult']['testResultPath'], agency=self.create_agency) + + xgb_result = ModelResult(job_id, train_result, test_result, result_detail.model, ModelType.XGB_MODEL_SETTING.name) + return xgb_result class SecureLGBMPredictJobContext(JobContext): - def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): - super().__init__(remote_job_client, project_name, dataset, my_agency) + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, model_setting, predict_algorithm, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) self.model_setting = model_setting + self.merge_field = merge_field + self.predict_algorithm = predict_algorithm def get_job_type(self) -> JobType: return JobType.XGB_PREDICTING - # TODO: build the jobParam def build(self) -> JobParam: - return None + self.dataset_list = self.dataset.to_model_formort( + self.merge_field, self.result_receiver_id_list) + # job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + # {'dataSetList': self.dataset_list}).replace('"', '\\"')) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list, 'modelSetting': self.model_setting, 'modelPredictAlgorithm': json.dumps(self.predict_algorithm)})) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + test_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.modelResultDetail['ModelResult']['testResultPath'], agency=self.create_agency) + + xgb_result = ModelResult(job_id, test_result=test_result) + return xgb_result + + +class SecureLRTrainingJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, model_setting, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) + self.model_setting = model_setting + self.merge_field = merge_field + + def get_job_type(self) -> JobType: + return JobType.LR_TRAINING + + def build(self) -> JobParam: + self.dataset_list = self.dataset.to_model_formort( + self.merge_field, self.result_receiver_id_list) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list, 'modelSetting': self.model_setting})) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + # result_detail.modelResultDetail['ModelResult'] + train_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.modelResultDetail['ModelResult']['trainResultPath'], agency=self.create_agency) + test_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.modelResultDetail['ModelResult']['testResultPath'], agency=self.create_agency) + + lr_result = ModelResult(job_id, train_result, test_result, result_detail.model, ModelType.LR_MODEL_SETTING.name) + return lr_result + + +class SecureLRPredictJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, storage_entry_point: StorageEntryPoint, project_id: str, model_setting, predict_algorithm, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, storage_entry_point, project_id, dataset, my_agency) + self.model_setting = model_setting + self.merge_field = merge_field + self.predict_algorithm = predict_algorithm + + def get_job_type(self) -> JobType: + return JobType.LR_PREDICTING + + def build(self) -> JobParam: + self.dataset_list = self.dataset.to_model_formort( + self.merge_field, self.result_receiver_id_list) + job_info = JobInfo(job_type=self.get_job_type(), project_id=self.project_id, param=json.dumps( + {'dataSetList': self.dataset_list, 'modelSetting': self.model_setting, 'modelPredictAlgorithm': json.dumps(self.predict_algorithm)})) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + def parse_result(self, job_id, block_until_success): + result_detail = self.fetch_job_result(job_id, block_until_success) + test_result = DatasetToolkit(storage_entrypoint=self.storage_entry_point, + storage_workspace=None, + dataset_owner=self.storage_entry_point.user_config.user, + dataset_path=result_detail.modelResultDetail['ModelResult']['testResultPath'], agency=self.create_agency) + + lr_result = ModelResult(job_id, test_result=test_result) + return lr_result diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/toolkit/dataset_toolkit.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/toolkit/dataset_toolkit.py index 2e58a3c7..cf9f3327 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/toolkit/dataset_toolkit.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/toolkit/dataset_toolkit.py @@ -10,11 +10,13 @@ def __init__(self, storage_workspace, dataset_id=None, dataset_path=None, + dataset_owner=None, agency=None, values=None, is_label_holder=False): self.dataset_id = dataset_id self.dataset_path = dataset_path + self.dataset_owner = dataset_owner self.agency = agency self.values = values self.is_label_holder = is_label_holder @@ -28,10 +30,10 @@ def __init__(self, self.columns = self.values.columns self.shape = self.values.shape - def load_values(self): + def load_values(self, header = None): # 加载hdfs的数据集 if self.storage_client is not None: - self.values = self.storage_client.download(self.dataset_path) + self.values = self.storage_client.download(self.dataset_path, header=header) self.columns = self.values.columns self.shape = self.values.shape @@ -39,7 +41,8 @@ def save_values(self, path=None): # 保存数据到hdfs目录 if path is not None: self.dataset_path = path - if not self.dataset_path.startswith(self.storage_workspace): + if self.storage_workspace is not None and \ + not self.dataset_path.startswith(self.storage_workspace): self.dataset_path = os.path.join( self.storage_workspace, self.dataset_path) if self.storage_client is not None: diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/storage_entrypoint.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/storage_entrypoint.py index 8ba01aa9..371ef823 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/storage_entrypoint.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/storage_entrypoint.py @@ -26,14 +26,14 @@ def upload(self, dataframe, hdfs_path): self.storage_client.save_data(csv_buffer.getvalue(), hdfs_path) return - def download(self, hdfs_path): + def download(self, hdfs_path, header=None): """ 从HDFS下载数据并返回为Pandas DataFrame :param hdfs_path: HDFS文件路径 :return: Pandas DataFrame """ content = self.storage_client.get_data(hdfs_path) - dataframe = pd.read_csv(io.BytesIO(content)) + dataframe = pd.read_csv(io.BytesIO(content), header=header) return dataframe def download_byte(self, hdfs_path): diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py index 8b3795a3..c04523d2 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py @@ -16,6 +16,13 @@ class JobType(Enum): FEATURE_ENGINEERING = "FEATURE_ENGINEERING", XGB_TRAINING = "XGB_TRAINING", XGB_PREDICTING = "XGB_PREDICTING" + LR_TRAINING = "LR_TRAINING", + LR_PREDICTING = "LR_PREDICTING" + + +class ModelType(Enum): + XGB_MODEL_SETTING = "XGB_MODEL_SETTING", + LR_MODEL_SETTING = "LR_MODEL_SETTING", class JobStatus(Enum): @@ -53,8 +60,9 @@ def get_job_status(job_status_tr: str): class JobInfo(BaseObject): - def __init__(self, job_id: str = None, job_type: JobType = None, project_name: str = None, param: str = None, **params: Any): - self.id = job_id + def __init__(self, job_id: str = None, job_type: JobType = None, project_id: str = None, param: str = None, **params: Any): + if job_id is not None: + self.id = job_id self.name = None self.owner = None self.ownerAgency = None @@ -63,7 +71,7 @@ def __init__(self, job_id: str = None, job_type: JobType = None, project_name: s else: self.jobType = None self.parties = None - self.projectName = project_name + self.projectId = project_id self.param = param self.status = None self.result = None @@ -76,9 +84,34 @@ def __repr__(self): return f"job_id: {self.id}, owner: {self.owner}, ownerAgency: {self.ownerAgency}, jobType: {self.jobType}, status: {self.status}" +class ModelInfo(BaseObject): + def __init__(self, model, model_type, **params: Any): + + self.type = model_type + # self.setting = json.loads(model) + self.setting = model + self.startTime = None + self.endTime = None + self.step = None + self.id = None + self.name = None + self.agency = None + self.owner = None + + self.set_params(**params) + + +class ModelResult: + def __init__(self, job_id: str, train_result = None, test_result = None, model = None, model_type = None): + self.job_id = job_id + self.train_result = train_result + self.test_result = test_result + self.model = ModelInfo(model, model_type).__dict__ + + class JobParam: def __init__(self, job_info: JobInfo, task_parities, dataset_list): - self.job = job_info + self.job = job_info.__dict__ self.taskParties = task_parities self.datasetList = dataset_list @@ -159,7 +192,7 @@ def get_job_config(self): def submit_job(self, job_params: JobParam) -> WeDPRResponse: wedpr_response = self.send_request(True, - self.job_config._submit_job_uri, None, None, json.dumps(job_params)) + self.job_config.submit_job_uri, None, None, json.dumps(job_params.__dict__)) submit_result = WeDPRResponse(**wedpr_response) # return the job_id if submit_result.success(): diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/wedpr_ml_toolkit.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/wedpr_ml_toolkit.py index 6d8b008e..ac999867 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit/wedpr_ml_toolkit.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit/wedpr_ml_toolkit.py @@ -11,6 +11,8 @@ from wedpr_ml_toolkit.context.job_context import FeatureEngineeringJobContext from wedpr_ml_toolkit.context.job_context import SecureLGBMPredictJobContext from wedpr_ml_toolkit.context.job_context import SecureLGBMTrainingJobContext +from wedpr_ml_toolkit.context.job_context import SecureLRPredictJobContext +from wedpr_ml_toolkit.context.job_context import SecureLRTrainingJobContext from wedpr_ml_toolkit.context.data_context import DataContext @@ -40,20 +42,26 @@ def query_job_status(self, job_id, block_until_finish=False) -> JobInfo: def query_job_detail(self, job_id, block_until_finish=False) -> JobDetailResponse: return self.remote_job_client.query_job_detail(job_id, block_until_finish) - def build_job_context(self, job_type: JobType, project_name: str, dataset: DataContext, model_setting=None, - id_fields='id'): + def build_job_context(self, job_type: JobType, project_id: str, dataset: DataContext, model_setting=None, + id_fields='id', predict_algorithm=None): if job_type == JobType.PSI: - return PSIJobContext(self.remote_job_client, project_name, dataset, self.config.agency_config.agency_name, - id_fields) + return PSIJobContext(self.remote_job_client, self.storage_entry_point, project_id, dataset, + self.config.agency_config.agency_name, id_fields) if job_type == JobType.PREPROCESSING: - return PreprocessingJobContext(self.remote_job_client, project_name, model_setting, dataset, - self.config.agency_config.agency_name) + return PreprocessingJobContext(self.remote_job_client, self.storage_entry_point, project_id, + model_setting, dataset, self.config.agency_config.agency_name, id_fields) if job_type == JobType.FEATURE_ENGINEERING: - return FeatureEngineeringJobContext(self.remote_job_client, project_name, model_setting, dataset, - self.config.agency_config.agency_name) + return FeatureEngineeringJobContext(self.remote_job_client, self.storage_entry_point, project_id, + model_setting, dataset, self.config.agency_config.agency_name, id_fields) if job_type == JobType.XGB_TRAINING: - return SecureLGBMTrainingJobContext(self.remote_job_client, project_name, model_setting, dataset, - self.config.agency_config.agency_name) + return SecureLGBMTrainingJobContext(self.remote_job_client, self.storage_entry_point, project_id, + model_setting, dataset, self.config.agency_config.agency_name, id_fields) if job_type == JobType.XGB_PREDICTING: - return SecureLGBMPredictJobContext(self.remote_job_client, project_name, model_setting, dataset, - self.config.agency_config.agency_name) + return SecureLGBMPredictJobContext(self.remote_job_client, self.storage_entry_point, project_id, + model_setting, predict_algorithm, dataset, self.config.agency_config.agency_name, id_fields) + if job_type == JobType.LR_TRAINING: + return SecureLRTrainingJobContext(self.remote_job_client, self.storage_entry_point, project_id, + model_setting, dataset, self.config.agency_config.agency_name, id_fields) + if job_type == JobType.LR_PREDICTING: + return SecureLRPredictJobContext(self.remote_job_client, self.storage_entry_point, project_id, + model_setting, predict_algorithm, dataset, self.config.agency_config.agency_name, id_fields)