Skip to content

Commit

Permalink
delete local tasks when local client is destroyed (reference: 6f2ce60 b…
Browse files Browse the repository at this point in the history
  • Loading branch information
jialchen committed Jan 28, 2022
1 parent c079023 commit e5036a6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
44 changes: 38 additions & 6 deletions tensorflow/core/data/service/worker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void DataServiceWorkerImpl::Stop() {

Status DataServiceWorkerImpl::GetElementResult(
const GetElementRequest* request, struct GetElementResult* result) {
Task* task;
Task* task = nullptr;
{
mutex_lock l(mu_);
if (cancelled_) {
Expand All @@ -175,17 +175,23 @@ Status DataServiceWorkerImpl::GetElementResult(
}
auto it = tasks_.find(request->task_id());
if (it == tasks_.end()) {
if (deleted_tasks_.contains(request->task_id())) {
return errors::FailedPrecondition(
"Got request for local task ", request->task_id(), " of worker ",
worker_address_, ", which has been deleted. You may be creating ",
"a duplicate job which has already finished. To fix this, make "
"sure to create your dataset only once, as opposed to re-creating "
"it repeatedly inside a loop.");
}
if (finished_tasks_.contains(request->task_id())) {
VLOG(3) << "Task is already finished";
result->end_of_sequence = true;
result->skip = false;
return Status::OK();
} else {
// Perhaps the workers hasn't gotten the task from the dispatcher yet.
// Return Unavailable so that the client knows to continue retrying.
VLOG(1) << "Task not found (probably not received from dispatcher yet";
return errors::Unavailable("Task ", request->task_id(), " not found");
}
// Perhaps the worker hasn't gotten the task from the dispatcher yet.
// Return Unavailable so that the client knows to continue retrying.
return errors::Unavailable("Task ", request->task_id(), " not found");
}
task = it->second.get();
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
Expand Down Expand Up @@ -470,6 +476,9 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
for (const auto& task : new_tasks) {
VLOG(1) << "Received new task from dispatcher with id " << task.task_id();
if (deleted_tasks_.contains(task.task_id())) {
continue;
}
Status s = ProcessTaskInternal(task);
if (!s.ok() && !errors::IsAlreadyExists(s)) {
LOG(WARNING) << "Failed to start processing task " << task.task_id()
Expand All @@ -480,6 +489,9 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
for (int64 task_id : task_ids_to_delete) {
VLOG(3) << "Deleting task " << task_id
<< " at the request of the dispatcher";
if (!tasks_.contains(task_id)) {
continue;
}
tasks_to_delete.push_back(std::move(tasks_[task_id]));
tasks_.erase(task_id);
finished_tasks_.insert(task_id);
Expand All @@ -492,6 +504,26 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
return Status::OK();
}

void DataServiceWorkerImpl::DeleteLocalTask(const TaskInfo& task_info)
TF_LOCKS_EXCLUDED(mu_) {
std::shared_ptr<Task> task;
{
mutex_lock l(mu_);
auto it = tasks_.find(task_info.task_id());
if (it == tasks_.end() || !it->second) {
return;
}
task = std::move(it->second);
tasks_.erase(task_info.task_id());
pending_completed_tasks_.insert(task_info.task_id());
deleted_tasks_.insert(task_info.task_id());
}

VLOG(2) << "Delete local task " << task_info.task_id() << " from worker "
<< worker_address_ << " at the request of the client.";
StopTask(*task);
}

void LocalWorkers::Add(absl::string_view worker_address,
std::shared_ptr<DataServiceWorkerImpl> worker) {
DCHECK(worker != nullptr) << "Adding a nullptr local worker is disallowed.";
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/data/service/worker_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class DataServiceWorkerImpl {
Status GetElementResult(const GetElementRequest* request,
GetElementResult* result);

// Deletes the local task and iterator. Only called by local clients to delete
// unused task iterators assuming the task is not read by remote clients. This
// method is not visible to gRPC clients.
void DeleteLocalTask(const TaskInfo& task_info);

// See worker.proto for API documentation.

/// Dispatcher-facing API.
Expand Down Expand Up @@ -126,6 +131,9 @@ class DataServiceWorkerImpl {
absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_);
// Completed tasks which haven't yet been communicated to the dispatcher.
absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
// Tasks deleted by the local client. If the client tries to read from them
// again, the worker will return a non-retriable FailedPrecondition error.
absl::flat_hash_set<int64_t> deleted_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Whether the worker has registered with the dispatcher yet.
bool registered_ TF_GUARDED_BY(mu_) = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
for (auto& worker_thread : worker_threads_) {
worker_thread.reset();
}

DeleteLocalWorkerTasks();
VLOG(1) << "Destroyed data service dataset iterator for job id "
<< job_client_id_;
}
Expand Down Expand Up @@ -550,6 +550,25 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
get_next_cv_.notify_all();
}

void DeleteLocalWorkerTasks() {
if (dataset()->target_workers_ != TargetWorkers::LOCAL) {
return;
}
std::vector<std::shared_ptr<Task>> tasks;
{
mutex_lock l(mu_);
tasks = tasks_;
}

for (const std::shared_ptr<Task>& task : tasks) {
std::shared_ptr<DataServiceWorkerImpl> worker =
LocalWorkers::Get(task->info.worker_address());
if (worker) {
worker->DeleteLocalTask(task->info);
}
}
}

// Periodically refresh the task list.
// Maintain one thread fetching elements for each task.
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
Expand Down

0 comments on commit e5036a6

Please sign in to comment.