-
Notifications
You must be signed in to change notification settings - Fork 0
/
fedlearner.patch
396 lines (359 loc) · 16.5 KB
/
fedlearner.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
diff --git a/fedlearner/common/db_client.py b/fedlearner/common/db_client.py
index 6a21f19a..cb438307 100644
--- a/fedlearner/common/db_client.py
+++ b/fedlearner/common/db_client.py
@@ -18,7 +18,7 @@
import os
from fedlearner.common.etcd_client import EtcdClient
from fedlearner.common.dfs_client import DFSClient
-from fedlearner.common.mysql_client import MySQLClient
+#from fedlearner.common.mysql_client import MySQLClient
from fedlearner.common.leveldb import LevelDB
diff --git a/fedlearner/fedavg/fedavg.py b/fedlearner/fedavg/fedavg.py
index a66a5dab..745a45b6 100644
--- a/fedlearner/fedavg/fedavg.py
+++ b/fedlearner/fedavg/fedavg.py
@@ -4,7 +4,7 @@ from fedlearner.common import metrics
from fedlearner.fedavg.master import LeaderMaster, FollowerMaster
from fedlearner.fedavg.cluster.cluster_spec import FLClusterSpec
from fedlearner.fedavg._global_context import global_context as _gtx
-
+from sklearn.metrics import roc_auc_score
class MasterControlKerasCallback(tf.keras.callbacks.Callback):
@@ -68,6 +68,8 @@ class MetricsKerasCallback(tf.keras.callbacks.Callback):
def train_from_keras_model(model,
+ x_test=None,
+ y_test=None,
x=None,
y=None,
batch_size=None,
@@ -98,15 +100,29 @@ def train_from_keras_model(model,
master = master_class(model, fl_name, fl_cluster_spec, steps_per_sync,
save_filepath)
master.start()
- history = model.fit(x,
- y,
- batch_size=batch_size,
- epochs=epochs,
- callbacks=[MasterControlKerasCallback(master),
- MetricsKerasCallback()])
+ model.fit(x,
+ y,
+ batch_size=batch_size,
+ epochs=epochs,
+ callbacks=[MasterControlKerasCallback(master),
+ MetricsKerasCallback()])
master.wait()
- return history
+ result = model.evaluate(x_test, y_test)
+ logger = master.logger
+ if fl_name == 'leader':
+ with logger.model_evaluation() as e:
+ output = model.predict(x_test)
+ if len(output.shape) == 1:
+ target_metric = result[0]
+ elif output.shape[1] == 2:
+ target_metric = roc_auc_score(y_test, output[:, 1])
+ else:
+ target_metric = result[1]
+ e.report_metric('target_metric', float(target_metric))
+ e.report_metric('loss', float(result[0]))
+ logger.end()
+ return result
def eval_from_keras_model(model: tf.keras.Model, x=None, y=None,
diff --git a/fedlearner/fedavg/master.py b/fedlearner/fedavg/master.py
index 48d44cf3..dabf03db 100644
--- a/fedlearner/fedavg/master.py
+++ b/fedlearner/fedavg/master.py
@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import grpc
import numpy as np
+import flbenchmark.logging
import fedlearner.common.fl_logging as logging
import fedlearner.common.grpc_utils as grpc_utils
from fedlearner.proxy.channel import make_insecure_channel, ChannelType
@@ -31,16 +32,27 @@ class _Master:
def on_train_begin(self):
self._initialize()
+ self.logger.training_start()
def on_train_batch_begin(self):
self._step += 1
+ if self._steps_per_sync == 1 or self._step % self._steps_per_sync == 1:
+ self.logger.training_round_start()
+ self.logger.computation_start()
return self._step
def on_train_batch_end(self):
+ self.logger.computation_end()
if self._step % self._steps_per_sync == 0:
self._sync()
+ self.logger.training_round_end()
def on_train_end(self):
+ try:
+ self.logger.training_round_end()
+ except:
+ pass
+ self.logger.training_end()
self._save_model()
self._done()
@@ -150,6 +162,7 @@ class LeaderMaster(_Master):
self._fl_name = fl_name
self._fl_cluster_spec = fl_cluster_spec
self._leader = self._fl_cluster_spec.leader
+ self.logger = flbenchmark.logging.BasicLogger(id=0, agent_type='client')
self._follower_mapping = dict()
for f in self._fl_cluster_spec.followers:
@@ -175,7 +188,9 @@ class LeaderMaster(_Master):
self._grpc_server = grpc.server(
ThreadPoolExecutor(
max_workers=8,
- thread_name_prefix="LeaderMasterGrpcServerThreadPoolExecutor"))
+ thread_name_prefix="LeaderMasterGrpcServerThreadPoolExecutor"),
+ options=[('grpc.max_send_message_length', 512 * 1024 * 1024), ('grpc.max_receive_message_length', 512 * 1024 * 1024)]
+ )
add_TrainingServiceServicer_to_server(_TrainingServiceServicer(self),
self._grpc_server)
self._grpc_server.add_insecure_port(address)
@@ -265,7 +280,8 @@ class LeaderMaster(_Master):
" train_end followers: %s", pushed, unpush, train_end)
self._cv.wait(1)
- self._aggregate_weights()
+ with self.logger.computation() as c:
+ self._aggregate_weights()
if not is_train_end:
break
@@ -413,10 +429,11 @@ class FollowerMaster(_Master):
self._fl_name = fl_name
self._fl_cluster_spec = fl_cluster_spec
self._leader = self._fl_cluster_spec.leader
+ self.logger = flbenchmark.logging.BasicLogger(id=str(int(fl_name.split('_')[1])+1), agent_type='client')
def start(self):
self._grpc_channel = make_insecure_channel(
- self._leader.address, ChannelType.REMOTE)
+ self._leader.address, ChannelType.REMOTE, options=[('grpc.max_send_message_length', 512 * 1024 * 1024), ('grpc.max_receive_message_length', 512 * 1024 * 1024)])
self._grpc_client = TrainingServiceStub(self._grpc_channel)
def wait(self):
@@ -445,8 +462,9 @@ class FollowerMaster(_Master):
version=version,
is_last_pull=is_last_pull)
while True:
- resp = grpc_utils.call_with_retry(
- lambda: self._grpc_client.Pull(req))
+ self.logger.communication_start(target_id=-1)
+ resp = grpc_utils.call_with_retry(lambda: self._grpc_client.Pull(req))
+ self.logger.communication_end(metrics={'byte': req.ByteSize()+resp.ByteSize()})
if resp.status.code == Status.Code.OK:
break
if resp.status.code == Status.Code.NOT_READY:
@@ -468,7 +486,9 @@ class FollowerMaster(_Master):
weights=_weight_mapping_to_proto_weights(weight_mapping),
version=version,
is_train_end=is_train_end)
+ self.logger.communication_start(target_id=-1)
resp = grpc_utils.call_with_retry(lambda: self._grpc_client.Push(req))
+ self.logger.communication_end(metrics={'byte': req.ByteSize()+resp.ByteSize()})
if resp.status.code != Status.Code.OK:
raise RuntimeError(
"push weights error, code: {}, message: {}".format(
diff --git a/fedlearner/model/tree/tree.py b/fedlearner/model/tree/tree.py
index 20cd3263..d4779384 100644
--- a/fedlearner/model/tree/tree.py
+++ b/fedlearner/model/tree/tree.py
@@ -33,7 +33,7 @@ from fedlearner.model.crypto import paillier, fixed_point_number
from fedlearner.common import tree_model_pb2 as tree_pb2
from fedlearner.common import common_pb2
from fedlearner.common.metrics import emit_store
-
+import flbenchmark.logging
BST_TYPE = np.float32
PRECISION = 1e38
@@ -1098,6 +1098,8 @@ class BoostingTreeEnsamble(object):
if self._role == 'leader' and self._enable_packing:
self._packer = GradHessPacker(self._public_key, PRECISION, EXPONENT)
+ self.logger = flbenchmark.logging.Logger(id = 0 if self._role == 'leader' else 1, agent_type='client')
+
@property
def loss(self):
return self._loss
@@ -1572,11 +1574,14 @@ class BoostingTreeEnsamble(object):
sum_prediction = np.zeros(num_examples, dtype=BST_TYPE)
# start iterations
+ self.logger.training_start()
while len(self._trees) < self._max_iters:
begin_time = time.time()
num_iter = len(self._trees)
+ self.logger.training_round_start()
# grow tree
+ self.logger.computation_start()
if self._bridge is None:
tree, raw_prediction = self._fit_one_round_local(
sum_prediction, binned, labels)
@@ -1588,6 +1593,7 @@ class BoostingTreeEnsamble(object):
else:
tree = self._fit_one_round_follower(binned)
self._trees.append(tree)
+ self.logger.computation_end()
logging.info("Elapsed time for one round %s s",
str(time.time()-begin_time))
@@ -1622,15 +1628,21 @@ class BoostingTreeEnsamble(object):
self._write_training_log(
output_path, 'train_%d'%num_iter, metrics, pred)
self.iter_metrics_handler(metrics, mode='train')
+ self.logger.training_round_end()
# validation
- if validation_features is not None:
- val_pred = self.batch_predict(
- validation_features,
- example_ids=validation_example_ids,
- cat_features=validation_cat_features)
- metrics = self._compute_metrics(val_pred, validation_labels)
- self.iter_metrics_handler(metrics, mode='eval')
+ if validation_features is not None and num_iter == self._max_iters - 1:
+ self.logger.training_end()
+ with self.logger.model_evaluation() as e:
+ val_pred = self.batch_predict(
+ validation_features,
+ example_ids=validation_example_ids,
+ cat_features=validation_cat_features)
+ metrics = self._compute_metrics(val_pred, validation_labels)
+ self.iter_metrics_handler(metrics, mode='eval')
+ if self._role == 'leader':
+ e.report_metric('auc', float(metrics['auc']))
+ e.report_metric('accuracy', float(metrics['acc']))
logging.info(
"Validation metrics for iter %d: %s", num_iter, metrics)
@@ -1638,6 +1650,8 @@ class BoostingTreeEnsamble(object):
self._write_training_log(
output_path, 'val_%d'%num_iter, metrics, val_pred)
+ self.logger.end()
+
return self._loss.predict(sum_prediction)
diff --git a/fedlearner/trainer/bridge.py b/fedlearner/trainer/bridge.py
index 721298b6..9e2b83d8 100644
--- a/fedlearner/trainer/bridge.py
+++ b/fedlearner/trainer/bridge.py
@@ -20,7 +20,7 @@ import collections
import threading
import time
from distutils.util import strtobool
-
+import flbenchmark.logging
import tensorflow.compat.v1 as tf
from google.protobuf import any_pb2 as any_pb
from fedlearner.common import fl_logging
@@ -42,7 +42,9 @@ class Bridge(object):
def Transmit(self, request_iterator, context):
for request in request_iterator:
+ self._bridge.logger.communication_start(target_id=1-self._bridge.id)
yield self._bridge._transmit_handler(request)
+ self._bridge.logger.communication_end(metrics={'byte': request.ByteSize()})
def LoadDataBlock(self, request, context):
return self._bridge._data_block_handler(request)
@@ -56,6 +58,8 @@ class Bridge(object):
stream_queue_size=1024,
waiting_alert_timeout=10):
self._role = role
+ self.id = 0 if role == 'leader' else 1
+ self.logger = flbenchmark.logging.Logger(id=self.id, agent_type='client')
self._listen_address = "[::]:{}".format(listen_port)
self._remote_address = remote_address
if app_id is None:
diff --git a/fedlearner/trainer/estimator.py b/fedlearner/trainer/estimator.py
index b7e5c2ac..865345eb 100644
--- a/fedlearner/trainer/estimator.py
+++ b/fedlearner/trainer/estimator.py
@@ -16,7 +16,7 @@
# pylint: disable=protected-access
import time
-
+import flbenchmark.logging
import tensorflow.compat.v1 as tf
from tensorflow_estimator.python.estimator import model_fn as model_fn_lib
from fedlearner.common import fl_logging
@@ -163,6 +163,7 @@ class FLEstimator(object):
self._model_fn = model_fn
self._trainer_master = trainer_master
self._is_chief = is_chief
+ self.logger = flbenchmark.logging.Logger(id = 0 if role == 'leader' else 1, agent_type='client')
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
dataset = input_fn(self._bridge, self._trainer_master)
@@ -198,18 +199,25 @@ class FLEstimator(object):
master=self._cluster_server.target,
config=self._cluster_server.cluster_config)
+ self.logger.training_start()
self._bridge.connect()
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=hooks) as sess:
while not sess.should_stop():
start_time = time.time()
self._bridge.start()
+ self.logger.training_round_start()
+ self.logger.computation_start()
sess.run(spec.train_op, feed_dict={})
+ self.logger.computation_end()
+ self.logger.training_round_end()
self._bridge.commit()
use_time = time.time() - start_time
fl_logging.debug("after session run. time: %f sec",
use_time)
self._bridge.terminate()
+ self.logger.training_end()
+ self.logger.end()
return self
@@ -254,18 +262,30 @@ class FLEstimator(object):
master=self._cluster_server.target,
config=self._cluster_server.cluster_config)
# Evaluate over dataset
- self._bridge.connect()
- with tf.train.MonitoredSession(
- session_creator=session_creator, hooks=all_hooks) as sess:
- while not sess.should_stop():
- start_time = time.time()
- self._bridge.start()
- sess.run(eval_op)
- self._bridge.commit()
- use_time = time.time() - start_time
- fl_logging.debug("after session run. time: %f sec",
- use_time)
- self._bridge.terminate()
+ with self.logger.model_evaluation() as e:
+ self._bridge.connect()
+ with tf.train.MonitoredSession(
+ session_creator=session_creator, hooks=all_hooks) as sess:
+ while not sess.should_stop():
+ start_time = time.time()
+ self._bridge.start()
+ sess.run(eval_op)
+ self._bridge.commit()
+ use_time = time.time() - start_time
+ fl_logging.debug("after session run. time: %f sec",
+ use_time)
+ self._bridge.terminate()
+ if self._role == 'leader':
+ try:
+ e.report_metric('target_metric', float(final_ops_hook.final_ops_values['auc']))
+ except:
+ try:
+ e.report_metric('target_metric', float(final_ops_hook.final_ops_values['accuracy']))
+ except:
+ e.report_metric('target_metric', float(final_ops_hook.final_ops_values['loss']))
+ e.report_metric('loss', float(final_ops_hook.final_ops_values['loss']))
+ self.logger.end()
+
# Print result
fl_logging.info('Metrics for evaluate: %s',
_dict_to_str(final_ops_hook.final_ops_values))
diff --git a/requirements.txt b/requirements.txt
index 5fce3a05..491404d8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,8 +22,8 @@ elasticsearch6
guppy3
tensorflow-io==0.8.1
psutil
-sqlalchemy==1.2.19
-mysqlclient
+#sqlalchemy==1.2.19
+#mysqlclient
leveldb
prison==0.1.3
matplotlib