Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DSL based on EASL2.7 #2

Draft
wants to merge 8 commits into
base: easl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tensorflow/core/data/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ cc_library(
deps = [
":cache_utils",
":scaling_utils",
":local_workers_utils",
":common_proto_cc",
":credentials_factory",
":data_service",
Expand Down Expand Up @@ -901,6 +902,21 @@ cc_library(
],
)

cc_library(
name = "local_workers_utils",
srcs = ["easl/local_workers_utils.cc"],
hdrs = [
"easl/local_workers_utils.h",
],
deps = [
":common_proto_cc",
":dispatcher_state",
":metadata_store",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "cache_model",
srcs = ["easl/cache_model.cc"],
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/core/data/service/dispatcher.proto
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ message JobKey {
int64 job_name_index = 2;
}

// Next tag: 8
// Next tag: 9
message GetOrCreateJobRequest {
reserved 3, 4;
// The id of the dataset to create a job for.
Expand All @@ -134,6 +134,8 @@ message GetOrCreateJobRequest {
oneof optional_num_consumers {
int64 num_consumers = 7;
}
// DSL - to pass the array of available local workers to the dispatcher's job creation logic
repeated string local_workers = 8;
}

// Next tag: 2
Expand Down Expand Up @@ -188,7 +190,7 @@ message ClientHeartbeatRequest {
double result_queue_size = 10;
}

// Next tag: 4
// Next tag: 5
message ClientHeartbeatResponse {
// A list of all tasks that the client should read from.
repeated TaskInfo task_info = 1;
Expand All @@ -198,6 +200,8 @@ message ClientHeartbeatResponse {
}
// Whether the job has finished.
bool job_finished = 2;
// DSL: to check whether we should use local workers (based on last epoch's metrics)
bool target_local_workers = 4;
}

// Next tag: 3
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ Status DataServiceDispatcherClient::RegisterDataset(
Status DataServiceDispatcherClient::GetOrCreateJob(
int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
int64& job_client_id) {
int64& job_client_id,
std::vector<std::string> local_workers) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id);
Expand All @@ -209,6 +210,9 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
if (num_consumers.has_value()) {
req.set_num_consumers(num_consumers.value());
}
// DSL - client sends a list of its local workers to the dispatcher
*req.mutable_local_workers() = {local_workers.begin(), local_workers.end()};

GetOrCreateJobResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key,
absl::optional<int64> num_consumers,
int64& job_client_id);
int64& job_client_id,
std::vector<std::string> local_workers);

// Releases a job client id, indicating that the id will no longer be used to
// read from the job.
Expand Down
78 changes: 61 additions & 17 deletions tensorflow/core/data/service/dispatcher_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/data/service/easl/cache_utils.h"
#include "tensorflow/core/data/service/easl/scaling_utils.h"
#include "tensorflow/core/data/service/easl/metadata_store.h"
#include "tensorflow/core/data/service/easl/local_workers_utils.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/graph.pb.h"
Expand Down Expand Up @@ -473,7 +474,7 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat(
state_.DatasetFromId(task_object->job->dataset_id, dataset);

string job_type;
string job_name = task_object->job->named_job_key.value().name;
string job_name = "task_object->job->named_job_key.value().name";
Status s3 = metadata_store_.GetJobTypeByJobId(job_id, job_type);

if (s3.ok()) {
Expand Down Expand Up @@ -631,7 +632,7 @@ Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
if (kEnableEventLogging) {
std::shared_ptr<const Dataset> dataset;
state_.DatasetFromId(job->dataset_id, dataset);
string job_name = job->named_job_key.value().name;
string job_name = "job->named_job_key.value().name";
RecordEvent(dataset->fingerprint, dataset->dataset_id, job_name, job_id,
"extended_epoch");
}
Expand Down Expand Up @@ -831,9 +832,14 @@ Status DataServiceDispatcherImpl::GetOrCreateJob(
GetOrCreateJobRequest::kNumConsumers) {
num_consumers = request->num_consumers();
}

absl::flat_hash_set<std::string> local_workers;
local_workers.insert(request->local_workers().cbegin(),
request->local_workers().cend());

TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(),
requested_processing_mode, key, num_consumers,
job));
job, local_workers));
int64 job_client_id;
TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
response->set_job_client_id(job_client_id);
Expand Down Expand Up @@ -942,7 +948,9 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob(
Status DataServiceDispatcherImpl::CreateJob(
int64 dataset_id, ProcessingMode processing_mode,
absl::optional<NamedJobKey> named_job_key,
absl::optional<int64> num_consumers, std::shared_ptr<const Job>& job)
absl::optional<int64> num_consumers,
std::shared_ptr<const Job>& job,
absl::flat_hash_set<std::string> local_workers)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
switch (processing_mode) {
case ProcessingMode::PARALLEL_EPOCHS:
Expand All @@ -952,12 +960,13 @@ Status DataServiceDispatcherImpl::CreateJob(
return errors::Internal(
absl::StrCat("ProcessingMode ", processing_mode, " not recognized"));
}
LOG(INFO) << "EASL DSL - DataServiceDispatcherImpl::CreateJob triggered";

int64 job_id = state_.NextAvailableJobId();

// EASL - Caching decision: should the job compute, write or read from cache?
int64 worker_count;
std::string job_type;
string job_name = named_job_key.value().name;
string job_name = "named_job_key.value().name";
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
int64 dataset_fingerprint = dataset->fingerprint;
Expand Down Expand Up @@ -993,28 +1002,59 @@ Status DataServiceDispatcherImpl::CreateJob(

std::shared_ptr<easl::JobMetrics> job_metrics;
s = metadata_store_.GetJobMetrics(job_id, job_metrics);
worker_count = job_metrics->target_worker_count_;

if (config_.scaling_policy() == 2) {
worker_count = 100;
// EASL - Scaling decision: how many workers (remote/local) should the job assign?
int64 total_workers = state_.ListWorkers().size();
int64 suggested_worker_count = job_metrics->target_worker_count_;

int64 target_remote_workers, target_local_workers;
if(config_.scaling_policy() == 1) { // Paper autoscaling, except a discrimination between local and remote workers is now made
VLOG(0) << "EASL DSL - Scalability decision for dataset_key "
<< compute_dataset_key << ": " << suggested_worker_count;
LOG(INFO) << "EASL - Scalability decision for dataset_key "
<< compute_dataset_key << ": " << suggested_worker_count << " with fingerprint " << dataset_fingerprint;

bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::ShouldUseLocalWorkers(
config_, metadata_store_, dataset_fingerprint, should_use_local_workers
));

if(should_use_local_workers && local_workers.size() >= 1) {
target_remote_workers = suggested_worker_count - 1;
target_local_workers = 1;
} else {
target_remote_workers = suggested_worker_count;
target_local_workers = 0;
}
} else if(config_.scaling_policy() == 2) { // Use all available workers
LOG(INFO) << "EASL DSL - Use all available workers";
target_remote_workers = total_workers - local_workers.size();
target_local_workers = local_workers.size();
} else if(config_.scaling_policy() == 3) { // Grid search over local and remote workers
LOG(INFO) << "EASL DSL - Grid search over local and remote workers"
<< compute_dataset_key << ": " << suggested_worker_count;
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::DecideTargetWorkersGridSearch(
total_workers - local_workers.size(), local_workers.size(),
target_remote_workers, target_local_workers // passed by reference
));
}

if (job_type == "PUT" || job_type == "PUT_SOURCE") {
std::shared_ptr<easl::JobMetrics> dataset_fingerprint_metrics;
s = metadata_store_.GetJobMetricsByDatasetFingerprintAndName(
dataset_fingerprint, job_name, dataset_fingerprint_metrics);
if (s.ok()) {
worker_count = std::ceil(std::max(1.0,
suggested_worker_count = std::ceil(std::max(1.0,
dataset_fingerprint_metrics->target_worker_count_ * 1.5));
}
job_metrics->target_worker_count_ = worker_count;
job_metrics->target_worker_count_ = suggested_worker_count;
}

// EASL: Logging stuff
if (kEnableEventLogging) {
last_scale_[job_name] = worker_count;
last_scale_[job_name] = suggested_worker_count;
RecordEvent(dataset_fingerprint, dataset_id, job_name, job_id,
"starting_worker_count", std::to_string(worker_count));
"starting_worker_count", std::to_string(suggested_worker_count));
}

int64 num_split_providers = 0;
Expand All @@ -1031,7 +1071,10 @@ Status DataServiceDispatcherImpl::CreateJob(
create_job->set_processing_mode(ProcessingModeDef(processing_mode));
create_job->set_job_type(job_type);
create_job->set_num_split_providers(num_split_providers);
create_job->set_target_worker_count(worker_count);
create_job->set_target_worker_count(suggested_worker_count);
create_job->set_target_local_workers(target_local_workers);
create_job->set_target_remote_workers(target_remote_workers);
*create_job->mutable_local_workers() = {local_workers.begin(), local_workers.end()};
if (named_job_key.has_value()) {
NamedJobKeyDef* key = create_job->mutable_named_job_key();
key->set_name(named_job_key->name);
Expand Down Expand Up @@ -1090,7 +1133,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
std::vector<std::shared_ptr<const Task>>& tasks)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<const Worker>> workers = state_.ReserveWorkers(
job->job_id, job->target_worker_count);
job->job_id, job->target_remote_workers, job->target_local_workers, job->local_workers);
if (workers.size() < job->target_worker_count){
VLOG(0)
<< "EASL - Not enough workers for job. Elasticity policy requires "
Expand Down Expand Up @@ -1239,7 +1282,7 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
// EASL: Update the client metrics
int64 job_target_worker_count;
string job_type;
string job_name = job->named_job_key.value().name;
string job_name = "job->named_job_key.value().name";
metadata_store_.GetJobTypeByJobId(job->job_id, job_type);
// FIXME: Note that we're only checking the first split provider
if (config_.scaling_policy() == 1 &&
Expand Down Expand Up @@ -1292,7 +1335,7 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset));
RecordEvent(dataset->fingerprint, dataset->dataset_id,
job->named_job_key.value().name, job->job_id, scale_type,
"job->named_job_key.value().name", job->job_id, scale_type,
std::to_string(target_worker_count));
}
}
Expand Down Expand Up @@ -1384,6 +1427,7 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
task_info->set_starting_round(task->starting_round);
}
response->set_job_finished(job->finished);
response->set_target_local_workers(job->target_local_workers);
VLOG(4) << "Found " << response->task_info_size()
<< " tasks for job client id " << request->job_client_id();

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ class DataServiceDispatcherImpl {
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
absl::optional<DispatcherState::NamedJobKey> named_job_key,
absl::optional<int64> num_consumers,
std::shared_ptr<const DispatcherState::Job>& job)
std::shared_ptr<const DispatcherState::Job>& job,
absl::flat_hash_set<std::string> local_workers)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Creates tasks for the specified worker, one task for every unfinished job.
Status CreateTasksForWorker(const std::string& worker_address);
Expand Down
63 changes: 51 additions & 12 deletions tensorflow/core/data/service/dispatcher_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,20 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
ProcessingMode(create_job.processing_mode()),
create_job.num_split_providers(),
named_job_key, num_consumers,
create_job.job_type(), create_job.target_worker_count());
create_job.job_type(),
create_job.target_worker_count(),
create_job.target_remote_workers(),
create_job.target_local_workers());
DCHECK(!jobs_.contains(job_id));
jobs_[job_id] = job;
tasks_by_job_[job_id] = TasksById();
ending_tasks_by_job_[job_id] = TasksById();

for (auto worker: create_job.local_workers()) {
VLOG(1) << "EASL-DSL (DispatcherState::CreateJob): adding local worker to dispatcher's state: " << worker;
job->local_workers.insert(worker);
}

if (named_job_key.has_value()) {
DCHECK(!named_jobs_.contains(named_job_key.value()) ||
named_jobs_[named_job_key.value()]->garbage_collected);
Expand Down Expand Up @@ -373,29 +381,60 @@ DispatcherState::ListAvailableWorkers() const {
return workers;
}

// Reserves a number of available workers for a particular job. If num_workers
// is lower than or equal to 0, then the reserved number of workers is equal
// to all the available workers.
std::vector<std::shared_ptr<const DispatcherState::Worker>>
DispatcherState::ReserveWorkers(
int64 job_id, int64 target_num_workers) {
// DCHECK(num_workers <= avail_workers_.size());
jobs_[job_id]->target_worker_count = target_num_workers;
int64 job_id,
int64 target_remote_workers,
int64 target_local_workers,
const absl::flat_hash_set<std::string> local_workers) {
// If the number of required workers is below those available, we just assign
// as many as there are available at this epoch's scheduling time.
int64 num_workers = target_num_workers <= 0
|| target_num_workers > avail_workers_.size() ? avail_workers_.size()
: target_num_workers;
int64 target_worker_count = target_remote_workers + target_local_workers;
if(target_worker_count <= 0 || target_worker_count > avail_workers_.size()) {
target_remote_workers = avail_workers_.size();
target_local_workers = avail_workers_.size();
}

jobs_[job_id]->target_worker_count = target_worker_count;

std::vector<std::shared_ptr<const Worker>> workers;
workers.reserve(num_workers);
VLOG(0) << "(ReserveWorkers) User got " << num_workers << " workers from "
<< "target " << target_num_workers << " workers";
workers.reserve(avail_workers_.size());
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers)" << "\n"
<< "Available remote: " << avail_workers_.size() << "\n"
<< "Available local: " << local_workers.size() << "\n"
<< "Target remote: " << target_remote_workers << "\n"
<< "Target local: " << target_local_workers << "\n";

for (auto it = avail_workers_.begin(); it != avail_workers_.end(); ) {
num_workers--;
bool is_local = local_workers.count(it->first);
if (is_local) {
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found local worker " << it->first;
if (target_local_workers <= 0) { // No additional local workers needed
it++;
continue;
} else {
target_local_workers--;
}
} else {
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found remote worker " << it->first;
if (target_remote_workers <= 0) { // No additional remote workers needed
it++;
continue;
} else {
target_remote_workers--;
}
}

workers.push_back(it->second);
VLOG(0) << "(ReserveWorkers) Assigning worker at address "
<< it->second->address << " to job " << job_id;
workers_by_job_[job_id][it->second->address] = it->second;
jobs_by_worker_[it->second->address][job_id] = jobs_[job_id];
avail_workers_.erase(it++);
if (num_workers == 0)
if (target_local_workers + target_remote_workers == 0)
break;
}
VLOG(0) << "(ReserveWorkers) Number of workers for job " << job_id << " is: "
Expand Down
Loading