Skip to content

Commit

Permalink
Added DSL contributions
Browse files Browse the repository at this point in the history
  • Loading branch information
amariucaitheodor committed Feb 3, 2022
1 parent e5036a6 commit 335ce65
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 44 deletions.
16 changes: 16 additions & 0 deletions tensorflow/core/data/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ cc_library(
deps = [
":cache_utils",
":scaling_utils",
":local_workers_utils",
":common_proto_cc",
":credentials_factory",
":data_service",
Expand Down Expand Up @@ -901,6 +902,21 @@ cc_library(
],
)

cc_library(
name = "local_workers_utils",
srcs = ["easl/local_workers_utils.cc"],
hdrs = [
"easl/local_workers_utils.h",
],
deps = [
":common_proto_cc",
":dispatcher_state",
":metadata_store",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "cache_model",
srcs = ["easl/cache_model.cc"],
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/core/data/service/dispatcher.proto
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ message JobKey {
int64 job_name_index = 2;
}

// Next tag: 8
// Next tag: 9
message GetOrCreateJobRequest {
reserved 3, 4;
// The id of the dataset to create a job for.
Expand All @@ -134,6 +134,8 @@ message GetOrCreateJobRequest {
oneof optional_num_consumers {
int64 num_consumers = 7;
}
// DSL - to pass the array of available local workers to the dispatcher's job creation logic
repeated string local_workers = 8;
}

// Next tag: 2
Expand Down Expand Up @@ -188,7 +190,7 @@ message ClientHeartbeatRequest {
double result_queue_size = 10;
}

// Next tag: 4
// Next tag: 5
message ClientHeartbeatResponse {
// A list of all tasks that the client should read from.
repeated TaskInfo task_info = 1;
Expand All @@ -198,6 +200,8 @@ message ClientHeartbeatResponse {
}
// Whether the job has finished.
bool job_finished = 2;
// DSL: to check whether we should use local workers (based on last epoch's metrics)
bool target_local_workers = 4;
}

// Next tag: 3
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ Status DataServiceDispatcherClient::RegisterDataset(
Status DataServiceDispatcherClient::GetOrCreateJob(
int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
int64& job_client_id) {
int64& job_client_id,
std::vector<std::string> local_workers) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id);
Expand All @@ -209,6 +210,9 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
if (num_consumers.has_value()) {
req.set_num_consumers(num_consumers.value());
}
// DSL - client sends a list of its local workers to the dispatcher
*req.mutable_local_workers() = {local_workers.begin(), local_workers.end()};

GetOrCreateJobResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
const absl::optional<JobKey>& job_key,
absl::optional<int64> num_consumers,
int64& job_client_id);
int64& job_client_id,
std::vector<std::string> local_workers);

// Releases a job client id, indicating that the id will no longer be used to
// read from the job.
Expand Down
52 changes: 42 additions & 10 deletions tensorflow/core/data/service/dispatcher_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/data/service/easl/cache_utils.h"
#include "tensorflow/core/data/service/easl/scaling_utils.h"
#include "tensorflow/core/data/service/easl/metadata_store.h"
#include "tensorflow/core/data/service/easl/local_workers_utils.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/graph.pb.h"
Expand Down Expand Up @@ -955,7 +956,6 @@ Status DataServiceDispatcherImpl::CreateJob(
int64 job_id = state_.NextAvailableJobId();

// EASL - Caching decision: should the job compute, write or read from cache?
int64 worker_count;
std::string job_type;
string job_name = "named_job_key.value().name";
std::shared_ptr<const Dataset> dataset;
Expand Down Expand Up @@ -993,28 +993,55 @@ Status DataServiceDispatcherImpl::CreateJob(

std::shared_ptr<easl::JobMetrics> job_metrics;
s = metadata_store_.GetJobMetrics(job_id, job_metrics);
worker_count = job_metrics->target_worker_count_;

if (config_.scaling_policy() == 2) {
worker_count = 100;
// EASL - Scaling decision: how many workers (remote/local) should the job assign?
int64 total_workers = state_.ListWorkers().size();
int64 suggested_worker_count = job_metrics->target_worker_count_;

int64 num_worker_remote_target, num_worker_local_target;
if(config_.scaling_policy() == 1) { // Paper autoscaling, except a discrimination between local and remote workers is now made
VLOG(0) << "EASL - Scalability decision for dataset_key "
<< compute_dataset_key << ": " << suggested_worker_count;

bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::ShouldUseLocalWorkers(
config_, metadata_store_, compute_dataset_key, should_use_local_workers
));

if(should_use_local_workers && local_workers.size() >= 1) {
num_worker_remote_target = suggested_worker_count - 1;
num_worker_local_target = 1;
} else {
num_worker_remote_target = suggested_worker_count;
num_worker_local_target = 0;
}
} else if(config_.scaling_policy() == 2) { // Use all available workers
num_worker_remote_target = total_workers - local_workers.size();
num_worker_local_target = local_workers.size();
} else if(config_.scaling_policy() == 3) { // Grid search over local and remote workers
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::DecideTargetWorkersGridSearch(
config_, metadata_store_, compute_dataset_key,
total_workers - local_workers.size(), local_workers.size(),
num_worker_remote_target, num_worker_local_target
));
}

if (job_type == "PUT" || job_type == "PUT_SOURCE") {
std::shared_ptr<easl::JobMetrics> dataset_fingerprint_metrics;
s = metadata_store_.GetJobMetricsByDatasetFingerprintAndName(
dataset_fingerprint, job_name, dataset_fingerprint_metrics);
if (s.ok()) {
worker_count = std::ceil(std::max(1.0,
suggested_worker_count = std::ceil(std::max(1.0,
dataset_fingerprint_metrics->target_worker_count_ * 1.5));
}
job_metrics->target_worker_count_ = worker_count;
job_metrics->target_worker_count_ = suggested_worker_count;
}

// EASL: Logging stuff
if (kEnableEventLogging) {
last_scale_[job_name] = worker_count;
last_scale_[job_name] = suggested_worker_count;
RecordEvent(dataset_fingerprint, dataset_id, job_name, job_id,
"starting_worker_count", std::to_string(worker_count));
"starting_worker_count", std::to_string(suggested_worker_count));
}

int64 num_split_providers = 0;
Expand All @@ -1031,7 +1058,10 @@ Status DataServiceDispatcherImpl::CreateJob(
create_job->set_processing_mode(ProcessingModeDef(processing_mode));
create_job->set_job_type(job_type);
create_job->set_num_split_providers(num_split_providers);
create_job->set_target_worker_count(worker_count);
create_job->set_target_worker_count(suggested_worker_count);
create_job->set_target_local_workers(target_local_workers);
create_job->set_target_remote_workers(target_remote_workers);
*create_job->mutable_local_workers() = {request.local_workers().begin(), request.local_workers().end()};
if (named_job_key.has_value()) {
NamedJobKeyDef* key = create_job->mutable_named_job_key();
key->set_name(named_job_key->name);
Expand Down Expand Up @@ -1090,7 +1120,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
std::vector<std::shared_ptr<const Task>>& tasks)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<const Worker>> workers = state_.ReserveWorkers(
job->job_id, job->target_worker_count);
job->job_id, job->target_worker_count, job->target_remote_workers, job->target_local_workers, job->local_workers);
if (workers.size() < job->target_worker_count){
VLOG(0)
<< "EASL - Not enough workers for job. Elasticity policy requires "
Expand Down Expand Up @@ -1384,6 +1414,8 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
task_info->set_starting_round(task->starting_round);
}
response->set_job_finished(job->finished);
response->set_target_local_workers(job->target_local_workers);
response->set_target_remote_workers(job->target_remote_workers);
VLOG(4) << "Found " << response->task_info_size()
<< " tasks for job client id " << request->job_client_id();

Expand Down
65 changes: 52 additions & 13 deletions tensorflow/core/data/service/dispatcher_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,20 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
ProcessingMode(create_job.processing_mode()),
create_job.num_split_providers(),
named_job_key, num_consumers,
create_job.job_type(), create_job.target_worker_count());
create_job.job_type(),
create_job.target_worker_count(),
create_job.target_remote_workers(),
create_job.target_local_workers());
DCHECK(!jobs_.contains(job_id));
jobs_[job_id] = job;
tasks_by_job_[job_id] = TasksById();
ending_tasks_by_job_[job_id] = TasksById();

for (auto worker: create_job.local_workers()) {
VLOG(1) << "EASL-DSL (DispatcherState::CreateJob): adding local worker to dispatcher's state: " << worker;
job->local_workers.insert(worker);
}

if (named_job_key.has_value()) {
DCHECK(!named_jobs_.contains(named_job_key.value()) ||
named_jobs_[named_job_key.value()]->garbage_collected);
Expand Down Expand Up @@ -373,29 +381,60 @@ DispatcherState::ListAvailableWorkers() const {
return workers;
}

// Reserves a number of available workers for a particular job. If num_workers
// is lower than or equal to 0, then the reserved number of workers is equal
// to all the available workers.
std::vector<std::shared_ptr<const DispatcherState::Worker>>
DispatcherState::ReserveWorkers(
int64 job_id, int64 target_num_workers) {
// DCHECK(num_workers <= avail_workers_.size());
jobs_[job_id]->target_worker_count = target_num_workers;
int64 job_id,
int64 target_remote_workers,
int64 target_local_workers,
const absl::flat_hash_set<std::string> local_workers) {
// If the number of required workers is below those available, we just assign
// as many as there are available at this epoch's scheduling time.
int64 num_workers = target_num_workers <= 0
|| target_num_workers > avail_workers_.size() ? avail_workers_.size()
: target_num_workers;
int64 target_worker_count = target_remote_workers + target_local_workers;
if(target_worker_count <= 0 || target_worker_count > avail_workers_.size()) {
target_remote_workers = avail_workers_.size();
target_local_workers = avail_workers_.size();
}

jobs_[job_id]->target_worker_count = target_worker_count;

std::vector<std::shared_ptr<const Worker>> workers;
workers.reserve(num_workers);
VLOG(0) << "(ReserveWorkers) User got " << num_workers << " workers from "
<< "target " << target_num_workers << " workers";
workers.reserve(avail_workers_.size());
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers)" << "\n"
<< "Available remote: " << avail_workers_.size() << "\n"
<< "Available local: " << local_workers.size() << "\n"
<< "Target remote: " << target_remote_workers << "\n"
<< "Target local: " << target_local_workers << "\n";

for (auto it = avail_workers_.begin(); it != avail_workers_.end(); ) {
num_workers--;
bool is_local = local_workers.count(it->first);
if (is_local) {
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found local worker " << it->first;
if (target_local_workers <= 0) { // No additional local workers needed
it++;
continue;
} else {
target_local_workers--;
}
} else {
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found remote worker " << it->first;
if (target_remote_workers <= 0) { // No additional remote workers needed
it++;
continue;
} else {
target_remote_workers--;
}
}

workers.push_back(it->second);
VLOG(0) << "(ReserveWorkers) Assigning worker at address "
<< it->second->address << " to job " << job_id;
workers_by_job_[job_id][it->second->address] = it->second;
workers_by_job_[job_id].push_back(it->second);
jobs_by_worker_[it->second->address][job_id] = jobs_[job_id];
avail_workers_.erase(it++);
if (num_workers == 0)
if (target_worker_count == 0)
break;
}
VLOG(0) << "(ReserveWorkers) Number of workers for job " << job_id << " is: "
Expand Down
19 changes: 16 additions & 3 deletions tensorflow/core/data/service/dispatcher_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,21 @@ class DispatcherState {
int64 num_split_providers,
absl::optional<NamedJobKey> named_job_key,
absl::optional<int64> num_consumers, const std::string& job_type,
int64 target_worker_count)
int64 target_worker_count,
int64 target_remote_workers,
int64 target_local_workers,
absl::flat_hash_set<std::string> local_workers = {}
)
: job_id(job_id),
dataset_id(dataset_id),
processing_mode(processing_mode),
named_job_key(named_job_key),
num_consumers(num_consumers),
job_type(job_type),
target_worker_count(target_worker_count){
target_worker_count(target_worker_count),
target_remote_workers(target_remote_workers),
target_local_workers(target_local_workers),
local_workers(local_workers){
if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
distributed_epoch_state = DistributedEpochState(num_split_providers);
}
Expand Down Expand Up @@ -176,6 +183,10 @@ class DispatcherState {
const std::string job_type;
int64 target_worker_count; // Non-constant, can be dynamically adjusted.
int64 current_worker_count = 0;
// EASL - DSL
const int64 target_remote_workers; // replaces worker_count as there is a distinction now
const int64 target_local_workers; // replaces worker_count as there is a distinction now
absl::flat_hash_set<std::string> local_workers; // list of local workers in the client
};

struct Task {
Expand Down Expand Up @@ -228,7 +239,9 @@ class DispatcherState {
// is lower than or equal to 0, then the reserved number of workers is equal
// to all the available workers.
std::vector<std::shared_ptr<const Worker>> ReserveWorkers(int64 job_id,
int64 num_workers = 0);
int64 target_remote_workers = 0,
int64 target_local_workers = 0,
const absl::flat_hash_set<std::string> local_workers = {});

// Returns the next available job id.
int64 NextAvailableJobId() const;
Expand Down
Loading

0 comments on commit 335ce65

Please sign in to comment.