From 113cfb14a8a27b549be84385c51bfbe24021a30e Mon Sep 17 00:00:00 2001 From: Jiakun Yan Date: Thu, 19 Dec 2024 15:54:40 -0600 Subject: [PATCH] add options for device lock mode --- lci/api/lci.h | 12 ++++++++++++ lci/runtime/1sided_primitive.c | 13 +++++++++++++ lci/runtime/2sided_primitive.c | 17 +++++++++++++++++ lci/runtime/device.c | 2 ++ lci/runtime/env.c | 20 ++++++++++++++++++++ lci/runtime/lcii.h | 20 ++++++++++++++++++++ lci/runtime/progress.c | 4 +++- 7 files changed, 87 insertions(+), 1 deletion(-) diff --git a/lci/api/lci.h b/lci/api/lci.h index ab2d6bdc..383914bd 100644 --- a/lci/api/lci.h +++ b/lci/api/lci.h @@ -644,6 +644,18 @@ typedef enum { } LCI_backend_try_lock_mode_t; extern uint64_t LCI_BACKEND_TRY_LOCK_MODE; +/** + * @ingroup LCI_COMM + * @brief Try_lock mode of LCI runtime. + */ +typedef enum { + LCI_DEVICE_LOCK_MODE_NONE, + LCI_DEVICE_LOCK_MODE_BLOCK, + LCI_DEVICE_LOCK_MODE_TRY, + LCI_DEVICE_LOCK_MODE_MAX, +} LCI_DEVICE_LOCK_MODE_t; +extern uint64_t LCI_DEVICE_LOCK_MODE; + /** * @ingroup LCI_DEVICE * @brief Default device initialized by LCI_initialize. Just for convenience. diff --git a/lci/runtime/1sided_primitive.c b/lci/runtime/1sided_primitive.c index 49db3933..2c87bab5 100644 --- a/lci/runtime/1sided_primitive.c +++ b/lci/runtime/1sided_primitive.c @@ -10,9 +10,11 @@ LCI_error_t LCI_puts(LCI_endpoint_t ep, LCI_short_t src, int rank, "Only support default remote completion " "(set by LCI_plist_set_default_comp, " "the default value is LCI_UR_CQ)\n"); + LCII_DEVICE_CS_ENTER(ep->device); LCI_error_t ret = LCIS_post_sends( ep->device->endpoint_worker->endpoint, rank, &src, sizeof(LCI_short_t), LCII_MAKE_PROTO(ep->gid, LCI_MSG_RDMA_SHORT, tag)); + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { LCII_PCOUNTER_ADD(put, sizeof(LCI_short_t)); } @@ -48,6 +50,7 @@ LCI_error_t LCI_putmac(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, "Only support default remote completion " "(set by LCI_plist_set_default_comp, " "the default value is LCI_UR_CQ)\n"); + LCII_DEVICE_CS_ENTER(ep->device); LCI_error_t ret = LCI_OK; bool is_user_provided_packet = LCII_is_packet(ep->device->heap, buffer.address); @@ -70,6 +73,7 @@ LCI_error_t LCI_putmac(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, packet = LCII_alloc_packet_nb(ep->pkpool); if (packet == NULL) { // no packet is available + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } memcpy(packet->data.address, buffer.address, buffer.length); @@ -102,6 +106,7 @@ LCI_error_t LCI_putmac(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCIU_free(ctx); } } + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { LCII_PCOUNTER_ADD(put, (int64_t)buffer.length); } @@ -143,12 +148,15 @@ LCI_error_t LCI_putla(LCI_endpoint_t ep, LCI_lbuffer_t buffer, "Only support default remote completion " "(set by LCI_plist_set_default_comp, " "the default value is LCI_UR_CQ)\n"); + LCII_DEVICE_CS_ENTER(ep->device); if (!LCII_bq_is_empty(ep->bq_p)) { + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } LCII_packet_t* packet = LCII_alloc_packet_nb(ep->pkpool); if (packet == NULL) { // no packet is available + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } packet->context.poolid = LCII_POOLID_LOCAL; @@ -187,6 +195,7 @@ LCI_error_t LCI_putla(LCI_endpoint_t ep, LCI_lbuffer_t buffer, LCIU_free(rts_ctx); LCIU_free(rdv_ctx); } + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { LCII_PCOUNTER_ADD(put, (int64_t)buffer.length); } @@ -217,6 +226,7 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec, iovec.piggy_back.length <= LCI_get_iovec_piggy_back_size(iovec.count), "iovec's piggy back is too large! (%lu > %lu)\n", iovec.piggy_back.length, LCI_get_iovec_piggy_back_size(iovec.count)); + LCII_DEVICE_CS_ENTER(ep->device); for (int i = 0; i < iovec.count; ++i) { LCI_DBG_Assert( (iovec.lbuffers[0].segment == LCI_SEGMENT_ALL && @@ -228,11 +238,13 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec, LCI_DBG_Assert(iovec.lbuffers[i].length > 0, "Invalid lbuffer length\n"); } if (!LCII_bq_is_empty(ep->bq_p)) { + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } LCII_packet_t* packet = LCII_alloc_packet_nb(ep->pkpool); if (packet == NULL) { // no packet is available + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } packet->context.poolid = @@ -286,6 +298,7 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec, LCIU_free(rts_ctx); LCIU_free(rdv_ctx); } + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { uint64_t total_length = iovec.piggy_back.length; for (int i = 0; i < iovec.count; ++i) { diff --git a/lci/runtime/2sided_primitive.c b/lci/runtime/2sided_primitive.c index 4b063854..2fb42905 100644 --- a/lci/runtime/2sided_primitive.c +++ b/lci/runtime/2sided_primitive.c @@ -6,9 +6,11 @@ LCI_error_t LCI_sends(LCI_endpoint_t ep, LCI_short_t src, int rank, { LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, LCI_MAX_TAG); + LCII_DEVICE_CS_ENTER(ep->device); LCI_error_t ret = LCIS_post_sends( ep->device->endpoint_worker->endpoint, rank, &src, sizeof(LCI_short_t), LCII_MAKE_PROTO(ep->gid, LCI_MSG_SHORT, tag)); + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { LCII_PCOUNTER_ADD(send, sizeof(LCI_short_t)); } @@ -25,6 +27,7 @@ LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE, "buffer is too large %lu (maximum: %d)\n", buffer.length, LCI_MEDIUM_SIZE); + LCII_DEVICE_CS_ENTER(ep->device); LCI_error_t ret = LCI_OK; bool is_user_provided_packet = LCII_is_packet(ep->device->heap, buffer.address); @@ -47,6 +50,7 @@ LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, packet = LCII_alloc_packet_nb(ep->pkpool); if (packet == NULL) { // no packet is available + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } memcpy(packet->data.address, buffer.address, buffer.length); @@ -80,6 +84,7 @@ LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCIU_free(ctx); } } + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { LCII_PCOUNTER_ADD(send, (int64_t)buffer.length); } @@ -108,12 +113,15 @@ LCI_error_t LCI_sendl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank, { LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, LCI_MAX_TAG); + LCII_DEVICE_CS_ENTER(ep->device); if (!LCII_bq_is_empty(ep->bq_p)) { + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } LCII_packet_t* packet = LCII_alloc_packet_nb(ep->pkpool); if (packet == NULL) { // no packet is available + LCII_DEVICE_CS_EXIT(ep->device); return LCI_ERR_RETRY; } packet->context.poolid = LCII_POOLID_LOCAL; @@ -149,6 +157,7 @@ LCI_error_t LCI_sendl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank, LCIU_free(rts_ctx); LCIU_free(rdv_ctx); } + LCII_DEVICE_CS_EXIT(ep->device); if (ret == LCI_OK) { LCII_PCOUNTER_ADD(send, (int64_t)buffer.length); } @@ -165,6 +174,7 @@ LCI_error_t LCI_recvs(LCI_endpoint_t ep, int rank, LCI_tag_t tag, { LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, LCI_MAX_TAG); + LCII_DEVICE_CS_ENTER(ep->device); LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t)); ctx->data_type = LCI_IMMEDIATE; ctx->rank = rank; @@ -184,6 +194,7 @@ LCI_error_t LCI_recvs(LCI_endpoint_t ep, int rank, LCI_tag_t tag, LCII_free_packet(packet); lc_ce_dispatch(ctx); } + LCII_DEVICE_CS_EXIT(ep->device); LCII_PCOUNTER_ADD(recv, 1); LCI_DBG_Log(LCI_LOG_TRACE, "comm", "LCI_recvs(ep %p, rank %d, tag %u, completion %p, user_context " @@ -200,6 +211,7 @@ LCI_error_t LCI_recvm(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE, "buffer is too large %lu (maximum: %d)\n", buffer.length, LCI_MEDIUM_SIZE); + LCII_DEVICE_CS_ENTER(ep->device); LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t)); ctx->data.mbuffer = buffer; ctx->data_type = LCI_MEDIUM; @@ -223,6 +235,7 @@ LCI_error_t LCI_recvm(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCII_free_packet(packet); lc_ce_dispatch(ctx); } + LCII_DEVICE_CS_EXIT(ep->device); LCII_PCOUNTER_ADD(recv, 1); LCI_DBG_Log(LCI_LOG_TRACE, "comm", "LCI_recvm(ep %p, buffer {%p, %lu}, rank %d, tag %u, completion " @@ -237,6 +250,7 @@ LCI_error_t LCI_recvmn(LCI_endpoint_t ep, int rank, LCI_tag_t tag, { LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, LCI_MAX_TAG); + LCII_DEVICE_CS_ENTER(ep->device); LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t)); ctx->data.mbuffer.address = NULL; ctx->data_type = LCI_MEDIUM; @@ -258,6 +272,7 @@ LCI_error_t LCI_recvmn(LCI_endpoint_t ep, int rank, LCI_tag_t tag, ctx->data.mbuffer.address = packet->data.address; lc_ce_dispatch(ctx); } + LCII_DEVICE_CS_EXIT(ep->device); LCII_PCOUNTER_ADD(recv, 1); LCI_DBG_Log(LCI_LOG_TRACE, "comm", "LCI_recvmn(ep %p, rank %d, tag %u, completion %p, user_context " @@ -271,6 +286,7 @@ LCI_error_t LCI_recvl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank, { LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, LCI_MAX_TAG); + LCII_DEVICE_CS_ENTER(ep->device); LCII_context_t* rdv_ctx = LCIU_malloc(sizeof(LCII_context_t)); rdv_ctx->data.lbuffer = buffer; rdv_ctx->data_type = LCI_LONG; @@ -292,6 +308,7 @@ LCI_error_t LCI_recvl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank, LCII_packet_t* packet = (LCII_packet_t*)value; LCII_handle_rts(ep, packet, packet->context.src_rank, tag, rdv_ctx, false); } + LCII_DEVICE_CS_EXIT(ep->device); LCII_PCOUNTER_ADD(recv, 1); LCI_DBG_Log(LCI_LOG_TRACE, "comm", "LCI_recvl(ep %p, buffer {%p, %lu, %p}, rank %d, tag %u, " diff --git a/lci/runtime/device.c b/lci/runtime/device.c index 96c2d06a..2a016289 100644 --- a/lci/runtime/device.c +++ b/lci/runtime/device.c @@ -59,6 +59,7 @@ LCI_error_t LCI_device_init(LCI_device_t* device_ptr) LCM_archive_init(&(device->ctx_archive), 16); LCII_bq_init(&device->bq); LCIU_spinlock_init(&device->bq_spinlock); + LCIU_spinlock_init(&device->device_lock); if (LCI_USE_GLOBAL_PACKET_POOL) { device->heap = &g_heap; @@ -101,6 +102,7 @@ LCI_error_t LCI_device_free(LCI_device_t* device_ptr) LCII_matchtable_free(&device->mt); LCM_archive_fini(&(device->ctx_archive)); LCII_bq_fini(&device->bq); + LCIU_spinlock_fina(&device->device_lock); LCIU_spinlock_fina(&device->bq_spinlock); if (LCI_USE_DREG) { LCII_rcache_fina(device); diff --git a/lci/runtime/env.c b/lci/runtime/env.c index f0470868..673af1df 100644 --- a/lci/runtime/env.c +++ b/lci/runtime/env.c @@ -32,6 +32,7 @@ LCI_API bool LCI_ENABLE_PRG_NET_ENDPOINT; LCI_API LCI_rdv_protocol_t LCI_RDV_PROTOCOL; LCI_API bool LCI_OFI_CXI_TRY_NO_HACK; LCI_API uint64_t LCI_BACKEND_TRY_LOCK_MODE; +LCI_API uint64_t LCI_DEVICE_LOCK_MODE; LCI_API bool LCI_UCX_USE_TRY_LOCK; LCI_API bool LCI_UCX_PROGRESS_FOCUSED; LCI_API bool LCI_USE_GLOBAL_PACKET_POOL; @@ -128,6 +129,25 @@ void LCII_env_init(int num_proc, int rank) LCI_Log(LCI_LOG_INFO, "env", "set LCI_BACKEND_TRY_LOCK_MODE to be %d\n", LCI_BACKEND_TRY_LOCK_MODE); } + { + // default value + LCI_DEVICE_LOCK_MODE = 0; + // if users explicitly set the value + char* p = getenv("LCI_DEVICE_LOCK_MODE"); + if (p) { + LCT_dict_str_int_t dict[] = { + {"none", LCI_DEVICE_LOCK_MODE_NONE}, + {"try", LCI_DEVICE_LOCK_MODE_TRY}, + {"block", LCI_DEVICE_LOCK_MODE_BLOCK}, + }; + LCI_DEVICE_LOCK_MODE = + LCT_parse_arg(dict, sizeof(dict) / sizeof(dict[0]), p, ","); + } + LCI_Assert(LCI_DEVICE_LOCK_MODE < LCI_DEVICE_LOCK_MODE_MAX, + "Unexpected LCI_DEVICE_LOCK_MODE %d", LCI_DEVICE_LOCK_MODE); + LCI_Log(LCI_LOG_INFO, "env", "set LCI_DEVICE_LOCK_MODE to be %d\n", + LCI_DEVICE_LOCK_MODE); + } LCI_UCX_USE_TRY_LOCK = LCIU_getenv_or("LCI_UCX_USE_TRY_LOCK", 0); LCI_UCX_PROGRESS_FOCUSED = LCIU_getenv_or("LCI_UCX_PROGRESS_FOCUSED", 0); LCI_USE_GLOBAL_PACKET_POOL = LCIU_getenv_or("LCI_USE_GLOBAL_PACKET_POOL", 1); diff --git a/lci/runtime/lcii.h b/lci/runtime/lcii.h index 1302c2a0..d0f3d144 100644 --- a/lci/runtime/lcii.h +++ b/lci/runtime/lcii.h @@ -17,6 +17,7 @@ #include "backlog_queue.h" extern uint64_t LCI_PAGESIZE; + /* * used by * - LCII_MAKE_PROTO (4 bits) for communication immediate data field @@ -89,8 +90,27 @@ struct __attribute__((aligned(LCI_CACHE_LINE))) LCI_device_s { LCII_backlog_queue_t bq; LCIU_spinlock_t bq_spinlock; LCIU_CACHE_PADDING((sizeof(LCII_backlog_queue_t) + sizeof(LCIU_spinlock_t))); + LCIU_spinlock_t device_lock; // used for device lock + LCIU_CACHE_PADDING((sizeof(LCIU_spinlock_t))); }; +// device lock mode +#define LCII_DEVICE_CS_ENTER_PROGRESS(device_p, ret) \ + if (LCI_DEVICE_LOCK_MODE == LCI_DEVICE_LOCK_MODE_TRY && \ + !LCIU_try_acquire_spinlock(&device_p->device_lock)) \ + return ret; \ + else if (LCI_DEVICE_LOCK_MODE == LCI_DEVICE_LOCK_MODE_BLOCK) \ + LCIU_acquire_spinlock(&device_p->device_lock); + +#define LCII_DEVICE_CS_ENTER(device_p) \ + if (LCI_DEVICE_LOCK_MODE == LCI_DEVICE_LOCK_MODE_TRY || \ + LCI_DEVICE_LOCK_MODE == LCI_DEVICE_LOCK_MODE_BLOCK) \ + LCIU_acquire_spinlock(&device_p->device_lock); + +#define LCII_DEVICE_CS_EXIT(device_p) \ + if (LCI_DEVICE_LOCK_MODE != LCI_DEVICE_LOCK_MODE_NONE) \ + LCIU_release_spinlock(&device_p->device_lock); + struct LCI_plist_s { LCI_match_t match_type; // matching type LCI_comp_type_t cmd_comp_type; // source-side completion type diff --git a/lci/runtime/progress.c b/lci/runtime/progress.c index 79b02b94..2f6581a5 100644 --- a/lci/runtime/progress.c +++ b/lci/runtime/progress.c @@ -154,6 +154,8 @@ LCI_error_t LCII_fill_rq(LCII_endpoint_t* endpoint, bool block) LCI_error_t LCI_progress(LCI_device_t device) { int ret = LCI_ERR_RETRY; + LCII_PCOUNTER_ADD(progress_call, 1); + LCII_DEVICE_CS_ENTER_PROGRESS(device, ret); // we want to make progress on the endpoint_progress as much as possible // to speed up rendezvous protocol while (LCI_ENABLE_PRG_NET_ENDPOINT && @@ -175,6 +177,6 @@ LCI_error_t LCI_progress(LCI_device_t device) if (LCII_fill_rq(device->endpoint_worker, false) == LCI_OK) { ret = LCI_OK; } - LCII_PCOUNTER_ADD(progress_call, 1); + LCII_DEVICE_CS_EXIT(device); return ret; }