Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the rendevzous protocol code #61

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,20 @@ if(NOT LCI_WITH_LCT_ONLY)
endif()
endif()

set(LCI_RDV_PROTOCOL_DEFAULT
writeimm
CACHE STRING "The default rendezvous protocol to use (write, writeimm).")
set_property(CACHE LCI_RDV_PROTOCOL_DEFAULT PROPERTY STRINGS write writeimm)

mark_as_advanced(
LCI_CONFIG_USE_ALIGNED_ALLOC
LCI_PACKET_SIZE_DEFAULT
LCI_SERVER_MAX_SENDS_DEFAULT
LCI_SERVER_MAX_RECVS_DEFAULT
LCI_SERVER_MAX_CQES_DEFAULT
LCI_SERVER_NUM_PKTS_DEFAULT
LCI_CACHE_LINE)
LCI_CACHE_LINE
LCI_RDV_PROTOCOL_DEFAULT)

# ############################################################################
# LCI Testing related options
Expand Down
17 changes: 17 additions & 0 deletions lci/api/lci.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,23 @@ extern bool LCI_IBV_ENABLE_TD;
*/
extern bool LCI_ENABLE_PRG_NET_ENDPOINT;

/**
* @ingroup LCI_COMM
* @brief Rendezvous protocol to use.
*/
typedef enum {
LCI_RDV_WRITE,
LCI_RDV_WRITEIMM,
} LCI_rdv_protocol_t;
extern LCI_rdv_protocol_t LCI_RDV_PROTOCOL;

/**
* @ingroup
* @brief For the libfabric cxi provider, Try turning off the hacking to see
* whether cxi has fixed the double mr_bind error.
*/
extern bool LCI_OFI_CXI_TRY_NO_HACK;

/**
* @ingroup LCI_DEVICE
* @brief Default device initialized by LCI_initialize. Just for convenience.
Expand Down
1 change: 1 addition & 0 deletions lci/api/lci_config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#define LCI_CACHE_LINE @LCI_CACHE_LINE@
#cmakedefine01 LCI_IBV_ENABLE_TD_DEFAULT
#cmakedefine01 LCI_ENABLE_PRG_NET_ENDPOINT_DEFAULT
#define LCI_RDV_PROTOCOL_DEFAULT "@LCI_RDV_PROTOCOL_DEFAULT@"

#define LCI_CQ_MAX_POLL 16
#define LCI_SERVER_MAX_ENDPOINTS 8
Expand Down
21 changes: 21 additions & 0 deletions lci/backend/ofi/server_ofi.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
"inject_size (%lu) < sizeof(LCI_short_t) (%lu)!\n",
server->info->tx_attr->inject_size, sizeof(LCI_short_t));
fi_freeinfo(hints);
if (strcmp(server->info->fabric_attr->prov_name, "cxi") == 0) {
LCI_Assert(LCI_USE_DREG == 0,
"The registration cache should be turned off "
"for libfabric cxi backend. Use `export LCI_USE_DREG=0`.\n");
if (LCI_RDV_PROTOCOL != LCI_RDV_WRITE) {
LCI_RDV_PROTOCOL = LCI_RDV_WRITE;
LCI_Warn(
"Switch LCI_RDV_PROTOCOL to \"write\" "
"as required by the libfabric cxi backend\n");
}
}

// Create libfabric obj.
FI_SAFECALL(fi_fabric(server->info->fabric_attr, &server->fabric, NULL));
Expand Down Expand Up @@ -132,6 +143,16 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
endpoint_p;
endpoint_p->is_single_threaded = single_threaded;
LCIU_spinlock_init(&endpoint_p->lock);
if (!LCI_OFI_CXI_TRY_NO_HACK &&
strcmp(endpoint_p->server->info->fabric_attr->prov_name, "cxi") == 0 &&
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_ENDPOINT &&
endpoint_p->server->endpoint_count > 1) {
// We are using more than one endpoint per server, but the cxi provider
// can only bind mr to one endpoint. We have to guess here.
endpoint_p->server->cxi_mr_bind_hack = true;
} else {
endpoint_p->server->cxi_mr_bind_hack = false;
}
// Create end-point;
FI_SAFECALL(fi_endpoint(endpoint_p->server->domain, endpoint_p->server->info,
&endpoint_p->ep, NULL));
Expand Down
35 changes: 30 additions & 5 deletions lci/backend/ofi/server_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ typedef struct __attribute__((aligned(LCI_CACHE_LINE))) LCISI_server_t {
struct fid_domain* domain;
struct LCISI_endpoint_t* endpoints[LCI_SERVER_MAX_ENDPOINTS];
int endpoint_count;
bool cxi_mr_bind_hack;
} LCISI_server_t;

typedef struct __attribute__((aligned(LCI_CACHE_LINE))) LCISI_endpoint_t {
Expand Down Expand Up @@ -63,11 +64,11 @@ static inline void* LCISI_real_server_reg(LCIS_server_t s, void* buf,
&mr, 0));
if (server->info->domain_attr->mr_mode & FI_MR_ENDPOINT) {
LCI_DBG_Assert(server->endpoint_count >= 1, "No endpoints available!\n");
if (strcmp(server->info->fabric_attr->prov_name, "cxi") == 0) {
// A temporary fix for the cxi provider. It appears cxi cannot bind a
// memory region to more than one endpoint, but other endpoints can still
// use this memory region to send and recv messages.
FI_SAFECALL(fi_mr_bind(mr, &server->endpoints[0]->ep->fid, 0));
if (server->cxi_mr_bind_hack) {
// A temporary fix for the cxi provider, currently cxi cannot bind a
// memory region to more than one endpoint.
FI_SAFECALL(fi_mr_bind(
mr, &server->endpoints[server->endpoint_count - 1]->ep->fid, 0));
} else {
for (int i = 0; i < server->endpoint_count; ++i) {
FI_SAFECALL(fi_mr_bind(mr, &server->endpoints[i]->ep->fid, 0));
Expand Down Expand Up @@ -227,6 +228,12 @@ static inline LCI_error_t LCISD_post_puts(LCIS_endpoint_t endpoint_pp, int rank,
LCIS_rkey_t rkey)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down Expand Up @@ -260,6 +267,12 @@ static inline LCI_error_t LCISD_post_put(LCIS_endpoint_t endpoint_pp, int rank,
LCIS_rkey_t rkey, void* ctx)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but an unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down Expand Up @@ -294,6 +307,12 @@ static inline LCI_error_t LCISD_post_putImms(LCIS_endpoint_t endpoint_pp,
LCIS_rkey_t rkey, uint32_t meta)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but an unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down Expand Up @@ -329,6 +348,12 @@ static inline LCI_error_t LCISD_post_putImm(LCIS_endpoint_t endpoint_pp,
void* ctx)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but an unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down
2 changes: 0 additions & 2 deletions lci/experimental/coll/coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ static inline void LCIXC_mcoll_complete(LCI_endpoint_t ep, LCI_mbuffer_t buffer,
{
LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_msg_type(ctx->comp_attr, LCI_MSG_NONE);
LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->msg_comp_type);
ctx->data_type = LCI_MEDIUM;
ctx->user_context = user_context;
Expand All @@ -158,7 +157,6 @@ static inline void LCIXC_lcoll_complete(LCI_endpoint_t ep, LCI_lbuffer_t buffer,
{
LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_msg_type(ctx->comp_attr, LCI_MSG_NONE);
LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->msg_comp_type);
ctx->data_type = LCI_LONG;
ctx->user_context = user_context;
Expand Down
4 changes: 2 additions & 2 deletions lci/profile/performance_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ extern LCT_pcounter_ctx_t LCII_pcounter_ctx;
_macro(cq_push_timer) \
_macro(cq_pop_timer) \
_macro(serve_rts_timer) \
_macro(rts_mem_reg_timer) \
_macro(rts_mem_timer) \
_macro(rts_send_timer) \
_macro(serve_rtr_timer) \
_macro(rtr_mem_reg_timer) \
_macro(rtr_putimm_timer) \
_macro(rtr_put_timer) \
_macro(serve_rdma_timer) \
_macro(packet_stealing_timer) \
_macro(mem_reg_timer) \
Expand Down
26 changes: 13 additions & 13 deletions lci/runtime/1sided_primitive.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LCI_error_t LCI_puts(LCI_endpoint_t ep, LCI_short_t src, int rank,
}

LCI_error_t LCI_putm(LCI_endpoint_t ep, LCI_mbuffer_t mbuffer, int rank,
LCI_tag_t tag, LCI_lbuffer_t remote_buffer,
LCI_tag_t tag, LCI_lbuffer_t rbuffer,
uintptr_t remote_completion)
{
// LC_POOL_GET_OR_RETN(ep->pkpool, p);
Expand Down Expand Up @@ -68,7 +68,6 @@ LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_msg_type(ctx->comp_attr, LCI_MSG_RDMA_MEDIUM);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

ret = LCIS_post_send(
Expand Down Expand Up @@ -111,7 +110,6 @@ LCI_error_t LCI_putmna(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_msg_type(ctx->comp_attr, LCI_MSG_RDMA_MEDIUM);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

LCI_error_t ret = LCIS_post_send(
Expand All @@ -134,7 +132,7 @@ LCI_error_t LCI_putmna(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,

LCI_error_t LCI_putl(LCI_endpoint_t ep, LCI_lbuffer_t local_buffer,
LCI_comp_t local_completion, int rank, LCI_tag_t tag,
LCI_lbuffer_t remote_buffer, uintptr_t remote_completion)
LCI_lbuffer_t rbuffer, uintptr_t remote_completion)
{
return LCI_ERR_FEATURE_NA;
}
Expand All @@ -157,12 +155,11 @@ LCI_error_t LCI_putla(LCI_endpoint_t ep, LCI_lbuffer_t buffer,
// no packet is available
return LCI_ERR_RETRY;
}
packet->context.poolid = -1;
packet->context.poolid = LCII_POOLID_LOCAL;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_msg_type(rts_ctx->comp_attr, LCI_MSG_RTS);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

LCII_context_t* rdv_ctx = LCIU_malloc(sizeof(LCII_context_t));
Expand All @@ -172,18 +169,18 @@ LCI_error_t LCI_putla(LCI_endpoint_t ep, LCI_lbuffer_t buffer,
rdv_ctx->tag = tag;
rdv_ctx->user_context = user_context;
LCII_initilize_comp_attr(rdv_ctx->comp_attr);
LCII_comp_attr_set_msg_type(rdv_ctx->comp_attr, LCI_MSG_RDMA_LONG);
LCII_comp_attr_set_rdv_type(rdv_ctx->comp_attr, LCII_RDV_1SIDED);
LCII_comp_attr_set_comp_type(rdv_ctx->comp_attr, ep->cmd_comp_type);
LCII_comp_attr_set_dereg(rdv_ctx->comp_attr,
buffer.segment == LCI_SEGMENT_ALL);
rdv_ctx->completion = completion;

packet->data.rts.msg_type = LCI_MSG_RDMA_LONG;
packet->data.rts.rdv_type = LCII_RDV_1SIDED;
packet->data.rts.send_ctx = (uintptr_t)rdv_ctx;
packet->data.rts.size = buffer.length;

LCI_DBG_Log(LCI_LOG_TRACE, "rdv", "send rts: type %d sctx %p size %lu\n",
packet->data.rts.msg_type, (void*)packet->data.rts.send_ctx,
packet->data.rts.rdv_type, (void*)packet->data.rts.send_ctx,
packet->data.rts.size);
LCI_error_t ret = LCIS_post_send(
ep->device->endpoint_worker->endpoint, rank, packet->data.address,
Expand Down Expand Up @@ -250,7 +247,6 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec,
LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_msg_type(rts_ctx->comp_attr, LCI_MSG_RTS);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

LCII_context_t* rdv_ctx = LCIU_malloc(sizeof(LCII_context_t));
Expand All @@ -260,11 +256,15 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec,
rdv_ctx->tag = tag;
rdv_ctx->user_context = user_context;
LCII_initilize_comp_attr(rdv_ctx->comp_attr);
LCII_comp_attr_set_msg_type(rdv_ctx->comp_attr, LCI_MSG_IOVEC);
LCII_comp_attr_set_rdv_type(rdv_ctx->comp_attr, LCII_RDV_IOVEC);
LCII_comp_attr_set_comp_type(rdv_ctx->comp_attr, ep->cmd_comp_type);
// Currently, for iovec, if one buffer uses LCI_SEGMENT_ALL,
// all buffers need to use LCI_SEGMENT_ALL
LCII_comp_attr_set_dereg(rdv_ctx->comp_attr,
iovec.lbuffers[0].segment == LCI_SEGMENT_ALL);
rdv_ctx->completion = completion;

packet->data.rts.msg_type = LCI_MSG_IOVEC;
packet->data.rts.rdv_type = LCII_RDV_IOVEC;
packet->data.rts.send_ctx = (uintptr_t)rdv_ctx;
packet->data.rts.count = iovec.count;
packet->data.rts.piggy_back_size = iovec.piggy_back.length;
Expand All @@ -277,7 +277,7 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec,
LCI_DBG_Log(LCI_LOG_TRACE, "rdv",
"send rts: type %d sctx %p count %d "
"piggy_back_size %lu\n",
packet->data.rts.msg_type, (void*)packet->data.rts.send_ctx,
packet->data.rts.rdv_type, (void*)packet->data.rts.send_ctx,
packet->data.rts.count, packet->data.rts.piggy_back_size);
size_t length = (uintptr_t)&packet->data.rts.size_p[iovec.count] -
(uintptr_t)packet->data.address + iovec.piggy_back.length;
Expand Down
Loading