From 3fa13060724a777e36985f04ca059806bc5ef390 Mon Sep 17 00:00:00 2001 From: "xinyang.gxy" Date: Fri, 25 Jan 2019 15:06:08 +0800 Subject: [PATCH 1/2] use worker_finish_op to end sync&async&barrier --- .../ps-plus/scheduler/asynchronizer.cc | 13 ++++------- .../ps-plus/scheduler/scheduler_impl.cc | 23 +++++++++++++++---- xdl/ps-plus/ps-plus/scheduler/synchronizer.cc | 22 ++++++++++++------ xdl/ps-plus/ps-plus/scheduler/synchronizer.h | 6 +++-- .../scheduler/test/asynchronizer_tests.cc | 14 +++++------ .../ops/ps_ops/ps_convert_ckpt_variable_op.cc | 2 +- .../core/ops/ps_ops/ps_synchronizer_ops.cc | 7 +++--- 7 files changed, 52 insertions(+), 35 deletions(-) diff --git a/xdl/ps-plus/ps-plus/scheduler/asynchronizer.cc b/xdl/ps-plus/ps-plus/scheduler/asynchronizer.cc index 31fca016..43a10d63 100644 --- a/xdl/ps-plus/ps-plus/scheduler/asynchronizer.cc +++ b/xdl/ps-plus/ps-plus/scheduler/asynchronizer.cc @@ -32,7 +32,6 @@ namespace { function MkCb(int id) { return [id](const Status& st) { - //LOG_WARN("Null callback for worker %d invoked", id); }; } @@ -71,7 +70,6 @@ void Asynchronizer::Enter(int id, function cb) { return; } if (removed_workers_.find(id) != removed_workers_.end()) { - //LOG_FATAL("Worker %d revived", id); abort(); } Context* ctx = contexts_.get() + id; @@ -92,12 +90,11 @@ void Asynchronizer::Enter(int id, function cb) { cb(Status::Ok()); } -void Asynchronizer::WorkerReportFinish(int id, std::function cb) { +Status Asynchronizer::WorkerReportFinish(int id) { if (id < 0 || id >= worker_count_) { - cb(Status::ArgumentError("Offset out of bound: min=0, max=" - + to_string(worker_count_) + ", actual=" - + to_string(id))); - return; + return Status::ArgumentError("Offset out of bound: min=0, max=" + + to_string(worker_count_) + ", actual=" + + to_string(id)); } removed_workers_.insert(id); Context* ctx = contexts_.get() + id; @@ -109,7 +106,7 @@ void Asynchronizer::WorkerReportFinish(int id, std::function(sync_.get()); - if (sync == nullptr) { - LOG(ERROR) << "Call async method in sync mode."; - cb(Status::ArgumentError("Call async method in sync mode.")); + if (sync_.get() != nullptr) { + Status st = sync_->WorkerReportFinish(id); + if (!st.IsOk()) { + cb(st); + return; + } + } + finished_workers_.insert(id); + auto iter = worker_barriers_.find(id); + if (iter != worker_barriers_.end()) { + worker_barriers_.erase(iter); + } + if (worker_barriers_.size() == worker_count_ - finished_workers_.size()) { + for (auto iter : worker_barriers_) { + (iter.second)(Status::Ok()); + } + worker_barriers_.clear(); } - sync->WorkerReportFinish(id, cb); + cb(Status::Ok()); } void SchedulerImpl::InternalWorkerBarrier(Version version, int id, int worker_count, function cb) { diff --git a/xdl/ps-plus/ps-plus/scheduler/synchronizer.cc b/xdl/ps-plus/ps-plus/scheduler/synchronizer.cc index 525c0f44..31c6f85d 100644 --- a/xdl/ps-plus/ps-plus/scheduler/synchronizer.cc +++ b/xdl/ps-plus/ps-plus/scheduler/synchronizer.cc @@ -16,10 +16,9 @@ limitations under the License. #include "synchronizer.h" #include "ps-plus/common/status.h" - +#include #include #include -#include using namespace std; using namespace std::chrono; @@ -33,7 +32,7 @@ namespace { function MkCb(int id) { return [id](int, const Status& st) { - LOG(WARNING) << "Null callback for worker" << id << " invoked"; + LOG(WARNING) << "Null callback for worker " << id << " invoked"; }; } @@ -77,15 +76,24 @@ void Synchronizer::Enter(int id, function cb) { waiting_list_.insert(ctx); } +Status Synchronizer::WorkerReportFinish(int id) { + if (working_list_.find(id) == working_list_.end()) { + return Status::Ok(); + } + working_list_.erase(id); + if (left_token_ == 0 && working_list_.empty()) { + UnlockNewToken(); + } + return Status::Ok(); +} + void Synchronizer::Leave(int id, int64_t token, function cb) { if (token != current_token_) { - LOG(WARNING) << "Receive token " << token << " from " << id << - "while current_token_ is " << current_token_; + LOG(WARNING) << "Receive token " << token << " from " << id << " while current_token_ is " << current_token_; cb(Status::Ok()); } if (working_list_.find(id) == working_list_.end()) { - LOG(FATAL) << "Worker " << id << " not granted token, but it call leave with token " << token << - ", current token is " << current_token_; + LOG(FATAL) << "Worker " << id << " not granted token, but it call leave with token " << token << ", current token is " << current_token_; abort(); } working_list_.erase(id); diff --git a/xdl/ps-plus/ps-plus/scheduler/synchronizer.h b/xdl/ps-plus/ps-plus/scheduler/synchronizer.h index 957df6d6..82f8ca59 100644 --- a/xdl/ps-plus/ps-plus/scheduler/synchronizer.h +++ b/xdl/ps-plus/ps-plus/scheduler/synchronizer.h @@ -34,6 +34,7 @@ class SyncMechanism { SyncMechanism() {} virtual ~SyncMechanism() {} virtual void Reset() = 0; + virtual Status WorkerReportFinish(int id) = 0; }; class Asynchronizer : public SyncMechanism { @@ -53,7 +54,7 @@ class Asynchronizer : public SyncMechanism { Asynchronizer(int staleness, int worker_count); ~Asynchronizer(); void Enter(int id, std::function cb); - void WorkerReportFinish(int id, std::function cb); + Status WorkerReportFinish(int id); void Reset(); }; @@ -74,7 +75,8 @@ class Synchronizer : public SyncMechanism { Synchronizer(int worker_count); ~Synchronizer() {} void Enter(int id, std::function cb); - void Leave(int id, int64_t token, std::function cb); + void Leave(int id, int64_t token, std::function cb); + Status WorkerReportFinish(int id); void Reset(); }; diff --git a/xdl/ps-plus/ps-plus/scheduler/test/asynchronizer_tests.cc b/xdl/ps-plus/ps-plus/scheduler/test/asynchronizer_tests.cc index ad7dd5bd..13deecfd 100644 --- a/xdl/ps-plus/ps-plus/scheduler/test/asynchronizer_tests.cc +++ b/xdl/ps-plus/ps-plus/scheduler/test/asynchronizer_tests.cc @@ -44,14 +44,12 @@ TEST(Asynchronizer, EnterAndFinish) { execute_log += "0-2"; }); EXPECT_EQ(execute_log, "0-01-00-12-0"); - async->WorkerReportFinish(1, [&execute_log](const Status& st) { - execute_log += "1-finish"; - }); - EXPECT_EQ(execute_log, "0-01-00-12-01-finish"); - async->WorkerReportFinish(2, [&execute_log](const Status& st) { - execute_log += "2-finish"; - }); - EXPECT_EQ(execute_log, "0-01-00-12-01-finish0-22-finish"); + Status st = async->WorkerReportFinish(1); + EXPECT_TRUE(st.IsOk()); + EXPECT_EQ(execute_log, "0-01-00-12-0"); + st = async->WorkerReportFinish(2); + EXPECT_TRUE(st.IsOk()); + EXPECT_EQ(execute_log, "0-01-00-12-00-2"); } TEST(Asynchronizer, Reset) { diff --git a/xdl/xdl/core/ops/ps_ops/ps_convert_ckpt_variable_op.cc b/xdl/xdl/core/ops/ps_ops/ps_convert_ckpt_variable_op.cc index 6c3ed5d8..3ce43bb5 100644 --- a/xdl/xdl/core/ops/ps_ops/ps_convert_ckpt_variable_op.cc +++ b/xdl/xdl/core/ops/ps_ops/ps_convert_ckpt_variable_op.cc @@ -107,7 +107,7 @@ class PsConvertCkptVariableOp : public xdl::OpKernelAsync { PS_CHECK_STATUS(ps::FileSystem::OpenWriteStreamAny(file_name, &output_stream)); for (size_t i = 0; i < info.parts.size(); i++) { ps::server::CheckpointUtils::VariableStruct vs; - printf("Start convert [%s], part[%d]\n", info.name.c_str(), i); + printf("Start convert [%s], part[%ld]\n", info.name.c_str(), i); PS_CHECK_STATUS(utils.LoadVariable(info.name, i, &vs)); if (!vs.initialized) { return ps::Status::DataLoss("Load variable " + info.name + " failed."); diff --git a/xdl/xdl/core/ops/ps_ops/ps_synchronizer_ops.cc b/xdl/xdl/core/ops/ps_ops/ps_synchronizer_ops.cc index 54ac64da..1fe8b1d2 100644 --- a/xdl/xdl/core/ops/ps_ops/ps_synchronizer_ops.cc +++ b/xdl/xdl/core/ops/ps_ops/ps_synchronizer_ops.cc @@ -90,7 +90,7 @@ class PsSynchronizeLeaveOp: public xdl::OpKernelAsync { } }; -class PsSemiSynchronizeLeaveOp: public xdl::OpKernelAsync { +class WorkerReportFinishOp: public xdl::OpKernelAsync { public: Status Init(OpKernelConstruction* ctx) override { return Status::Ok(); @@ -113,7 +113,6 @@ class WorkerBarrierOp: public xdl::OpKernelAsync { Status Init(OpKernelConstruction* ctx) override { return Status::Ok(); } - void Compute(OpKernelContext* ctx, Callback done) override { ps::client::BaseClient* client; XDL_CHECK_STATUS_ASYNC(GetClient(&client), done); @@ -139,7 +138,7 @@ XDL_DEFINE_OP(PsSynchronizeEnterOp) XDL_DEFINE_OP(PsSynchronizeLeaveOp) .Input("id", DataType::kInt32); -XDL_DEFINE_OP(PsSemiSynchronizeLeaveOp) +XDL_DEFINE_OP(WorkerReportFinishOp) .Input("id", DataType::kInt32); XDL_DEFINE_OP(WorkerBarrierOp) @@ -149,7 +148,7 @@ XDL_DEFINE_OP(WorkerBarrierOp) XDL_REGISTER_KERNEL(PsAsynchronizeEnterOp, PsAsynchronizeEnterOp).Device("CPU"); XDL_REGISTER_KERNEL(PsSynchronizeEnterOp, PsSynchronizeEnterOp).Device("CPU"); XDL_REGISTER_KERNEL(PsSynchronizeLeaveOp, PsSynchronizeLeaveOp).Device("CPU"); -XDL_REGISTER_KERNEL(PsSemiSynchronizeLeaveOp, PsSemiSynchronizeLeaveOp).Device("CPU"); +XDL_REGISTER_KERNEL(WorkerReportFinishOp, WorkerReportFinishOp).Device("CPU"); XDL_REGISTER_KERNEL(WorkerBarrierOp, WorkerBarrierOp).Device("CPU"); } From bc0c3a32c5cd83b71a80f7697a23a317889a5b20 Mon Sep 17 00:00:00 2001 From: "xinyang.gxy" Date: Fri, 25 Jan 2019 20:06:09 +0800 Subject: [PATCH 2/2] fix bug of scalar_integer_logger not deal with NOT_ADD_ID --- xdl/ps-plus/ps-plus/server/udf/scalar_integer_logger.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xdl/ps-plus/ps-plus/server/udf/scalar_integer_logger.cc b/xdl/ps-plus/ps-plus/server/udf/scalar_integer_logger.cc index d57daed7..04dc5542 100644 --- a/xdl/ps-plus/ps-plus/server/udf/scalar_integer_logger.cc +++ b/xdl/ps-plus/ps-plus/server/udf/scalar_integer_logger.cc @@ -16,6 +16,7 @@ limitations under the License. #include "ps-plus/server/udf/simple_udf.h" #include "ps-plus/server/slice.h" #include "ps-plus/common/initializer/constant_initializer.h" +#include "ps-plus/common/hashmap.h" namespace ps { namespace server { @@ -32,6 +33,9 @@ class ScalarIntegerLogger : public SimpleUdf { int64_t* data = t->Raw(); int64_t val = pval; for (size_t slice : slices.slice_id) { + if ((int64_t)slice == ps::HashMap::NOT_ADD_ID) { + continue; + } data[slice] = val; } return Status::Ok();