diff --git a/velox/exec/tests/utils/LocalExchangeSource.cpp b/velox/exec/tests/utils/LocalExchangeSource.cpp index a24da1405814..5c67848f5b9e 100644 --- a/velox/exec/tests/utils/LocalExchangeSource.cpp +++ b/velox/exec/tests/utils/LocalExchangeSource.cpp @@ -52,116 +52,97 @@ class LocalExchangeSource : public exec::ExchangeSource { VELOX_CHECK(requestPending_); auto requestedSequence = sequence_; auto self = shared_from_this(); - auto hasBeenCalled = std::make_shared(false); - static std::mutex resultCallbackMutex; // Since this lambda may outlive 'this', we need to capture a // shared_ptr to the current object (self). - auto resultCallback = - [self, requestedSequence, buffers, hasBeenCalled, this]( - std::vector> data, int64_t sequence) { - { - std::lock_guard l(resultCallbackMutex); - // This is called when data is found and when this times out. Only - // the first of the two runs the body of the function. - if (*hasBeenCalled) { - return; - } - *hasBeenCalled = true; - } - if (data.empty()) { - common::testutil::TestValue::adjust( - "facebook::velox::exec::test::LocalExchangeSource::timeout", - this); - VeloxPromise requestPromise; - { - std::lock_guard l(queue_->mutex()); - requestPending_ = false; - requestPromise = std::move(promise_); - } - if (!requestPromise.isFulfilled()) { - requestPromise.setValue(Response{0, false}); - } - return; - } - if (requestedSequence > sequence) { - VLOG(2) << "Receives earlier sequence than requested: task " - << taskId_ << ", destination " << destination_ - << ", requested " << sequence << ", received " - << requestedSequence; - int64_t nExtra = requestedSequence - sequence; - VELOX_CHECK(nExtra < data.size()); - data.erase(data.begin(), data.begin() + nExtra); - sequence = requestedSequence; + auto resultCallback = [self, requestedSequence, buffers, this]( + std::vector> data, + int64_t sequence) { + if (requestedSequence > sequence) { + VLOG(2) << "Receives earlier sequence than requested: task " << taskId_ + << ", destination " << destination_ << ", requested " + << sequence << ", received " << requestedSequence; + int64_t nExtra = requestedSequence - sequence; + VELOX_CHECK(nExtra < data.size()); + data.erase(data.begin(), data.begin() + nExtra); + sequence = requestedSequence; + } + std::vector> pages; + bool atEnd = false; + int64_t totalBytes = 0; + for (auto& inputPage : data) { + if (!inputPage) { + atEnd = true; + // Keep looping, there could be extra end markers. + continue; + } + totalBytes += inputPage->length(); + inputPage->unshare(); + pages.push_back(std::make_unique(std::move(inputPage))); + inputPage = nullptr; + } + numPages_ += pages.size(); + totalBytes_ += totalBytes; + + try { + common::testutil::TestValue::adjust( + "facebook::velox::exec::test::LocalExchangeSource", &numPages_); + } catch (const std::exception& e) { + queue_->setError(e.what()); + checkSetRequestPromise(); + return; + } + + int64_t ackSequence; + VeloxPromise requestPromise; + { + std::vector queuePromises; + { + std::lock_guard l(queue_->mutex()); + requestPending_ = false; + requestPromise = std::move(promise_); + for (auto& page : pages) { + queue_->enqueueLocked(std::move(page), queuePromises); } - std::vector> pages; - bool atEnd = false; - int64_t totalBytes = 0; - for (auto& inputPage : data) { - if (!inputPage) { - atEnd = true; - // Keep looping, there could be extra end markers. - continue; - } - totalBytes += inputPage->length(); - inputPage->unshare(); - pages.push_back( - std::make_unique(std::move(inputPage))); - inputPage = nullptr; - } - numPages_ += pages.size(); - totalBytes_ += totalBytes; - - try { - common::testutil::TestValue::adjust( - "facebook::velox::exec::test::LocalExchangeSource", &numPages_); - } catch (const std::exception& e) { - queue_->setError(e.what()); - checkSetRequestPromise(); - return; + if (atEnd) { + queue_->enqueueLocked(nullptr, queuePromises); + atEnd_ = true; } + ackSequence = sequence_ = sequence + pages.size(); + } + for (auto& promise : queuePromises) { + promise.setValue(); + } + } + // Outside of queue mutex. + if (atEnd_) { + buffers->deleteResults(taskId_, destination_); + } else { + buffers->acknowledge(taskId_, destination_, ackSequence); + } + + if (!requestPromise.isFulfilled()) { + requestPromise.setValue(Response{totalBytes, atEnd_}); + } + }; - int64_t ackSequence; + // Call the callback in any case after timeout. + auto& exec = folly::QueuedImmediateExecutor::instance(); + future = std::move(future).via(&exec).onTimeout( + std::chrono::seconds(maxWaitSeconds), [self, this] { + common::testutil::TestValue::adjust( + "facebook::velox::exec::test::LocalExchangeSource::timeout", + this); VeloxPromise requestPromise; { - std::vector queuePromises; - { - std::lock_guard l(queue_->mutex()); - requestPending_ = false; - requestPromise = std::move(promise_); - for (auto& page : pages) { - queue_->enqueueLocked(std::move(page), queuePromises); - } - if (atEnd) { - queue_->enqueueLocked(nullptr, queuePromises); - atEnd_ = true; - } - ackSequence = sequence_ = sequence + pages.size(); - } - for (auto& promise : queuePromises) { - promise.setValue(); - } - } - // Outside of queue mutex. - if (atEnd_) { - buffers->deleteResults(taskId_, destination_); - } else { - buffers->acknowledge(taskId_, destination_, ackSequence); + std::lock_guard l(queue_->mutex()); + requestPending_ = false; + requestPromise = std::move(promise_); } - + Response response = {0, false}; if (!requestPromise.isFulfilled()) { - requestPromise.setValue(Response{totalBytes, atEnd_}); + requestPromise.setValue(response); } - }; - - // Call the callback in any case after timeout. 'future' returned - // from this will be realized with no error but empty data. Also, - // the future is a SemiFuture, so setting a timeout on the future - // in this function is not possible. - auto& exec = folly::QueuedImmediateExecutor::instance(); - std::move(folly::futures::sleep(std::chrono::seconds(maxWaitSeconds))) - .via(&exec) - .thenValue([resultCallback, requestedSequence](auto /*ignore*/) { - resultCallback({}, requestedSequence); + return response; }); buffers->getData(