Skip to content

Commit

Permalink
small fixes and refactoring for the UCX backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Feb 17, 2024
1 parent 84d9c42 commit aed11fe
Show file tree
Hide file tree
Showing 12 changed files with 445 additions and 440 deletions.
33 changes: 14 additions & 19 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,23 @@ if(NOT LCI_WITH_LCT_ONLY)
"If using the ofi(libfabric) backend, provide a hint for the provider to use"
)

find_package(OFI)
find_package(IBV)
find_package(ucx)
if(IBV_FOUND AND OFI_FOUND)
if(LCI_SERVER STREQUAL "ofi")
set(FABRIC OFI)
else()
set(FABRIC IBV)
endif()
find_package(OFI)
find_package(UCX)
string(TOUPPER ${LCI_SERVER} LCI_SERVER_UPPER)
if(${LCI_SERVER_UPPER}_FOUND)
# If the user-specified server are found, just use it.
set(FABRIC ${LCI_SERVER_UPPER})
elseif(IBV_FOUND)
set(FABRIC IBV)
elseif(OFI_FOUND)
set(FABRIC OFI)
elseif(UCX_FOUND)
set(FABRIC UCX)
else()
message(FATAL_ERROR "Find neither libfabric nor libibverbs. Give up!")
message(FATAL_ERROR "Cannot find any servers. Give up!")
endif()
if(LCI_SERVER STREQUAL "ucx")
if(NOT ucx_FOUND)
message(FATAL_ERROR "ucx is chosen as network backend but not found!")
endif()
set(FABRIC ucx)
endif()
string(TOUPPER ${LCI_SERVER} LCI_SERVER_UPPER)

if(LCI_FORCE_SERVER AND NOT LCI_SERVER_UPPER STREQUAL FABRIC)
message(
FATAL_ERROR
Expand Down Expand Up @@ -269,10 +263,11 @@ if(NOT LCI_WITH_LCT_ONLY)
C_STANDARD 11
C_EXTENSIONS ON)
target_compile_definitions(LCI PRIVATE _GNU_SOURCE)
if(FABRIC STREQUAL ucx)
target_link_libraries(LCI PUBLIC Threads::Threads ${FABRIC}::ucp LCT)
target_link_libraries(LCI PUBLIC Threads::Threads LCT)
if(FABRIC STREQUAL UCX)
target_link_libraries(LCI PUBLIC ucx::ucp)
else()
target_link_libraries(LCI PUBLIC Threads::Threads ${FABRIC}::${FABRIC} LCT)
target_link_libraries(LCI PUBLIC ${FABRIC}::${FABRIC})
endif()
if(LCI_USE_AVX)
target_compile_options(LCI PUBLIC -mavx)
Expand Down
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,17 @@ make install
- This is the same across all the cmake projects.
- `LCI_DEBUG=ON/OFF`: Enable/disable the debug mode (more assertions and logs).
The default value is `OFF`.
- `LCI_SERVER=ibv/ofi`: Hint to which network backend to use. If both `ibv` and `ofi` are found, LCI will use the one
indicated by this variable. The default value is `ibv`. Typically, you don't need to
- `LCI_SERVER=ibv/ofi/ucx`: Hint to which network backend to use.
If the backend indicated by this variable are found, LCI will just use it.
Otherwise, LCI will use whatever are found with the priority `ibv` > `ofi` > `ucx`.
The default value is `ibv`. Typically, you don't need to
modify this variable as if `libibverbs` presents, it is likely to be the recommended one to use.
- `ibv`: libibverbs, typically for infiniband.
- `ofi`: libfabrics, for all other networks (slingshot-11, ethernet, shared memory).
- `ibv`: [libibverbs](https://github.com/linux-rdma/rdma-core/blob/master/Documentation/libibverbs.md),
typically for infiniband.
- `ofi`: [libfabrics](https://ofiwg.github.io/libfabric/),
for all other networks (slingshot-11, ethernet, shared memory).
- `ucx`: [UCX](https://openucx.org/).
Currently, the backend is in the experimental state.
- `LCI_FORCE_SERVER=ON/OFF`: Default value is `OFF`. If it is set to `ON`,
`LCI_SERVER` will not be treated as a hint but a requirement.
- `LCI_WITH_LCT_ONLY=ON/OFF`: Whether to only build LCT (The Lightweight Communication Tools).
Expand Down
3 changes: 2 additions & 1 deletion contrib/spack/packages/lci/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def is_positive_int(val):
except ValueError:
return val == 'auto'

variant('fabric', default='ibv', values=('ofi', 'ibv'), multi=False,
variant('fabric', default='ibv', values=('ofi', 'ibv', 'ucx'), multi=False,
description='Communication fabric')
variant('completion', default='sync,cq,am',
values=('sync', 'cq', 'am', 'glob'), multi=True,
Expand Down Expand Up @@ -82,6 +82,7 @@ def is_positive_int(val):
depends_on('[email protected]:', type='build')
depends_on('libfabric', when='fabric=ofi')
depends_on('rdma-core', when='fabric=ibv')
depends_on('ucx', when='fabric=ucx')
depends_on('mpi', when='default-pm=mpi')
depends_on('papi', when='+papi')
depends_on('doxygen', when='+docs')
Expand Down
19 changes: 14 additions & 5 deletions lci/api/lci.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,12 +593,25 @@ typedef enum {
extern LCI_rdv_protocol_t LCI_RDV_PROTOCOL;

/**
* @ingroup
* @ingroup LCI_COMM
* @brief For the libfabric cxi provider, Try turning off the hacking to see
* whether cxi has fixed the double mr_bind error.
*/
extern bool LCI_OFI_CXI_TRY_NO_HACK;

/**
* @ingroup LCI_COMM
* @brief For the UCX backend, use try_lock to wrap the ucx function calls.
*/
extern bool LCI_UCX_USE_TRY_LOCK;

/**
* @ingroup LCI_COMM
* @brief For the UCX backend, use blocking lock to wrap the ucx_progress
* function calls.
*/
extern bool LCI_UCX_PROGRESS_FOCUSED;

/**
* @ingroup LCI_COMM
* @brief Try_lock mode of network backend.
Expand Down Expand Up @@ -629,10 +642,6 @@ extern LCI_endpoint_t LCI_UR_ENDPOINT;
*/
extern LCI_comp_t LCI_UR_CQ;

extern bool LCI_UCX_USE_TRY_LOCK;

extern bool LCI_UCX_PROGRESS_FOCUSED;

/**
* @ingroup LCI_SETUP
* @brief Initialize the LCI runtime. No LCI calls are allowed to be called
Expand Down
2 changes: 0 additions & 2 deletions lci/api/lci_config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
#cmakedefine LCI_ENABLE_SLOWDOWN
#cmakedefine LCI_USE_PAPI
#cmakedefine01 LCI_USE_DREG_DEFAULT

#cmakedefine LCI_UCX_NO_PROGRESS_THREAD
#cmakedefine LCI_UCX_USE_SEGMENTED_PUT

#define LCI_PACKET_SIZE_DEFAULT @LCI_PACKET_SIZE_DEFAULT@
Expand Down
3 changes: 2 additions & 1 deletion lci/backend/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ typedef struct LCIS_mr_t {

#ifdef LCI_USE_SERVER_UCX
typedef struct {
char tmp[128];
uint64_t val[2];
} LCIS_rkey_t;
#else
typedef uint64_t LCIS_rkey_t;
Expand Down Expand Up @@ -95,6 +95,7 @@ static inline LCI_error_t LCISD_post_recv(LCIS_endpoint_t endpoint_pp,
#endif
#ifdef LCI_USE_SERVER_UCX
#include "backend/ucx/server_ucx.h"
#include "backend/ucx/lcisi_ucx_detail.h"
#endif

/* Wrapper functions */
Expand Down
27 changes: 27 additions & 0 deletions lci/backend/ucx/lcisi_ucx_detail.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef LCI_LCISI_UCX_DETAIL_H
#define LCI_LCISI_UCX_DETAIL_H

#include <ucp/api/ucp.h>

// Borrowed from UCX library
static ucs_status_t LCISI_wait_status_ptr(ucp_worker_h worker,
ucs_status_ptr_t status_ptr)
{
ucs_status_t status;

if (status_ptr == NULL) {
status = UCS_OK;
} else if (UCS_PTR_IS_PTR(status_ptr)) {
do {
ucp_worker_progress(worker);
status = ucp_request_test(status_ptr, NULL);
} while (status == UCS_INPROGRESS);
ucp_request_release(status_ptr);
} else {
status = UCS_PTR_STATUS(status_ptr);
}

return status;
}

#endif // LCI_LCISI_UCX_DETAIL_H
102 changes: 55 additions & 47 deletions lci/backend/ucx/server_ucx.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "runtime/lcii.h"
#include "backend/ucx/server_ucx.h"
#include "lcisi_ucx_detail.h"

#define ENCODED_LIMIT 8192 // length of buffer to store encoded ucp address during initialization, user can change it
#define ENCODED_LIMIT \
8192 // length of buffer to store encoded ucp address during initialization,
// user can change it
#define DECODED_LIMIT 8192

static int g_endpoint_num = 0;
Expand All @@ -13,9 +15,11 @@ static int g_endpoint_num = 0;
// it
void encode_ucp_address(char* my_addrs, int addrs_length, char* encoded_value)
{
// Encoding as hexdecimal at most doubles the length, so available length should be at least twice
// of the original length to avoid overflow
LCI_Assert(ENCODED_LIMIT >= 2 * addrs_length, "Buffer to store encoded address is too short! Use a higher ENCODED_LIMIT");
// Encoding as hexdecimal at most doubles the length, so available length
// should be at least twice of the original length to avoid overflow
LCI_Assert(ENCODED_LIMIT >= 2 * addrs_length,
"Buffer to store encoded address is too short! Use a higher "
"ENCODED_LIMIT");
int segs = (addrs_length + sizeof(uint64_t) - 1) / sizeof(uint64_t);
for (int i = 0; i < segs; i++) {
sprintf(encoded_value + 2 * i * sizeof(uint64_t), "%016lx",
Expand All @@ -28,7 +32,9 @@ void encode_ucp_address(char* my_addrs, int addrs_length, char* encoded_value)
void decode_ucp_address(char* encoded_addrs, char* decoded_addrs)
{
// Avoid overflow
LCI_Assert(DECODED_LIMIT >= strlen(encoded_addrs), "Buffer to store decoded address is too short! Use a higher DECODED_LIMIT");
LCI_Assert(DECODED_LIMIT >= strlen(encoded_addrs),
"Buffer to store decoded address is too short! Use a higher "
"DECODED_LIMIT");
int segs = (strlen(encoded_addrs) + 16 - 1) / 16;
char tmp_buf[17];
tmp_buf[16] = 0;
Expand Down Expand Up @@ -89,14 +95,13 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
server->device = device;

// Create server (ucp_context)
ucs_status_t status;
ucp_config_t* config;
status = ucp_config_read(NULL, NULL, &config);
UCX_SAFECALL(ucp_config_read(NULL, NULL, &config));
ucp_params_t params;
params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_AM;
ucp_context_h context;
status = ucp_init(&params, config, &context);
UCX_SAFECALL(ucp_init(&params, config, &context));
server->context = context;
server->endpoint_count = 0;
}
Expand All @@ -105,12 +110,11 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
// result in errors
void LCISD_server_fina(LCIS_server_t s)
{
// LCISI_server_t* server = (LCISI_server_t*)s;
// LCI_Assert(server->endpoint_count == 0, "Endpoint count is not zero
// (%d)\n",
// server->endpoint_count);
// ucp_cleanup(server->context);
// free(s);
LCISI_server_t* server = (LCISI_server_t*)s;
LCI_Assert(server->endpoint_count == 0, "Endpoint count is not zero (%d)\n",
server->endpoint_count);
ucp_cleanup(server->context);
LCIU_free(s);
}

void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
Expand All @@ -126,38 +130,39 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
// Create endpoint (ucp_worker)
ucp_worker_h worker;
ucp_worker_params_t params;
ucs_status_t status;
params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS;
params.field_mask =
UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS;
params.flags = UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK;
if (single_threaded) {
params.thread_mode = UCS_THREAD_MODE_SINGLE;
} else {
params.thread_mode = UCS_THREAD_MODE_MULTI;
}

status = ucp_worker_create(endpoint_p->server->context, &params, &worker);
LCI_Assert(status == UCS_OK, "Error in creating UCP worker!");
UCX_SAFECALL(
ucp_worker_create(endpoint_p->server->context, &params, &worker));
endpoint_p->worker = worker;

// Create lock
#ifdef LCI_ENABLE_MULTITHREAD_PROGRESS
LCIU_spinlock_init(&(endpoint_p->cq_lock));
printf("\nUsing multiple progress threads");
#endif
#ifdef LCI_ENABLE_MULTITHREAD_PROGRESS
LCIU_spinlock_init(&(endpoint_p->cq_lock));
#endif
if (LCI_UCX_USE_TRY_LOCK == true) {
LCIU_spinlock_init(&(endpoint_p->try_lock));
printf("\nUsing try lock for progress and send/recv");
if (LCI_UCX_PROGRESS_FOCUSED) printf("\nGiving priority to lock for progress thread");
LCIU_spinlock_init(&(endpoint_p->wrapper_lock));
LCI_Log(LCI_LOG_INFO, "ucx", "\nUsing try lock for progress and send/recv");
if (LCI_UCX_PROGRESS_FOCUSED)
LCI_Log(LCI_LOG_INFO, "ucx",
"\nGiving priority to lock for progress thread");
}
// Create completion queue
LCM_dq_init(&endpoint_p->completed_ops, 2 * LCI_PACKET_SIZE);
LCM_dq_init(&endpoint_p->cq, 2 * LCI_PACKET_SIZE);

// Exchange endpoint address
endpoint_p->peers = LCIU_malloc(sizeof(ucp_ep_h) * LCI_NUM_PROCESSES);
ucp_address_t* my_addrs;
size_t addrs_length;
status = ucp_worker_get_address(worker, &my_addrs, &addrs_length);
LCI_Assert(status == UCS_OK, "Error in getting worker address!");
UCX_SAFECALL(ucp_worker_get_address(worker, &my_addrs, &addrs_length));
endpoint_p->if_address = my_addrs;

// Publish worker address
// Worker address is encoded into a string of hex representation of original
Expand All @@ -176,7 +181,7 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
sprintf(seg_key, "LCI_SEG_%d_%d", endpoint_id, LCI_RANK);

// Encode the address
encode_ucp_address((char*)my_addrs, addrs_length, encoded_value);
encode_ucp_address((char*)my_addrs, (int)addrs_length, encoded_value);

// Publish address, get number of segments
size_t num_segments;
Expand All @@ -194,7 +199,6 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
memset(decoded_value, 0, DECODED_LIMIT);

for (int i = 0; i < LCI_NUM_PROCESSES; i++) {
size_t size;
// Create ucp endpoint to connect workers
ucp_ep_params_t ep_params;
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS |
Expand Down Expand Up @@ -245,21 +249,25 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
// result in errors
void LCISD_endpoint_fina(LCIS_endpoint_t endpoint_pp)
{
LCT_pmi_barrier();
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
int my_idx = --endpoint_p->server->endpoint_count;
LCI_Assert(endpoint_p->server->endpoints[my_idx] == endpoint_p,
"This is not me!\n");
endpoint_p->server->endpoints[my_idx] = NULL;
for (int i = 0; i < LCI_NUM_PROCESSES; i++) {
ucs_status_ptr_t status;
ucp_request_param_t params;
params.flags = UCP_EP_CLOSE_FLAG_FORCE;
status = ucp_ep_close_nbx((endpoint_p->peers)[i], &params);
}
LCT_pmi_barrier();
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
int my_idx = --endpoint_p->server->endpoint_count;
LCI_Assert(endpoint_p->server->endpoints[my_idx] == endpoint_p,
"This is not me!\n");
endpoint_p->server->endpoints[my_idx] = NULL;
for (int i = 0; i < LCI_NUM_PROCESSES; i++) {
ucp_request_param_t params;
// It seems the FORCE flag here is necessary, otherwise I will
// sometimes get the "Connection reset by remote peer" error
params.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
params.flags = UCP_EP_CLOSE_FLAG_FORCE;
ucs_status_ptr_t status_ptr;
status_ptr = ucp_ep_close_nbx((endpoint_p->peers)[i], &params);
UCX_SAFECALL(LCISI_wait_status_ptr(endpoint_p->worker, status_ptr));
}

// Should other ucp ep owned by other workers be destoryed?
ucp_worker_destroy(endpoint_p->worker);
LCM_dq_finalize(&(endpoint_p->completed_ops));
free(endpoint_pp);
ucp_worker_release_address(endpoint_p->worker, endpoint_p->if_address);
ucp_worker_destroy(endpoint_p->worker);
LCM_dq_finalize(&(endpoint_p->cq));
LCIU_free(endpoint_pp);
}
Loading

0 comments on commit aed11fe

Please sign in to comment.