Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-0.33' into branch-0.34-…
Browse files Browse the repository at this point in the history
…merge-branch-0.33
  • Loading branch information
pentschev committed Aug 7, 2023
2 parents 42e5aeb + 98b6b89 commit f1d6099
Show file tree
Hide file tree
Showing 31 changed files with 655 additions and 221 deletions.
5 changes: 4 additions & 1 deletion conda/recipes/ucxx/conda_build_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ python:
ucx:
- 1.14.0

gtest_version:
gmock:
- ">=1.13.0"

gtest:
- ">=1.13.0"
2 changes: 1 addition & 1 deletion conda/recipes/ucxx/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ requirements:
- ucx
- python
- librmm =23.10
- gtest {{ gtest_version }}
- gtest

outputs:
- name: libucxx
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/ucxx/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,6 @@ class RMMBuffer : public Buffer {
};
#endif

Buffer* allocateBuffer(BufferType bufferType, const size_t size);
std::shared_ptr<Buffer> allocateBuffer(BufferType bufferType, const size_t size);

} // namespace ucxx
74 changes: 63 additions & 11 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ class DelayedSubmission {

class DelayedSubmissionCollection {
private:
std::vector<DelayedSubmissionCallbackType>
_genericPre{}; ///< The collection of all known generic pre-progress operations.
std::vector<DelayedSubmissionCallbackType>
_genericPost{}; ///< The collection of all known generic post-progress operations.
std::vector<std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>>
_collection{}; ///< The collection of all known delayed submission operations.
_requests{}; ///< The collection of all known delayed request submission operations.
std::mutex _mutex{}; ///< Mutex to provide access to the collection.
bool _enableDelayedRequestSubmission{false};

public:
/**
Expand All @@ -69,36 +74,83 @@ class DelayedSubmissionCollection {
* Construct an empty collection of delayed submissions. Despite its name, a delayed
* submission registration may be processed right after registration, thus effectively
* making it an immediate submission.
*
* @param[in] enableDelayedRequestSubmission whether request submission should be
* enabled, if `false`, only generic
* callbacks are enabled.
*/
DelayedSubmissionCollection() = default;
explicit DelayedSubmissionCollection(bool enableDelayedRequestSubmission = false);

DelayedSubmissionCollection() = delete;
DelayedSubmissionCollection(const DelayedSubmissionCollection&) = delete;
DelayedSubmissionCollection& operator=(DelayedSubmissionCollection const&) = delete;
DelayedSubmissionCollection(DelayedSubmissionCollection&& o) = delete;
DelayedSubmissionCollection& operator=(DelayedSubmissionCollection&& o) = delete;

/**
* @brief Process all pending delayed submission operations.
* @brief Process pending delayed request submission and generic-pre callback operations.
*
* Process all pending delayed request submissions and generic callbacks. Generic
* callbacks are deemed completed when their execution completes. On the other hand, the
* execution of the delayed request submission callbacks does not imply completion of the
* operation, only that it has been submitted. The completion of each delayed request
* submission is handled externally by the implementation of the object being processed,
* for example by checking the result of `ucxx::Request::isCompleted()`.
*/
void processPre();

/**
* @brief Process all pending generic-post callback operations.
*
* Process all pending delayed submissions and execute their callbacks. The execution
* of the callbacks does not imply completion of the operation, only that it has been
* submitted. The completion of each operation is handled externally by the
* implementation of the object being processed, for example by checking the result
* of `ucxx::Request::isCompleted()`.
* Process all pending generic-post callbacks. Generic callbacks are deemed completed when
* their execution completes.
*/
void process();
void processPost();

/**
* @brief Register a request for delayed submission.
*
* Register a request for delayed submission with a callback that will be executed when
* the request is in fact submitted when `process()` is called.
* the request is in fact submitted when `processPre()` is called.
*
* @throws std::runtime_error if delayed request submission was disabled at construction.
*
* @param[in] request the request to which the callback belongs, ensuring it remains
* alive until the callback is invoked.
* @param[in] callback the callback that will be executed by `process()` when the
* @param[in] callback the callback that will be executed by `processPre()` when the
* operation is submitted.
*/
void registerRequest(std::shared_ptr<Request> request, DelayedSubmissionCallbackType callback);

/**
* @brief Register a generic callback to execute during `processPre()`.
*
* Register a generic callback that will be executed when `processPre()` is called.
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*/
void registerGenericPre(DelayedSubmissionCallbackType callback);

/**
* @brief Register a generic callback to execute during `processPost()`.
*
* Register a generic callback that will be executed when `processPost()` is called.
* Lifetime of the callback must be ensured by the caller.
*
* @param[in] callback the callback that will be executed by `processPre()`.
*/
void registerGenericPost(DelayedSubmissionCallbackType callback);

/**
* @brief Inquire if delayed request submission is enabled.
*
* Check whether delayed submission request is enabled, in which case `registerRequest()`
* may be used to register requests that will be executed during `processPre()`.
*
* @returns `true` if a delayed request submission is enabled, `false` otherwise.
*/
bool isDelayedRequestSubmissionEnabled() const;
};

} // namespace ucxx
9 changes: 3 additions & 6 deletions cpp/include/ucxx/listener.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@

namespace ucxx {

void ucpListenerDestructor(ucp_listener_h ptr);

class Listener : public Component {
private:
std::unique_ptr<ucp_listener, void (*)(ucp_listener_h)> _handle{
nullptr, ucpListenerDestructor}; ///< The UCP listener handle
std::string _ip{}; ///< The IP address to which the listener is bound to
uint16_t _port{0}; ///< The port to which the listener is bound to
ucp_listener_h _handle{nullptr}; ///< The UCP listener handle
std::string _ip{}; ///< The IP address to which the listener is bound to
uint16_t _port{0}; ///< The port to which the listener is bound to

/**
* @brief Private constructor of `ucxx::Listener`.
Expand Down
13 changes: 7 additions & 6 deletions cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ namespace ucxx {

class Request : public Component {
protected:
std::atomic<ucs_status_t> _status{UCS_INPROGRESS}; ///< Requests status
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
std::shared_ptr<Future> _future{nullptr}; ///< Future to notify upon completion
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
std::shared_ptr<Future> _future{nullptr}; ///< Future to notify upon completion
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
std::shared_ptr<Worker> _worker{
nullptr}; ///< Worker that generated request (if not from endpoint)
std::shared_ptr<Endpoint> _endpoint{
Expand All @@ -39,6 +39,7 @@ class Request : public Component {
nullptr}; ///< The submission object that will dispatch the request
std::string _operationName{
"request_undefined"}; ///< Human-readable operation name, mostly used for log messages
std::recursive_mutex _mutex{}; ///< Mutex to prevent checking status while it's being set
bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request

/**
Expand Down
19 changes: 15 additions & 4 deletions cpp/include/ucxx/request_tag_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ class RequestTagMulti;
struct BufferRequest {
std::shared_ptr<Request> request{nullptr}; ///< The `ucxx::RequestTag` of a header or frame
std::shared_ptr<std::string> stringBuffer{nullptr}; ///< Serialized `Header`
Buffer* buffer{nullptr}; ///< Internally allocated buffer to receive a frame
std::shared_ptr<Buffer> buffer{nullptr}; ///< Internally allocated buffer to receive a frame

BufferRequest();
~BufferRequest();

BufferRequest(const BufferRequest&) = delete;
BufferRequest& operator=(BufferRequest const&) = delete;
BufferRequest(BufferRequest&& o) = delete;
BufferRequest& operator=(BufferRequest&& o) = delete;
};

typedef std::shared_ptr<BufferRequest> BufferRequestPtr;
Expand All @@ -34,8 +42,8 @@ class RequestTagMulti : public Request {
ucp_tag_t _tag{0}; ///< Tag to match
size_t _totalFrames{0}; ///< The total number of frames handled by this request
std::mutex
_completedRequestsMutex{}; ///< Mutex to control access to completed requests container
std::vector<BufferRequest*> _completedRequests{}; ///< Requests that already completed
_completedRequestsMutex{}; ///< Mutex to control access to completed requests container
size_t _completedRequests{0}; ///< Count requests that already completed

public:
std::vector<BufferRequestPtr> _bufferRequests{}; ///< Container of all requests posted
Expand Down Expand Up @@ -191,7 +199,10 @@ class RequestTagMulti : public Request {
*
* When this method is called, the request that completed will be pushed into a container
* which will be later used to evaluate if all frames completed and set the final status
* of the multi-transfer request and the Python future, if enabled.
* of the multi-transfer request and the Python future, if enabled. The final status is
* either `UCS_OK` if all underlying requests completed successfully, otherwise it will
* contain the status of the first failing request, for granular information the user
* may still verify each of the underlying requests individually.
*
* @param[in] status the status of the request being completed.
* @param[in] request the `ucxx::BufferRequest` object containing a single tag .
Expand Down
80 changes: 80 additions & 0 deletions cpp/include/ucxx/utils/callback_notifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <utility>

namespace ucxx {

namespace utils {

template <typename Flag>
class CallbackNotifier {
private:
Flag _flag{}; //< flag storing state
std::mutex _mutex{}; //< lock to guard accesses
std::condition_variable _conditionVariable{}; //< notification condition var

public:
/**
* @brief Construct a thread-safe notification object with given initial value.
*
* Construct a thread-safe notification object with a given initial value which may be
* later set via `store()` in one thread and block other threads running `wait()` while
* the new value is not set.
*
* @param[in] init The initial flag value
*/
explicit CallbackNotifier(Flag flag) : _flag{flag} {}

~CallbackNotifier() {}

CallbackNotifier(const CallbackNotifier&) = delete;
CallbackNotifier& operator=(CallbackNotifier const&) = delete;
CallbackNotifier(CallbackNotifier&& o) = delete;
CallbackNotifier& operator=(CallbackNotifier&& o) = delete;

/**
* @brief Store a new flag value and notify waiting threads.
*
* Store a new flag value and notify others threads blocked by a call to `wait()`.
* See also `std::condition_variable::notify_all`.
*
* @param[in] flag The new flag value.
*/
void store(Flag flag)
{
{
std::lock_guard lock(_mutex);
_flag = flag;
}
_conditionVariable.notify_all();
}

/**
* @brief Wait while predicate is not true for the flag value to change.
*
* Wait while predicate is not true which should be satisfied by a change in the flag's
* value by a `store()` call on a different thread.
*
* @param[in] compare Function of type `T -> bool` called with the flag value. This
* function loops until the predicate is satisfied. See also
* `std::condition_variable::wait`.
* @param[out] The new flag value.
*/
template <typename Compare>
Flag wait(Compare compare)
{
std::unique_lock lock(_mutex);
_conditionVariable.wait(lock, [this, &compare]() { return compare(_flag); });
return std::move(_flag);
}
};

} // namespace utils
//
} // namespace ucxx
51 changes: 48 additions & 3 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Worker : public Component {
std::shared_ptr<InflightRequests> _inflightRequestsToCancel{
std::make_shared<InflightRequests>()}; ///< The inflight requests scheduled to be canceled
std::shared_ptr<WorkerProgressThread> _progressThread{nullptr}; ///< The progress thread object
std::thread::id _progressThreadId{}; ///< The progress thread ID
std::function<void(void*)> _progressThreadStartCallback{
nullptr}; ///< The callback function to execute at progress thread start
void* _progressThreadStartCallbackArg{
Expand Down Expand Up @@ -111,8 +112,10 @@ class Worker : public Component {
* be canceled when necessary.
*
* @param[in] request the request to register.
*
* @return the request that was registered (i.e., the `request` argument itself).
*/
void registerInflightRequest(std::shared_ptr<Request> request);
std::shared_ptr<Request> registerInflightRequest(std::shared_ptr<Request> request);

/**
* @brief Progress the worker until all communication events are completed.
Expand Down Expand Up @@ -376,7 +379,7 @@ class Worker : public Component {
bool progress();

/**
* @brief Register delayed submission.
* @brief Register delayed request submission.
*
* Register `ucxx::Request` for delayed submission. When the `ucxx::Worker` is created
* with `enableDelayedSubmission=true`, calling actual UCX transfer routines will not
Expand All @@ -394,14 +397,47 @@ class Worker : public Component {
void registerDelayedSubmission(std::shared_ptr<Request> request,
DelayedSubmissionCallbackType callback);

/**
* @brief Register callback to be executed in progress thread before progressing.
*
* Register callback to be executed in the current or next iteration of the progress
* thread before the worker is progressed. There is no guarantee that the callback will
* be executed in the current or next iteration, this depends on where the progress thread
* is in the current iteration when this callback is registered. The lifetime of the
* callback must be ensured by the caller.
*
* The purpose of this method is to schedule operations to be executed in the progress
* thread, such as endpoint creation and closing, so that progressing doesn't ever need
* to occur in the application thread when using a progress thread.
*
* @param[in] callback the callback to execute before progressing the worker.
*/
void registerGenericPre(DelayedSubmissionCallbackType callback);

/**
* @brief Register callback to be executed in progress thread after progressing.
*
* Register callback to be executed in the current or next iteration of the progress
* thread after the worker is progressed. There is no guarantee that the callback will
* be executed in the current or next iteration, this depends on where the progress thread
* is in the current iteration when this callback is registered. The lifetime of the
* callback must be ensured by the caller.
*
* The purpose of this method is to schedule operations to be executed in the progress
* thread, immediately after progressing the worker completes.
*
* @param[in] callback the callback to execute after progressing the worker.
*/
void registerGenericPost(DelayedSubmissionCallbackType callback);

/**
* @brief Inquire if worker has been created with delayed submission enabled.
*
* Check whether the worker has been created with delayed submission enabled.
*
* @returns `true` if delayed submission is enabled, `false` otherwise.
*/
bool isDelayedSubmissionEnabled() const;
bool isDelayedRequestSubmissionEnabled() const;

/**
* @brief Inquire if worker has been created with future support.
Expand Down Expand Up @@ -513,6 +549,15 @@ class Worker : public Component {
*/
void stopProgressThread();

/**
* @brief Inquire if worker has a progress thread running.
*
* Check whether the worker currently has a progress thread running.
*
* @returns `true` if a progress thread is running, `false` otherwise.
*/
bool isProgressThreadRunning();

/**
* @brief Cancel inflight requests.
*
Expand Down
Loading

0 comments on commit f1d6099

Please sign in to comment.