Skip to content

Commit

Permalink
Add capability to process generic callbacks in progress thread (#65)
Browse files Browse the repository at this point in the history
Implement and use generic callbacks to create endpoints, close endpoints and listeners, and probe tags. By doing this we ensure those don't need to take the UCS spinlock or progress the worker during the application thread, which can be sources of deadlocks.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #65
  • Loading branch information
pentschev authored Aug 3, 2023
1 parent 65deb67 commit 98b6b89
Show file tree
Hide file tree
Showing 17 changed files with 485 additions and 76 deletions.
74 changes: 63 additions & 11 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ class DelayedSubmission {

class DelayedSubmissionCollection {
private:
std::vector<DelayedSubmissionCallbackType>
_genericPre{}; ///< The collection of all known generic pre-progress operations.
std::vector<DelayedSubmissionCallbackType>
_genericPost{}; ///< The collection of all known generic post-progress operations.
std::vector<std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>>
_collection{}; ///< The collection of all known delayed submission operations.
_requests{}; ///< The collection of all known delayed request submission operations.
std::mutex _mutex{}; ///< Mutex to provide access to the collection.
bool _enableDelayedRequestSubmission{false};

public:
/**
Expand All @@ -69,36 +74,83 @@ class DelayedSubmissionCollection {
* Construct an empty collection of delayed submissions. Despite its name, a delayed
* submission registration may be processed right after registration, thus effectively
* making it an immediate submission.
*
* @param[in] enableDelayedRequestSubmission whether request submission should be
* enabled, if `false`, only generic
* callbacks are enabled.
*/
DelayedSubmissionCollection() = default;
explicit DelayedSubmissionCollection(bool enableDelayedRequestSubmission = false);

DelayedSubmissionCollection() = delete;
DelayedSubmissionCollection(const DelayedSubmissionCollection&) = delete;
DelayedSubmissionCollection& operator=(DelayedSubmissionCollection const&) = delete;
DelayedSubmissionCollection(DelayedSubmissionCollection&& o) = delete;
DelayedSubmissionCollection& operator=(DelayedSubmissionCollection&& o) = delete;

/**
* @brief Process all pending delayed submission operations.
* @brief Process pending delayed request submission and generic-pre callback operations.
*
* Process all pending delayed request submissions and generic callbacks. Generic
* callbacks are deemed completed when their execution completes. On the other hand, the
* execution of the delayed request submission callbacks does not imply completion of the
* operation, only that it has been submitted. The completion of each delayed request
* submission is handled externally by the implementation of the object being processed,
* for example by checking the result of `ucxx::Request::isCompleted()`.
*/
void processPre();

/**
* @brief Process all pending generic-post callback operations.
*
* Process all pending delayed submissions and execute their callbacks. The execution
* of the callbacks does not imply completion of the operation, only that it has been
* submitted. The completion of each operation is handled externally by the
* implementation of the object being processed, for example by checking the result
* of `ucxx::Request::isCompleted()`.
* Process all pending generic-post callbacks. Generic callbacks are deemed completed when
* their execution completes.
*/
void process();
void processPost();

/**
* @brief Register a request for delayed submission.
*
* Register a request for delayed submission with a callback that will be executed when
* the request is in fact submitted when `process()` is called.
* the request is in fact submitted when `processPre()` is called.
*
* @throws std::runtime_error if delayed request submission was disabled at construction.
*
* @param[in] request the request to which the callback belongs, ensuring it remains
* alive until the callback is invoked.
* @param[in] callback the callback that will be executed by `process()` when the
* @param[in] callback the callback that will be executed by `processPre()` when the
* operation is submitted.
*/
void registerRequest(std::shared_ptr<Request> request, DelayedSubmissionCallbackType callback);

/**
* @brief Register a generic callback to execute during `processPre()`.
*
* Register a generic callback that will be executed when `processPre()` is called.
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*/
void registerGenericPre(DelayedSubmissionCallbackType callback);

/**
* @brief Register a generic callback to execute during `processPost()`.
*
* Register a generic callback that will be executed when `processPost()` is called.
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*/
void registerGenericPost(DelayedSubmissionCallbackType callback);

/**
* @brief Inquire if delayed request submission is enabled.
*
* Check whether delayed submission request is enabled, in which case `registerRequest()`
* may be used to register requests that will be executed during `processPre()`.
*
* @returns `true` if a delayed request submission is enabled, `false` otherwise.
*/
bool isDelayedRequestSubmissionEnabled() const;
};

} // namespace ucxx
9 changes: 3 additions & 6 deletions cpp/include/ucxx/listener.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@

namespace ucxx {

void ucpListenerDestructor(ucp_listener_h ptr);

class Listener : public Component {
private:
std::unique_ptr<ucp_listener, void (*)(ucp_listener_h)> _handle{
nullptr, ucpListenerDestructor}; ///< The UCP listener handle
std::string _ip{}; ///< The IP address to which the listener is bound to
uint16_t _port{0}; ///< The port to which the listener is bound to
ucp_listener_h _handle{nullptr}; ///< The UCP listener handle
std::string _ip{}; ///< The IP address to which the listener is bound to
uint16_t _port{0}; ///< The port to which the listener is bound to

/**
* @brief Private constructor of `ucxx::Listener`.
Expand Down
80 changes: 80 additions & 0 deletions cpp/include/ucxx/utils/callback_notifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <utility>

namespace ucxx {

namespace utils {

template <typename Flag>
class CallbackNotifier {
private:
Flag _flag{}; //< flag storing state
std::mutex _mutex{}; //< lock to guard accesses
std::condition_variable _conditionVariable{}; //< notification condition var

public:
/**
* @brief Construct a thread-safe notification object with given initial value.
*
* Construct a thread-safe notification object with a given initial value which may be
* later set via `store()` in one thread and block other threads running `wait()` while
* the new value is not set.
*
* @param[in] init The initial flag value
*/
explicit CallbackNotifier(Flag flag) : _flag{flag} {}

~CallbackNotifier() {}

CallbackNotifier(const CallbackNotifier&) = delete;
CallbackNotifier& operator=(CallbackNotifier const&) = delete;
CallbackNotifier(CallbackNotifier&& o) = delete;
CallbackNotifier& operator=(CallbackNotifier&& o) = delete;

/**
* @brief Store a new flag value and notify waiting threads.
*
* Store a new flag value and notify others threads blocked by a call to `wait()`.
* See also `std::condition_variable::notify_all`.
*
* @param[in] flag The new flag value.
*/
void store(Flag flag)
{
{
std::lock_guard lock(_mutex);
_flag = flag;
}
_conditionVariable.notify_all();
}

/**
* @brief Wait while predicate is not true for the flag value to change.
*
* Wait while predicate is not true which should be satisfied by a change in the flag's
* value by a `store()` call on a different thread.
*
* @param[in] compare Function of type `T -> bool` called with the flag value. This
* function loops until the predicate is satisfied. See also
* `std::condition_variable::wait`.
* @param[out] The new flag value.
*/
template <typename Compare>
Flag wait(Compare compare)
{
std::unique_lock lock(_mutex);
_conditionVariable.wait(lock, [this, &compare]() { return compare(_flag); });
return std::move(_flag);
}
};

} // namespace utils
//
} // namespace ucxx
51 changes: 48 additions & 3 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Worker : public Component {
std::shared_ptr<InflightRequests> _inflightRequestsToCancel{
std::make_shared<InflightRequests>()}; ///< The inflight requests scheduled to be canceled
std::shared_ptr<WorkerProgressThread> _progressThread{nullptr}; ///< The progress thread object
std::thread::id _progressThreadId{}; ///< The progress thread ID
std::function<void(void*)> _progressThreadStartCallback{
nullptr}; ///< The callback function to execute at progress thread start
void* _progressThreadStartCallbackArg{
Expand Down Expand Up @@ -111,8 +112,10 @@ class Worker : public Component {
* be canceled when necessary.
*
* @param[in] request the request to register.
*
* @return the request that was registered (i.e., the `request` argument itself).
*/
void registerInflightRequest(std::shared_ptr<Request> request);
std::shared_ptr<Request> registerInflightRequest(std::shared_ptr<Request> request);

/**
* @brief Progress the worker until all communication events are completed.
Expand Down Expand Up @@ -376,7 +379,7 @@ class Worker : public Component {
bool progress();

/**
* @brief Register delayed submission.
* @brief Register delayed request submission.
*
* Register `ucxx::Request` for delayed submission. When the `ucxx::Worker` is created
* with `enableDelayedSubmission=true`, calling actual UCX transfer routines will not
Expand All @@ -394,14 +397,47 @@ class Worker : public Component {
void registerDelayedSubmission(std::shared_ptr<Request> request,
DelayedSubmissionCallbackType callback);

/**
* @brief Register callback to be executed in progress thread before progressing.
*
* Register callback to be executed in the current or next iteration of the progress
* thread before the worker is progressed. There is no guarantee that the callback will
* be executed in the current or next iteration, this depends on where the progress thread
* is in the current iteration when this callback is registered. The lifetime of the
* callback must be ensured by the caller.
*
* The purpose of this method is to schedule operations to be executed in the progress
* thread, such as endpoint creation and closing, so that progressing doesn't ever need
* to occur in the application thread when using a progress thread.
*
* @param[in] callback the callback to execute before progressing the worker.
*/
void registerGenericPre(DelayedSubmissionCallbackType callback);

/**
* @brief Register callback to be executed in progress thread after progressing.
*
* Register callback to be executed in the current or next iteration of the progress
* thread after the worker is progressed. There is no guarantee that the callback will
* be executed in the current or next iteration, this depends on where the progress thread
* is in the current iteration when this callback is registered. The lifetime of the
* callback must be ensured by the caller.
*
* The purpose of this method is to schedule operations to be executed in the progress
* thread, immediately after progressing the worker completes.
*
* @param[in] callback the callback to execute after progressing the worker.
*/
void registerGenericPost(DelayedSubmissionCallbackType callback);

/**
* @brief Inquire if worker has been created with delayed submission enabled.
*
* Check whether the worker has been created with delayed submission enabled.
*
* @returns `true` if delayed submission is enabled, `false` otherwise.
*/
bool isDelayedSubmissionEnabled() const;
bool isDelayedRequestSubmissionEnabled() const;

/**
* @brief Inquire if worker has been created with future support.
Expand Down Expand Up @@ -513,6 +549,15 @@ class Worker : public Component {
*/
void stopProgressThread();

/**
* @brief Inquire if worker has a progress thread running.
*
* Check whether the worker currently has a progress thread running.
*
* @returns `true` if a progress thread is running, `false` otherwise.
*/
bool isProgressThreadRunning();

/**
* @brief Cancel inflight requests.
*
Expand Down
11 changes: 11 additions & 0 deletions cpp/include/ucxx/worker_progress_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class WorkerProgressThread {
nullptr}; ///< Callback to execute at start of the progress thread
ProgressThreadStartCallbackArg _startCallbackArg{
nullptr}; ///< Argument to pass to start callback
std::shared_ptr<DelayedSubmissionCollection> _delayedSubmissionCollection{
nullptr}; ///< Collection of enqueued delayed submissions

/**
* @brief The function executed in the new thread.
Expand Down Expand Up @@ -64,6 +66,13 @@ class WorkerProgressThread {
* made private to ensure all UCXX objects are shared pointers for correct
* lifetime management.
*
* This thread runs asynchronously with the main application thread. If you require
* cross-thread synchronization (for example when tearing down the thread or canceling
* requests), use the generic pre and post callbacks with a `CallbackNotifier` that
* synchronizes with the application thread. Since the worker progress itself may change
* state, it is usually the case that synchronization is needed in both pre and post
* callbacks.
*
* @code{.cpp}
* // context is `std::shared_ptr<ucxx::Context>`
* auto worker = context->createWorker(false);
Expand Down Expand Up @@ -103,6 +112,8 @@ class WorkerProgressThread {
* @returns Whether polling mode is enabled.
*/
bool pollingMode() const;

std::thread::id getId() const;
};

} // namespace ucxx
Loading

0 comments on commit 98b6b89

Please sign in to comment.