Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Oct 21, 2024
1 parent 676b27e commit 8dc416a
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 95 deletions.
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ endif()
set(TRANSPORT_SDK_SOURCE_LIST
wedpr-protocol
wedpr-transport/ppc-front
wedpr-transport/sdk
wedpr-transport/sdk-wrapper)
wedpr-transport/sdk)

set(TRANSPORT_SDK_TOOLKIT_SOURCE_LIST
${TRANSPORT_SDK_SOURCE_LIST}
Expand All @@ -90,6 +89,7 @@ set(ALL_SOURCE_LIST
wedpr-helper/libhelper wedpr-helper/ppc-tools
wedpr-storage/ppc-io wedpr-storage/ppc-storage
wedpr-transport/ppc-gateway
wedpr-transport/ppc-rpc
wedpr-transport/ppc-http
wedpr-computing/ppc-psi wedpr-computing/ppc-mpc wedpr-computing/ppc-pir ${CEM_SOURCE}
wedpr-initializer wedpr-main)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load(transport_config: TransportConfig) -> Transport:
return Transport(transport, transport_config)

@staticmethod
def build(self, transport_threadpool_size: int = 4,
def build(transport_threadpool_size: int = 4,
transport_node_id: str = None,
transport_gateway_targets: str = None,
transport_host_ip: str = None,
Expand Down
76 changes: 33 additions & 43 deletions python/ppc_model/common/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ppc_common.deps_services import storage_loader
from ppc_common.ppc_utils import common_func
from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager
from wedpr_python_gateway_sdk.transport.impl.transport_loader import TransportLoader
from ppc_model.network.wedpr_model_transport import ModelTransport
from ppc_model.task.task_manager import TaskManager
Expand All @@ -15,41 +16,35 @@
class Initializer:
def __init__(self, log_config_path, config_path, plot_lock=None):
self.log_config_path = log_config_path
logging.config.fileConfig(self.log_config_path)
self.config_path = config_path
self.config_data = None
self.grpc_options = None
self.task_manager = None
self.thread_event_manager = None
self.storage_client = None
# default send msg timeout
self.transport = None
self.send_msg_timeout_ms = 5000
self.pop_msg_timeout_ms = 60000
self.MODEL_COMPONENT = "WEDPR_MODEL"
# 只用于测试
self.mock_logger = None
self.public_key_length = 2048
self.homo_algorithm = 0
self.init_config()
self.job_cache_dir = common_func.get_config_value(
"JOB_TEMP_DIR", "/tmp", self.config_data, False)
self.thread_event_manager = ThreadEventManager()
self.task_manager = TaskManager(
logger=self.logger(),
thread_event_manager=self.thread_event_manager,
task_timeout_h=self.config_data['TASK_TIMEOUT_H']
)
self.storage_client = storage_loader.load(
self.config_data, self.logger())
# default send msg timeout
self.MODEL_COMPONENT = "WEDPR_MODEL"
self.send_msg_timeout_ms = 5000
self.pop_msg_timeout_ms = 60000
# for UT
self.transport = None
# matplotlib 线程不安全,并行任务绘图增加全局锁
self.plot_lock = plot_lock
if plot_lock is None:
self.plot_lock = threading.Lock()

def init_all(self):
self.init_log()
self.init_config()
self.init_task_manager()
self.init_transport()
self.init_storage_client()
self.init_cache()

def init_log(self):
logging.config.fileConfig(self.log_config_path)

def init_cache(self):
self.job_cache_dir = common_func.get_config_value(
"JOB_TEMP_DIR", "/tmp", self.config_data, False)

def init_config(self):
with open(self.config_path, 'rb') as f:
self.config_data = yaml.safe_load(f.read())
Expand All @@ -59,34 +54,29 @@ def init_config(self):
if 'HOMO_ALGORITHM' in self.config_data:
self.homo_algorithm = self.config_data['HOMO_ALGORITHM']

def init_transport(self):
def init_all(self):
self.init_transport(task_manager=self.task_manager,
component_type=self.MODEL_COMPONENT,
send_msg_timeout_ms=self.send_msg_timeout_ms,
pop_msg_timeout_ms=self.pop_msg_timeout_ms)

def init_transport(self, task_manager: TaskManager, component_type: str, send_msg_timeout_ms: int, pop_msg_timeout_ms: int):
# create the transport
transport = TransportLoader.build(**self.config_data)
self.logger(
f"Create transport success, config: {self.get_config().desc()}")
f"Create transport success, config: {transport.get_config().desc()}")
# start the transport
transport.start()
self.logger().info(
f"Start transport success, config: {transport.get_config().desc()}")
transport.register_component(self.MODEL_COMPONENT)
transport.register_component(component_type)
self.logger().info(
f"Register the component {self.MODEL_COMPONENT} success")
f"Register the component {component_type} success")
self.transport = ModelTransport(transport=transport,
task_manager=self.task_manager,
component_type=self.MODEL_COMPONENT,
send_msg_timeout_ms=self.send_msg_timeout_ms,
pop_msg_timeout_ms=self.pop_msg_timeout_ms)

def init_task_manager(self):
self.task_manager = TaskManager(
logger=self.logger(),
thread_event_manager=self.thread_event_manager,
task_timeout_h=self.config_data['TASK_TIMEOUT_H']
)

def init_storage_client(self):
self.storage_client = storage_loader.load(
self.config_data, self.logger())
task_manager=task_manager,
component_type=component_type,
send_msg_timeout_ms=send_msg_timeout_ms,
pop_msg_timeout_ms=pop_msg_timeout_ms)

def logger(self, name=None):
if self.mock_logger is None:
Expand Down
2 changes: 1 addition & 1 deletion python/ppc_model/conf/application-sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ transport_threadpool_size: 4
transport_node_id: "MODEL_WeBank_NODE"
transport_gateway_targets: "ipv4:127.0.0.1:40600,127.0.0.1:40601"
transport_host_ip: "127.0.0.1"
transport_listen_port: 6200
transport_listen_port: 6500
2 changes: 1 addition & 1 deletion python/ppc_model/conf/logging.conf
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ formatter=simpleFormatter
[handler_consoleHandler]
class=StreamHandler
args=(sys.stdout,)
level=ERROR
level=INFO
formatter=simpleFormatter

[formatters]
Expand Down
4 changes: 4 additions & 0 deletions python/ppc_model/network/wedpr_model_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def select_node(self, route_type: RouteType, dst_agency: str, dst_component: str
return self.transport.select_node_by_route_policy(route_type=route_type,
dst_inst=dst_agency, dst_component=dst_component)

def stop(self):
self.transport.stop()


class ModelRouter(ModelRouterApi):
def __init__(self, logger, transport: ModelTransport, participant_id_list):
self.logger = logger
self.transport = transport
self.participant_id_list = participant_id_list
self.router_info = {}
self.__init_routers__()

def __init_routers__(self):
for participant in self.participant_id_list:
Expand Down
5 changes: 4 additions & 1 deletion python/ppc_model/ppc_model_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def register_task_handler():
TransLogger(app, setup_console_handler=False), numthreads=2)

protocol = 'http'
message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']}"
message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']} successfully"
print(message)
components.logger().info(message)
server.start()
# stop the nodes
components.transport.stop()
print("Stop ppc model server successfully")
2 changes: 1 addition & 1 deletion python/ppc_model/secure_lgbm/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
partner_id = ctx.participant_id_list[partner_index]

self.ctx.model_router.push(
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, data=byte_data)
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, payload=byte_data)
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
Expand Down
2 changes: 1 addition & 1 deletion python/ppc_model/secure_lr/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]
self.ctx.model_router.push(
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, data=byte_data)
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, payload=byte_data)
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
Expand Down
63 changes: 32 additions & 31 deletions python/ppc_model/tools/start.sh
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
#!/bin/bash
SHELL_FOLDER=$(cd $(dirname $0);pwd)
LOG_ERROR() {
content=${1}
echo -e "\033[31m[ERROR] ${content}\033[0m"
}

dirpath="$(cd "$(dirname "$0")" && pwd)"
cd $dirpath
LOG_INFO() {
content=${1}
echo -e "\033[32m[INFO] ${content}\033[0m"
}
binary_path=${SHELL_FOLDER}/ppc_model_app.py
cd ${SHELL_FOLDER}
node=$(basename ${SHELL_FOLDER})
node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}')

# kill crypto process
crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l`
for i in $( seq 1 $crypto_pro_num )
if [ ! -z ${node_pid} ];then
echo " ${node} is running, pid is $node_pid."
exit 0
else
nohup python ${binary_path} > start.out 2>&1 &
sleep 1.5
fi
try_times=4
i=0
while [ $i -lt ${try_times} ]
do
crypto_pid=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | awk 'NR==1{print}'`
kill -9 $crypto_pid
node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}')
success_flag=$(tail -n20 start.out | grep successfully)
if [[ ! -z ${node_pid} && ! -z "${success_flag}" ]];then
echo -e "\033[32m ${node} start successfully pid=${node_pid}\033[0m"
exit 0
fi
sleep 0.5
((i=i+1))
done

sleep 1

nohup python ppc_model_app.py > start.out 2>&1 &

check_service() {
try_times=5
i=0
while [ -z `ps -ef | grep ${1} | grep python | grep -v grep | awk '{print $2}'` ]; do
sleep 1
((i = i + 1))
if [ $i -lt ${try_times} ]; then
echo -e "\033[32m.\033[0m\c"
else
echo -e "\033[31m\nServer ${1} isn't running. \033[0m"
return
fi
done

echo -e "\033[32mServer ${1} started \033[0m"
}

sleep 5
check_service ppc_model_app.py
echo -e "\033[31m Exceed waiting time. Please try again to start ${node} \033[0m"
53 changes: 40 additions & 13 deletions python/ppc_model/tools/stop.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
#!/bin/bash
SHELL_FOLDER=$(cd $(dirname $0);pwd)

dirpath="$(cd "$(dirname "$0")" && pwd)"
cd $dirpath
LOG_ERROR() {
content=${1}
echo -e "\033[31m[ERROR] ${content}\033[0m"
}

# kill crypto process
crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l`
for i in $( seq 1 $crypto_pro_num )
do
crypto_pid=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | awk 'NR==1{print}'`
kill -9 $crypto_pid
done
LOG_INFO() {
content=${1}
echo -e "\033[32m[INFO] ${content}\033[0m"
}

sleep 1
binary_path=${SHELL_FOLDER}/ppc_model_app.py
node=$(basename ${SHELL_FOLDER})
node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}')
try_times=10
i=0
if [ -z ${node_pid} ];then
echo " ${node} isn't running."
exit 0
fi

ppc_model_app_pid=`ps aux |grep ppc_model_app.py |grep -v grep |awk '{print $2}'`
kill -9 $ppc_model_app_pid
#Stop monitor here
dirs=($(ls -l ${SHELL_FOLDER} | awk '/^d/ {print $NF}'))
for dir in ${dirs[*]}
do
if [[ -f "${SHELL_FOLDER}/${dir}/node.mtail" && -f "${SHELL_FOLDER}/${dir}/stop_mtail_monitor.sh" ]];then
echo "try to start ${dir}"
bash ${SHELL_FOLDER}/${dir}/stop_mtail_monitor.sh &
fi
done

echo -e "\033[32mServer ppc_model_app.py killed. \033[0m"
[ ! -z ${node_pid} ] && kill ${node_pid} > /dev/null
while [ $i -lt ${try_times} ]
do
sleep 1
node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}')
if [ -z ${node_pid} ];then
echo -e "\033[32m stop ${node} success.\033[0m"
exit 0
fi
((i=i+1))
done
echo " Exceed maximum number of retries. Please try again to stop ${node}"
exit 1

0 comments on commit 8dc416a

Please sign in to comment.