Skip to content

Commit

Permalink
Refactor fetch handler (#21264)
Browse files Browse the repository at this point in the history
* fix fetch handler problem and refactor
when a user define FetchHandler class, he or she should initialize a handler
with variable dict. the key of a variable dict is a user defined name,
the value of a variable dict is a Varaible generated from python API.

For each fetching, a user should implement handler function in which
fetched_result_dict will be available and the user can access the fetched value
with user defined keys.
  • Loading branch information
guru4elephant authored Nov 24, 2019
1 parent f1b09ba commit 691ced8
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 75 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
// training and finalize training
VLOG(3) << "Trainer starts to run";
trainer->Run();
}

void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class Executor {
Scope* scope, Dataset* dataset);
void RunFromDataset(std::shared_ptr<TrainerBase> trainer);

void ReleaseTrainer(std::shared_ptr<TrainerBase> trainer);

const platform::Place GetPlace() const { return place_; }

private:
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/framework/multi_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,12 @@ void MultiTrainer::Run() {
workers_[thidx].get()));
}
}
}

void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
root_scope_->DropKids();
}

void MultiTrainer::Finalize() { root_scope_->DropKids(); }

} // end namespace framework
} // end namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1324,10 +1324,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("close", &Executor::Close)
.def("run_from_dataset", &Executor::RunFromDataset,
py::call_guard<py::gil_scoped_release>())
.def("release_trainer", &Executor::ReleaseTrainer,
py::call_guard<py::gil_scoped_release>())
.def("init_for_dataset",
[](Executor &self, const ProgramDesc &prog,
const std::string &trainer_desc, Scope *scope,
Dataset *dataset) -> std::shared_ptr<TrainerBase> {
pybind11::gil_scoped_release release;
return self.InitForDataset(prog, trainer_desc, scope, dataset);
})
.def("run_from_dataset",
Expand Down
31 changes: 18 additions & 13 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,23 +395,28 @@ def _as_lodtensor(data, place):


class FetchHandler(object):
def __init__(self, fetch_target_names, period_secs=60, return_np=True):
self.fetch_target_names = fetch_target_names
def __init__(self, var_dict=None, period_secs=60):
assert var_dict != None
self.var_dict = var_dict
self.period_secs = period_secs
self.return_np = return_np

def handler(self, fetch_target_vars):
return
def handler(self, res_dict):
for key in res_dict:
if type(res_dict[key]) is np.ndarray:
sys.stdout.write("{}[0]: {} ".format(key, res_dict[key][0]))
sys.stdout.write("\n")

@staticmethod
def help():
print("""
class FetchHandlerExamlpe(FetchHandler):
def handler(self, fetch_target_vars):
b_auc = fetch_target_vars[0]
g_auc = fetch_target_vars[1]
print("b_auc: {}, g_auc: {} at time: {}".format(b_auc, g_auc, time.ctime()))
class FetchHandlerExample(FetchHandler):
def handler(self, res_dict):
print(res_dict["auc"])
print("auc: {}, {}".format(res_dict["auc"], time.ctime()))
auc = Variable()
var_dict = {"auc": auc}
handler = FetchHandlerExample(var_dict=var_dict)
""")


Expand Down Expand Up @@ -1010,13 +1015,13 @@ def _run_from_dataset(self,
scope0 = trainer_instance.get_worker_scope(0)
fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler)
fetch_monitor.start()

self._default_executor.run_from_dataset(trainer_instance)

fetch_monitor.stop()
self._default_executor.release_trainer(trainer_instance)
else:

self._default_executor.run_from_dataset(trainer_instance)
self._default_executor.release_trainer(trainer_instance)

dataset._dynamic_adjust_after_train()
dataset._finish_to_run()
Expand Down
26 changes: 21 additions & 5 deletions python/paddle/fluid/tests/unittests/test_fetch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
import unittest
import numpy as np
from paddle.fluid.framework import Program

import paddle.fluid.core as core
import paddle.fluid as fluid
Expand All @@ -29,20 +30,35 @@ def test_fetch_handler(self):

table = np.random.random((3, 10)).astype("float32")

prog = Program()
block = prog.current_block()
var_emb = block.create_var(name='emb', type=core.VarDesc.VarType.FP32)
var_emb3 = block.create_var(name='emb3', type=core.VarDesc.VarType.FP32)

class FH(fluid.executor.FetchHandler):
def handler(self, fetch_target_vars):
assert len(fetch_target_vars) == 1
def handler(self, fetch_dict):
assert len(fetch_dict) == 1

table_var = scope.var('emb').get_tensor()
table_var.set(table, place)

fh = FH(['emb'], period_secs=2, return_np=True)
fh = FH(var_dict={'emb': var_emb}, period_secs=2)
fm = fluid.trainer_factory.FetchHandlerMonitor(scope, fh)

fm.start()
time.sleep(10)
time.sleep(3)
fm.stop()

default_fh = fluid.executor.FetchHandler(
var_dict={'emb': var_emb,
'emb2': None,
'emb3': var_emb3},
period_secs=1)
default_fm = fluid.trainer_factory.FetchHandlerMonitor(scope,
default_fh)
default_fm.start()
time.sleep(5)
default_fm.stop()


if __name__ == "__main__":
unittest.main()
116 changes: 63 additions & 53 deletions python/paddle/fluid/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@

import threading
import time

import logging
import numpy as np

logging.basicConfig()

from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section
from .framework import Variable
from multiprocessing import Process, Manager

__all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]

Expand Down Expand Up @@ -93,68 +97,74 @@ class FetchHandlerMonitor(object):
def __init__(self, scope, handler):
self.fetch_instance = handler
self.fetch_thread = threading.Thread(
target=self.handler_decorator,
args=(scope, self.fetch_instance.handler))
target=self.handler_launch_func, args=(scope, self.fetch_instance))
self.running_lock = threading.Lock()
self.running = False

def handler_launch_func(self, scope, handler):
fetch_instance = handler
period_secs = fetch_instance.period_secs
var_name_to_key = {}
for key in fetch_instance.var_dict:
if isinstance(fetch_instance.var_dict[key], Variable):
var_name_to_key[fetch_instance.var_dict[key].name] = key
else:
logging.warning("the value of {} is not a Variable".format(key))
var_name_to_key["None.var"] = key
elapsed_secs = 0
while True:
self.running_lock.acquire()
if self.running == False:
break
if elapsed_secs < period_secs:
# TODO(guru4elephant): needs customized condition
time.sleep(1)
elapsed_secs += 1
else:
elapsed_secs = 0
fetch_dict = {}
for key in var_name_to_key:
var = scope.find_var(key)
fetch_dict[key] = var
if var == None:
logging.warning("{} value currently not available".
format(var_name_to_key[key]))
res_dict = {}
for key in fetch_dict:
user_name = var_name_to_key[key]
if fetch_dict[key] == None:
res_dict[user_name] = None
continue
else:
res_dict[user_name] = fetch_dict[key].get_tensor()

lod = res_dict[user_name].lod()
if len(lod) > 0:
raise RuntimeError("Some of your fetched tensors \
hold LoD information. \
They can not be completely cast \
to Python ndarray. We can \
not return LoDTensor itself directly, \
please choose another targets")
if res_dict[user_name]._is_initialized():
res_dict[user_name] = np.array(res_dict[user_name])
else:
res_dict[user_name] = None
fetch_instance.handler(res_dict)
self.running_lock.release()

def start(self):
"""
start monitor,
it will start a monitor thread.
"""
self.running_lock.acquire()
self.running = True
self.running_lock.release()
self.fetch_thread.setDaemon(True)
self.fetch_thread.start()

def handler_decorator(self, fetch_scope, fetch_handler):
"""
decorator of handler,
Args:
fetch_scope(Scope): fetch scope
fetch_handler(Handler): fetch handler
"""
fetch_target_names = self.fetch_instance.fetch_target_names
period_secs = self.fetch_instance.period_secs

elapsed_secs = 0
while True:
while self.running and elapsed_secs >= period_secs:
elapsed_secs = 0

fetch_vars = [
fetch_scope.find_var(varname)
for varname in fetch_target_names
]

if None in fetch_vars:
continue

fetch_tensors = [var.get_tensor() for var in fetch_vars]

if self.fetch_instance.return_np:
fetch_nps = []

for tensor in fetch_tensors:
lod = tensor.lod()

if len(lod) > 0:
raise RuntimeError(
"Some of your fetched tensors hold LoD information. \
They can not be completely cast to Python ndarray. We can not \
return LoDTensor itself directly, please choose another targets"
)

if tensor._is_initialized():
fetch_nps.append(np.array(tensor))
else:
fetch_nps.append(None)

fetch_handler(fetch_nps)
else:
fetch_handler(fetch_tensors)
else:
time.sleep(1)
elapsed_secs += 1

def stop(self):
self.running_lock.acquire()
self.running = False
self.running_lock.release()

0 comments on commit 691ced8

Please sign in to comment.