Skip to content

Commit

Permalink
fix rma_reg; add UCX_SAFECALL; fix endpoint_fina
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Feb 15, 2024
1 parent 6fa21d8 commit e3f77cf
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 170 deletions.
1 change: 1 addition & 0 deletions lci/backend/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ static inline LCI_error_t LCISD_post_recv(LCIS_endpoint_t endpoint_pp,
#endif
#ifdef LCI_USE_SERVER_UCX
#include "backend/ucx/server_ucx.h"
#include "backend/ucx/lcisi_ucx_detail.h"
#endif

/* Wrapper functions */
Expand Down
27 changes: 27 additions & 0 deletions lci/backend/ucx/lcisi_ucx_detail.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef LCI_LCISI_UCX_DETAIL_H
#define LCI_LCISI_UCX_DETAIL_H

#include <ucp/api/ucp.h>

// Borrowed from UCX library
static ucs_status_t LCISI_wait_status_ptr(ucp_worker_h worker,
ucs_status_ptr_t status_ptr)
{
ucs_status_t status;

if (status_ptr == NULL) {
status = UCS_OK;
} else if (UCS_PTR_IS_PTR(status_ptr)) {
do {
ucp_worker_progress(worker);
status = ucp_request_test(status_ptr, NULL);
} while (status == UCS_INPROGRESS);
ucp_request_release(status_ptr);
} else {
status = UCS_PTR_STATUS(status_ptr);
}

return status;
}

#endif // LCI_LCISI_UCX_DETAIL_H
94 changes: 49 additions & 45 deletions lci/backend/ucx/server_ucx.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "runtime/lcii.h"
#include "backend/ucx/server_ucx.h"
#include "lcisi_ucx_detail.h"

#define ENCODED_LIMIT 8192 // length of buffer to store encoded ucp address during initialization, user can change it
#define ENCODED_LIMIT \
8192 // length of buffer to store encoded ucp address during initialization,
// user can change it
#define DECODED_LIMIT 8192

static int g_endpoint_num = 0;
Expand All @@ -13,9 +15,11 @@ static int g_endpoint_num = 0;
// it
void encode_ucp_address(char* my_addrs, int addrs_length, char* encoded_value)
{
// Encoding as hexdecimal at most doubles the length, so available length should be at least twice
// of the original length to avoid overflow
LCI_Assert(ENCODED_LIMIT >= 2 * addrs_length, "Buffer to store encoded address is too short! Use a higher ENCODED_LIMIT");
// Encoding as hexdecimal at most doubles the length, so available length
// should be at least twice of the original length to avoid overflow
LCI_Assert(ENCODED_LIMIT >= 2 * addrs_length,
"Buffer to store encoded address is too short! Use a higher "
"ENCODED_LIMIT");
int segs = (addrs_length + sizeof(uint64_t) - 1) / sizeof(uint64_t);
for (int i = 0; i < segs; i++) {
sprintf(encoded_value + 2 * i * sizeof(uint64_t), "%016lx",
Expand All @@ -28,7 +32,9 @@ void encode_ucp_address(char* my_addrs, int addrs_length, char* encoded_value)
void decode_ucp_address(char* encoded_addrs, char* decoded_addrs)
{
// Avoid overflow
LCI_Assert(DECODED_LIMIT >= strlen(encoded_addrs), "Buffer to store decoded address is too short! Use a higher DECODED_LIMIT");
LCI_Assert(DECODED_LIMIT >= strlen(encoded_addrs),
"Buffer to store decoded address is too short! Use a higher "
"DECODED_LIMIT");
int segs = (strlen(encoded_addrs) + 16 - 1) / 16;
char tmp_buf[17];
tmp_buf[16] = 0;
Expand Down Expand Up @@ -89,14 +95,13 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
server->device = device;

// Create server (ucp_context)
ucs_status_t status;
ucp_config_t* config;
status = ucp_config_read(NULL, NULL, &config);
UCX_SAFECALL(ucp_config_read(NULL, NULL, &config));
ucp_params_t params;
params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_AM;
ucp_context_h context;
status = ucp_init(&params, config, &context);
UCX_SAFECALL(ucp_init(&params, config, &context));
server->context = context;
server->endpoint_count = 0;
}
Expand All @@ -105,12 +110,11 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
// result in errors
void LCISD_server_fina(LCIS_server_t s)
{
// LCISI_server_t* server = (LCISI_server_t*)s;
// LCI_Assert(server->endpoint_count == 0, "Endpoint count is not zero
// (%d)\n",
// server->endpoint_count);
// ucp_cleanup(server->context);
// free(s);
LCISI_server_t* server = (LCISI_server_t*)s;
LCI_Assert(server->endpoint_count == 0, "Endpoint count is not zero (%d)\n",
server->endpoint_count);
ucp_cleanup(server->context);
LCIU_free(s);
}

void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
Expand All @@ -126,28 +130,29 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
// Create endpoint (ucp_worker)
ucp_worker_h worker;
ucp_worker_params_t params;
ucs_status_t status;
params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS;
params.field_mask =
UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS;
params.flags = UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK;
if (single_threaded) {
params.thread_mode = UCS_THREAD_MODE_SINGLE;
} else {
params.thread_mode = UCS_THREAD_MODE_MULTI;
}

status = ucp_worker_create(endpoint_p->server->context, &params, &worker);
LCI_Assert(status == UCS_OK, "Error in creating UCP worker!");
UCX_SAFECALL(
ucp_worker_create(endpoint_p->server->context, &params, &worker));
endpoint_p->worker = worker;

// Create lock
#ifdef LCI_ENABLE_MULTITHREAD_PROGRESS
LCIU_spinlock_init(&(endpoint_p->cq_lock));
printf("\nUsing multiple progress threads");
#endif
#ifdef LCI_ENABLE_MULTITHREAD_PROGRESS
LCIU_spinlock_init(&(endpoint_p->cq_lock));
#endif
if (LCI_UCX_USE_TRY_LOCK == true) {
LCIU_spinlock_init(&(endpoint_p->try_lock));
printf("\nUsing try lock for progress and send/recv");
if (LCI_UCX_PROGRESS_FOCUSED) printf("\nGiving priority to lock for progress thread");
LCI_Log(LCI_LOG_INFO, "ucx", "\nUsing try lock for progress and send/recv");
if (LCI_UCX_PROGRESS_FOCUSED)
LCI_Log(LCI_LOG_INFO, "ucx",
"\nGiving priority to lock for progress thread");
}
// Create completion queue
LCM_dq_init(&endpoint_p->completed_ops, 2 * LCI_PACKET_SIZE);
Expand All @@ -156,8 +161,7 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
endpoint_p->peers = LCIU_malloc(sizeof(ucp_ep_h) * LCI_NUM_PROCESSES);
ucp_address_t* my_addrs;
size_t addrs_length;
status = ucp_worker_get_address(worker, &my_addrs, &addrs_length);
LCI_Assert(status == UCS_OK, "Error in getting worker address!");
UCX_SAFECALL(ucp_worker_get_address(worker, &my_addrs, &addrs_length));

// Publish worker address
// Worker address is encoded into a string of hex representation of original
Expand All @@ -176,7 +180,7 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
sprintf(seg_key, "LCI_SEG_%d_%d", endpoint_id, LCI_RANK);

// Encode the address
encode_ucp_address((char*)my_addrs, addrs_length, encoded_value);
encode_ucp_address((char*)my_addrs, (int)addrs_length, encoded_value);

// Publish address, get number of segments
size_t num_segments;
Expand All @@ -194,7 +198,6 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
memset(decoded_value, 0, DECODED_LIMIT);

for (int i = 0; i < LCI_NUM_PROCESSES; i++) {
size_t size;
// Create ucp endpoint to connect workers
ucp_ep_params_t ep_params;
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS |
Expand Down Expand Up @@ -245,21 +248,22 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
// result in errors
void LCISD_endpoint_fina(LCIS_endpoint_t endpoint_pp)
{
LCT_pmi_barrier();
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
int my_idx = --endpoint_p->server->endpoint_count;
LCI_Assert(endpoint_p->server->endpoints[my_idx] == endpoint_p,
"This is not me!\n");
endpoint_p->server->endpoints[my_idx] = NULL;
for (int i = 0; i < LCI_NUM_PROCESSES; i++) {
ucs_status_ptr_t status;
ucp_request_param_t params;
params.flags = UCP_EP_CLOSE_FLAG_FORCE;
status = ucp_ep_close_nbx((endpoint_p->peers)[i], &params);
}
LCT_pmi_barrier();
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
int my_idx = --endpoint_p->server->endpoint_count;
LCI_Assert(endpoint_p->server->endpoints[my_idx] == endpoint_p,
"This is not me!\n");
endpoint_p->server->endpoints[my_idx] = NULL;
for (int i = 0; i < LCI_NUM_PROCESSES; i++) {
ucp_request_param_t params;
params.flags = UCP_EP_CLOSE_FLAG_FORCE;
ucs_status_ptr_t status_ptr;
status_ptr = ucp_ep_close_nbx((endpoint_p->peers)[i], &params);
UCX_SAFECALL(LCISI_wait_status_ptr(endpoint_p->worker, status_ptr));
}

// Should other ucp ep owned by other workers be destoryed?
ucp_worker_destroy(endpoint_p->worker);
LCM_dq_finalize(&(endpoint_p->completed_ops));
free(endpoint_pp);
// Should other ucp ep owned by other workers be destoryed?
ucp_worker_destroy(endpoint_p->worker);
LCM_dq_finalize(&(endpoint_p->completed_ops));
LCIU_free(endpoint_pp);
}
Loading

0 comments on commit e3f77cf

Please sign in to comment.