diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index ef0460df9c19ac..9ba42edecfcfca 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -276,6 +276,7 @@ cc_library( deps = [ ":cache_utils", ":scaling_utils", + ":local_workers_utils", ":common_proto_cc", ":credentials_factory", ":data_service", @@ -901,6 +902,21 @@ cc_library( ], ) +cc_library( + name = "local_workers_utils", + srcs = ["easl/local_workers_utils.cc"], + hdrs = [ + "easl/local_workers_utils.h", + ], + deps = [ + ":common_proto_cc", + ":dispatcher_state", + ":metadata_store", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "cache_model", srcs = ["easl/cache_model.cc"], diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto index e6785a836f9a68..fe2357ac5d02c2 100644 --- a/tensorflow/core/data/service/dispatcher.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -119,7 +119,7 @@ message JobKey { int64 job_name_index = 2; } -// Next tag: 8 +// Next tag: 9 message GetOrCreateJobRequest { reserved 3, 4; // The id of the dataset to create a job for. @@ -134,6 +134,8 @@ message GetOrCreateJobRequest { oneof optional_num_consumers { int64 num_consumers = 7; } + // DSL - to pass the array of available local workers to the dispatcher's job creation logic + repeated string local_workers = 8; } // Next tag: 2 @@ -188,7 +190,7 @@ message ClientHeartbeatRequest { double result_queue_size = 10; } -// Next tag: 4 +// Next tag: 5 message ClientHeartbeatResponse { // A list of all tasks that the client should read from. repeated TaskInfo task_info = 1; @@ -198,6 +200,8 @@ message ClientHeartbeatResponse { } // Whether the job has finished. bool job_finished = 2; + // DSL: to check whether we should use local workers (based on last epoch's metrics) + bool target_local_workers = 4; } // Next tag: 3 diff --git a/tensorflow/core/data/service/dispatcher_client.cc b/tensorflow/core/data/service/dispatcher_client.cc index 84500c3439c467..79b03f4688b9fb 100644 --- a/tensorflow/core/data/service/dispatcher_client.cc +++ b/tensorflow/core/data/service/dispatcher_client.cc @@ -198,7 +198,8 @@ Status DataServiceDispatcherClient::RegisterDataset( Status DataServiceDispatcherClient::GetOrCreateJob( int64 dataset_id, ProcessingMode processing_mode, const absl::optional& job_key, absl::optional num_consumers, - int64& job_client_id) { + int64& job_client_id, + std::vector local_workers) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrCreateJobRequest req; req.set_dataset_id(dataset_id); @@ -209,6 +210,9 @@ Status DataServiceDispatcherClient::GetOrCreateJob( if (num_consumers.has_value()) { req.set_num_consumers(num_consumers.value()); } + // DSL - client sends a list of its local workers to the dispatcher + *req.mutable_local_workers() = {local_workers.begin(), local_workers.end()}; + GetOrCreateJobResponse resp; grpc::ClientContext client_ctx; grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp); diff --git a/tensorflow/core/data/service/dispatcher_client.h b/tensorflow/core/data/service/dispatcher_client.h index 65e5e1275b5b46..987f9b7febbff4 100644 --- a/tensorflow/core/data/service/dispatcher_client.h +++ b/tensorflow/core/data/service/dispatcher_client.h @@ -78,7 +78,8 @@ class DataServiceDispatcherClient : public DataServiceClientBase { Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode, const absl::optional& job_key, absl::optional num_consumers, - int64& job_client_id); + int64& job_client_id, + std::vector local_workers); // Releases a job client id, indicating that the id will no longer be used to // read from the job. diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 69855b0a49bdbe..3efe3804f027a5 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/data/service/easl/cache_utils.h" #include "tensorflow/core/data/service/easl/scaling_utils.h" #include "tensorflow/core/data/service/easl/metadata_store.h" +#include "tensorflow/core/data/service/easl/local_workers_utils.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/graph.pb.h" @@ -955,7 +956,6 @@ Status DataServiceDispatcherImpl::CreateJob( int64 job_id = state_.NextAvailableJobId(); // EASL - Caching decision: should the job compute, write or read from cache? - int64 worker_count; std::string job_type; string job_name = "named_job_key.value().name"; std::shared_ptr dataset; @@ -993,10 +993,37 @@ Status DataServiceDispatcherImpl::CreateJob( std::shared_ptr job_metrics; s = metadata_store_.GetJobMetrics(job_id, job_metrics); - worker_count = job_metrics->target_worker_count_; - if (config_.scaling_policy() == 2) { - worker_count = 100; + // EASL - Scaling decision: how many workers (remote/local) should the job assign? + int64 total_workers = state_.ListWorkers().size(); + int64 suggested_worker_count = job_metrics->target_worker_count_; + + int64 target_remote_workers, target_local_workers; + if(config_.scaling_policy() == 1) { // Paper autoscaling, except a discrimination between local and remote workers is now made + VLOG(0) << "EASL - Scalability decision for dataset_key " + << compute_dataset_key << ": " << suggested_worker_count; + + bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth? + TF_RETURN_IF_ERROR(service::easl::local_workers_utils::ShouldUseLocalWorkers( + config_, metadata_store_, compute_dataset_key, should_use_local_workers + )); + + if(should_use_local_workers && request.local_workers().size() >= 1) { + target_remote_workers = suggested_worker_count - 1; + target_local_workers = 1; + } else { + target_remote_workers = suggested_worker_count; + target_local_workers = 0; + } + } else if(config_.scaling_policy() == 2) { // Use all available workers + target_remote_workers = total_workers - request.local_workers().size(); + target_local_workers = request.local_workers().size(); + } else if(config_.scaling_policy() == 3) { // Grid search over local and remote workers + TF_RETURN_IF_ERROR(service::easl::local_workers_utils::DecideTargetWorkersGridSearch( + config_, metadata_store_, compute_dataset_key, + total_workers - request.local_workers().size(), request.local_workers().size(), + target_remote_workers, target_local_workers + )); } if (job_type == "PUT" || job_type == "PUT_SOURCE") { @@ -1004,17 +1031,17 @@ Status DataServiceDispatcherImpl::CreateJob( s = metadata_store_.GetJobMetricsByDatasetFingerprintAndName( dataset_fingerprint, job_name, dataset_fingerprint_metrics); if (s.ok()) { - worker_count = std::ceil(std::max(1.0, + suggested_worker_count = std::ceil(std::max(1.0, dataset_fingerprint_metrics->target_worker_count_ * 1.5)); } - job_metrics->target_worker_count_ = worker_count; + job_metrics->target_worker_count_ = suggested_worker_count; } // EASL: Logging stuff if (kEnableEventLogging) { - last_scale_[job_name] = worker_count; + last_scale_[job_name] = suggested_worker_count; RecordEvent(dataset_fingerprint, dataset_id, job_name, job_id, - "starting_worker_count", std::to_string(worker_count)); + "starting_worker_count", std::to_string(suggested_worker_count)); } int64 num_split_providers = 0; @@ -1031,7 +1058,10 @@ Status DataServiceDispatcherImpl::CreateJob( create_job->set_processing_mode(ProcessingModeDef(processing_mode)); create_job->set_job_type(job_type); create_job->set_num_split_providers(num_split_providers); - create_job->set_target_worker_count(worker_count); + create_job->set_target_worker_count(suggested_worker_count); + create_job->set_target_local_workers(target_local_workers); + create_job->set_target_remote_workers(target_remote_workers); + *create_job->mutable_local_workers() = {request.local_workers().begin(), request.local_workers().end()}; if (named_job_key.has_value()) { NamedJobKeyDef* key = create_job->mutable_named_job_key(); key->set_name(named_job_key->name); @@ -1090,7 +1120,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob( std::vector>& tasks) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector> workers = state_.ReserveWorkers( - job->job_id, job->target_worker_count); + job->job_id, job->target_worker_count, job->target_remote_workers, job->target_local_workers, job->local_workers); if (workers.size() < job->target_worker_count){ VLOG(0) << "EASL - Not enough workers for job. Elasticity policy requires " @@ -1384,6 +1414,8 @@ Status DataServiceDispatcherImpl::ClientHeartbeat( task_info->set_starting_round(task->starting_round); } response->set_job_finished(job->finished); + response->set_target_local_workers(job->target_local_workers); + response->set_target_remote_workers(job->target_remote_workers); VLOG(4) << "Found " << response->task_info_size() << " tasks for job client id " << request->job_client_id(); diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index 7ceff89650832c..c623fb4930ed7e 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -130,12 +130,20 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) { ProcessingMode(create_job.processing_mode()), create_job.num_split_providers(), named_job_key, num_consumers, - create_job.job_type(), create_job.target_worker_count()); + create_job.job_type(), + create_job.target_worker_count(), + create_job.target_remote_workers(), + create_job.target_local_workers()); DCHECK(!jobs_.contains(job_id)); jobs_[job_id] = job; tasks_by_job_[job_id] = TasksById(); ending_tasks_by_job_[job_id] = TasksById(); + for (auto worker: create_job.local_workers()) { + VLOG(1) << "EASL-DSL (DispatcherState::CreateJob): adding local worker to dispatcher's state: " << worker; + job->local_workers.insert(worker); + } + if (named_job_key.has_value()) { DCHECK(!named_jobs_.contains(named_job_key.value()) || named_jobs_[named_job_key.value()]->garbage_collected); @@ -373,29 +381,60 @@ DispatcherState::ListAvailableWorkers() const { return workers; } +// Reserves a number of available workers for a particular job. If num_workers +// is lower than or equal to 0, then the reserved number of workers is equal +// to all the available workers. std::vector> DispatcherState::ReserveWorkers( - int64 job_id, int64 target_num_workers) { - // DCHECK(num_workers <= avail_workers_.size()); - jobs_[job_id]->target_worker_count = target_num_workers; + int64 job_id, + int64 target_remote_workers, + int64 target_local_workers, + const absl::flat_hash_set local_workers) { // If the number of required workers is below those available, we just assign // as many as there are available at this epoch's scheduling time. - int64 num_workers = target_num_workers <= 0 - || target_num_workers > avail_workers_.size() ? avail_workers_.size() - : target_num_workers; + int64 target_worker_count = target_remote_workers + target_local_workers; + if(target_worker_count <= 0 || target_worker_count > avail_workers_.size()) { + target_remote_workers = avail_workers_.size(); + target_local_workers = avail_workers_.size(); + } + + jobs_[job_id]->target_worker_count = target_worker_count; + std::vector> workers; - workers.reserve(num_workers); - VLOG(0) << "(ReserveWorkers) User got " << num_workers << " workers from " - << "target " << target_num_workers << " workers"; + workers.reserve(avail_workers_.size()); + VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers)" << "\n" + << "Available remote: " << avail_workers_.size() << "\n" + << "Available local: " << local_workers.size() << "\n" + << "Target remote: " << target_remote_workers << "\n" + << "Target local: " << target_local_workers << "\n"; + for (auto it = avail_workers_.begin(); it != avail_workers_.end(); ) { - num_workers--; + bool is_local = local_workers.count(it->first); + if (is_local) { + VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found local worker " << it->first; + if (target_local_workers <= 0) { // No additional local workers needed + it++; + continue; + } else { + target_local_workers--; + } + } else { + VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found remote worker " << it->first; + if (target_remote_workers <= 0) { // No additional remote workers needed + it++; + continue; + } else { + target_remote_workers--; + } + } + workers.push_back(it->second); VLOG(0) << "(ReserveWorkers) Assigning worker at address " << it->second->address << " to job " << job_id; - workers_by_job_[job_id][it->second->address] = it->second; + workers_by_job_[job_id].push_back(it->second); jobs_by_worker_[it->second->address][job_id] = jobs_[job_id]; avail_workers_.erase(it++); - if (num_workers == 0) + if (target_worker_count == 0) break; } VLOG(0) << "(ReserveWorkers) Number of workers for job " << job_id << " is: " diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index e7261277c8fa7e..0c4eff62e53792 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -137,14 +137,21 @@ class DispatcherState { int64 num_split_providers, absl::optional named_job_key, absl::optional num_consumers, const std::string& job_type, - int64 target_worker_count) + int64 target_worker_count, + int64 target_remote_workers, + int64 target_local_workers, + absl::flat_hash_set local_workers = {} + ) : job_id(job_id), dataset_id(dataset_id), processing_mode(processing_mode), named_job_key(named_job_key), num_consumers(num_consumers), job_type(job_type), - target_worker_count(target_worker_count){ + target_worker_count(target_worker_count), + target_remote_workers(target_remote_workers), + target_local_workers(target_local_workers), + local_workers(local_workers){ if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) { distributed_epoch_state = DistributedEpochState(num_split_providers); } @@ -176,6 +183,10 @@ class DispatcherState { const std::string job_type; int64 target_worker_count; // Non-constant, can be dynamically adjusted. int64 current_worker_count = 0; + // EASL - DSL + const int64 target_remote_workers; // replaces worker_count as there is a distinction now + const int64 target_local_workers; // replaces worker_count as there is a distinction now + absl::flat_hash_set local_workers; // list of local workers in the client }; struct Task { @@ -228,7 +239,9 @@ class DispatcherState { // is lower than or equal to 0, then the reserved number of workers is equal // to all the available workers. std::vector> ReserveWorkers(int64 job_id, - int64 num_workers = 0); + int64 target_remote_workers = 0, + int64 target_local_workers = 0, + const absl::flat_hash_set local_workers = {}); // Returns the next available job id. int64 NextAvailableJobId() const; diff --git a/tensorflow/core/data/service/easl/local_workers_utils.cc b/tensorflow/core/data/service/easl/local_workers_utils.cc new file mode 100644 index 00000000000000..d7efe8eacfd586 --- /dev/null +++ b/tensorflow/core/data/service/easl/local_workers_utils.cc @@ -0,0 +1,123 @@ +// +// Created by Muyu Li on 16.11.21. +// Edited by the DSL group HS21 (Theodor Amariucai, Jiale Chen, Muyu Li) throughout November 2021 - February 2022 +// + +#include "local_workers_utils.h" + +namespace tensorflow { +namespace data { +namespace service { +namespace easl { +namespace local_workers_utils { + +Status ShouldUseLocalWorkers( + const experimental::DispatcherConfig& dispatcher_config, + const ::tensorflow::data::easl::MetadataStore& metadata_store, + const std::string& dataset_key, + bool& should_use_local_workers) { + using NodeMetrics = ::tensorflow::data::easl::NodeMetrics; + using ModelMetrics = ::tensorflow::data::easl::ModelMetrics; + + // Check if we have any metrics for this dataset + std::shared_ptr job_metrics; + Status s = metadata_store.GetLastNodeMetricsByDatasetFingerprint( + dataset_key, job_metrics); + + // We do not yet have the metrics for this dataset --> use 1 worker + if(errors::IsNotFound(s)) { + VLOG(0) << "DSL (ShouldUseLocalWorkers) No metrics found for dataset, will use local workers (optimistic)!"; + should_use_local_workers = true; + return Status::OK(); + } else if (!s.ok()) { + VLOG(0) << "DSL (ShouldUseLocalWorkers) Another error has been thrown: " << s; + return s; + } + + // Pipeline stats: last TF node metrics + std::shared_ptr last_tf_node_metrics; + + s = metadata_store.GetLastNodeMetricsByDatasetKey(dataset_key, last_tf_node_metrics); + if (!s.ok()) { + VLOG(0) << "DSL (ShouldUseLocalWorkers) Failed to get the last TF node metrics"; + return s; + } + + int64_t total_bytes_produced = 0, total_num_elements = 0; + for (std::pair> e : + last_tf_node_metrics->metrics_) { + std::shared_ptr node_metrics = e.second; + total_bytes_produced += node_metrics->bytes_produced(); + total_num_elements += node_metrics->num_elements(); + } + + double avg_bytes_per_element = (double)total_bytes_produced / total_num_elements; + VLOG(0) << "DSL (ShouldUseLocalWorkers) Total bytes produced: " << total_bytes_produced << "\n" + << "Total num elements: " << total_num_elements << "\n" + << "Avg bytes produced per element: " << avg_bytes_per_element << "\n" + << "Decision Threshold: " << dispatcher_config.avg_bytes_per_element_local_thres() << "\n"; + + if (avg_bytes_per_element > dispatcher_config.avg_bytes_per_element_local_thres()) { + should_use_local_workers = true; + VLOG(0) << "DSL (ShouldUseLocalWorkers) Using local workers! (because avg. bytes per element > threshold) \n"; + } + else { + should_use_local_workers = false; + VLOG(0) << "DSL (ShouldUseLocalWorkers) NOT using local workers! (because avg. bytes per element < threshold) \n"; + } + + return Status::OK(); +} + +std::vector records; + +void grid_search(int64 num_worker_remote_avail, int64 num_worker_local_avail, + int64& num_worker_remote_target, int64& num_worker_local_target) { + std::vector> test_set = std::vector>(); + for(int64 n_r = 0; n_r <= num_worker_remote_avail; n_r++) { + for(int64 n_l = 0; n_l <= num_worker_local_avail; n_l++) { + if(n_r + n_l <= 0) { + continue; + } + test_set.emplace_back(n_r, n_l); + } + } + std::vector epoch_times; + for(int i = 1; i < records.size(); i++) { + epoch_times.push_back(records[i] - records[i-1]); + } + int index; + if(epoch_times.size() < test_set.size()) { + index = epoch_times.size(); + } else { + index = std::min_element(epoch_times.begin(), epoch_times.begin() + test_set.size()) - epoch_times.begin(); + } + auto p = test_set[index]; + num_worker_remote_target = p.first; + num_worker_local_target = p.second; +} + +Status DecideTargetWorkersGridSearch( + const experimental::DispatcherConfig& dispatcher_config, + const ::tensorflow::data::easl::MetadataStore& metadata_store, + const std::string& dataset_key, + int64 num_worker_remote_avail, + int64 num_worker_local_avail, + int64& num_worker_remote_target, + int64& num_worker_local_target) { + std::time_t t = std::time(nullptr); + records.push_back(t); + grid_search(num_worker_remote_avail, num_worker_local_avail, num_worker_remote_target, num_worker_local_target); + VLOG(0) << "DSL (DecideTargetWorkersGridSearch)" << "\n" + << "Available remote: " << num_worker_remote_avail << "\n" + << "Available local: " << num_worker_local_avail << "\n" + << "Decided remote: " << num_worker_remote_target << "\n" + << "Decided local: " << num_worker_local_target << "\n"; + return Status::OK(); +} + +} // namespace local_workers_utils +} // namespace easl +} // namespace service +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/easl/local_workers_utils.h b/tensorflow/core/data/service/easl/local_workers_utils.h new file mode 100644 index 00000000000000..06ed141718ad39 --- /dev/null +++ b/tensorflow/core/data/service/easl/local_workers_utils.h @@ -0,0 +1,44 @@ +// +// Created by Muyu Li on 16.11.21. +// Edited by the DSL group HS21 (Theodor Amariucai, Jiale Chen, Muyu Li) throughout November 2021 - February 2022 +// + +#ifndef ML_INPUT_DATA_SERVICE_LOCAL_WORKERS_UTILS_H +#define ML_INPUT_DATA_SERVICE_LOCAL_WORKERS_UTILS_H + +#include +#include "tensorflow/core/platform/default/integral_types.h" +#include "tensorflow/core/data/service/easl/metadata_store.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/dispatcher_state.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { +namespace service { +namespace easl { +namespace local_workers_utils { + +Status ShouldUseLocalWorkers( + const experimental::DispatcherConfig& dispatcher_config, + const ::tensorflow::data::easl::MetadataStore& metadata_store, + const std::string& dataset_key, + bool& should_use_local_workers); + +Status DecideTargetWorkersGridSearch( + const experimental::DispatcherConfig& dispatcher_config, + const ::tensorflow::data::easl::MetadataStore& metadata_store, + const std::string& dataset_key, + int64 num_worker_remote_avail, + int64 num_worker_local_avail, + int64& num_worker_remote_target, + int64& num_worker_local_target); + +} // namespace local_workers_utils +} // namespace easl +} // namespace service +} // namespace data +} // namespace tensorflow + +#endif //ML_INPUT_DATA_SERVICE_LOCAL_WORKERS_UTILS_H diff --git a/tensorflow/core/data/service/journal.proto b/tensorflow/core/data/service/journal.proto index f1f064ab133179..e9c3e792460e95 100644 --- a/tensorflow/core/data/service/journal.proto +++ b/tensorflow/core/data/service/journal.proto @@ -46,7 +46,7 @@ message NamedJobKeyDef { int64 index = 2; } -// Next tag: 10 +// Next tag: 18 message CreateJobUpdate { reserved 5, 6; int64 job_id = 1; @@ -63,6 +63,9 @@ message CreateJobUpdate { // EASL string job_type = 13; // i.e read, write, cache int64 target_worker_count = 14; // determined by elasticity policy + int64 target_local_workers = 15; // decided between epochs + int64 target_remote_workers = 16; // decided between epochs + repeated string local_workers = 17; } // Next tag: 5 diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 1ad837ddde45b4..13b0f5f7f92fa7 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -532,6 +532,19 @@ void LocalWorkers::Add(absl::string_view worker_address, (*local_workers_)[worker_address] = worker; } +std::vector LocalWorkers::GetList() { + string local_workers_string = ""; + std::vector local_workers; + for (auto it = local_workers_->begin(); it != local_workers_->end(); ++it) { + local_workers.push_back(it->first); + local_workers_string += it->first + "; "; + } + + VLOG(1) << "EASL-DSL: Check List of Local Workers: " << local_workers_string; + + return local_workers; +} + std::shared_ptr LocalWorkers::Get( absl::string_view worker_address) { tf_shared_lock l(mu_); diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 54adda4f986699..7dc27db3abe501 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -172,6 +172,9 @@ class LocalWorkers { // at the address. static void Remove(absl::string_view worker_address); + // EASL-DSL: Get a list of local workers created in process + static std::vector GetList(); + private: using AddressToWorkerMap = absl::flat_hash_map>; diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 4b26df5803dfab..2c1934a8bc447f 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -302,7 +303,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { [&]() { return dispatcher_->GetOrCreateJob( dataset()->dataset_id_, dataset()->processing_mode_, key, - dataset()->num_consumers_, job_client_id_); + dataset()->num_consumers_, + job_client_id_, + LocalWorkers::GetList() + ); }, /*description=*/ strings::StrCat("get or create job with dispatcher at ", @@ -1015,7 +1019,29 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { } if (enqueue_result && !result.end_of_sequence) { - results_.push(std::move(result)); + uint64 current_micro_timestamp = Env::Default()->NowMicros(); + std::string data_source = task.info.worker_address(); + bool if_local = false; + int result_size = result.element.size(); + if (local_tasks_.contains(task.info.worker_address())) { + if_local = true; + local_results_buffer_.push(std::move(result)); + } else { + results_.push(std::move(result)); + } + + const char* log_location = std::getenv("EASL_MUYU_WORKER_METRICS"); + if (log_location) { + std::ofstream file(log_location, std::ios_base::app); + + file << current_micro_timestamp << "," + << data_source << "," + << if_local << "," + << result_size << "\n"; + + file.flush(); + file.clear(); + } } get_next_cv_.notify_all(); } diff --git a/tensorflow/core/protobuf/service_config.proto b/tensorflow/core/protobuf/service_config.proto index d6d9543f547ad0..0cd447b042db58 100644 --- a/tensorflow/core/protobuf/service_config.proto +++ b/tensorflow/core/protobuf/service_config.proto @@ -43,7 +43,8 @@ message DispatcherConfig { string log_dir = 13; // The interval at which the dispatcher should dump log files. int64 log_dumps_interval_ms = 14; - + // MUYU's modification + int64 avg_bytes_per_element_local_thres = 15; } // Configuration for a tf.data service WorkerServer. diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 72540f076471d6..5e85b3141bd292 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -33,8 +33,10 @@ class DispatcherConfig( collections.namedtuple("DispatcherConfig", [ "port", "protocol", "work_dir", "fault_tolerant_mode", "job_gc_check_interval_ms", "job_gc_timeout_ms", "cache_policy", - "cache_format", "cache_compression", "cache_ops_parallelism", "cache_path", - "scaling_policy", "log_dir", "log_dumps_interval_ms" + "cache_format", "cache_compression", "cache_ops_parallelism", + "cache_path", "scaling_policy", "log_dir", + "log_dumps_interval_ms", + "avg_bytes_per_element_local_workers_threshold" ])): """Configuration class for tf.data service dispatchers. @@ -68,13 +70,16 @@ class DispatcherConfig( cache_policy: The cache policy applied by the dispatcher (e.g. no-chache, all-cache..). cache_format: The file format used for the cache of the service. - cache_compression: The compression schema (if any) to use for the caching ops + cache_compression: The compression schema (if any) to use for the + caching ops cache_ops_parallelism: The number of parallel threads the caching ops shoujld use for reading/writing to cache cache_path: The base path to use for storing the cache contents. scaling_policy: The scaling policy applied by the dispatcher. - log_dir: The directory to put the logs into. If set not empty (""), logs will be printed there. - log_dumps_interval_ms: How often the dispatcher should dump the logs into the log_dir. + log_dir: The directory to put the logs into. If set not empty (""), + logs will be printed there. + log_dumps_interval_ms: How often the dispatcher should dump the + logs into the log_dir. Only valid if log_dir is not empty. """ @@ -92,7 +97,9 @@ def __new__(cls, cache_path="./outputs", scaling_policy=1, log_dir="", - log_dumps_interval_ms=None): + log_dumps_interval_ms=None, + avg_bytes_per_element_local_workers_threshold= 1024*1024*100 # 100MB + ): if protocol is None: protocol = _pywrap_utils.TF_DATA_DefaultProtocol() if job_gc_check_interval_ms is None: @@ -100,7 +107,7 @@ def __new__(cls, if job_gc_timeout_ms is None: job_gc_timeout_ms = 5 * 60 * 1000 # 5 minutes. if log_dumps_interval_ms is None: - log_dumps_interval_ms = 100 # 100msec + log_dumps_interval_ms = 100 # 100msec """ if cache_policy is None: cache_policy=1 @@ -115,8 +122,12 @@ def __new__(cls, cls).__new__(cls, port, protocol, work_dir, fault_tolerant_mode, job_gc_check_interval_ms, job_gc_timeout_ms, cache_policy, cache_format, - cache_compression, cache_ops_parallelism, cache_path, scaling_policy, - log_dir, log_dumps_interval_ms) + cache_compression, cache_ops_parallelism, + cache_path, + scaling_policy, + log_dir, + log_dumps_interval_ms, + avg_bytes_per_element_local_workers_threshold) @tf_export("data.experimental.service.DispatchServer", v1=[]) @@ -188,7 +199,9 @@ def __init__(self, config=None, start=True): cache_path=config.cache_path, scaling_policy=config.scaling_policy, log_dir=config.log_dir, - log_dumps_interval_ms=config.log_dumps_interval_ms) + log_dumps_interval_ms=config.log_dumps_interval_ms, + avg_bytes_per_element_local_workers_threshold= + config.avg_bytes_per_element_local_workers_threshold) self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( config_proto.SerializeToString()) if start: