Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new 2sided primitive, sendmc: send medium messages with completion objects #64

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions lci/api/lci.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,26 @@ LCI_error_t LCI_endpoint_free(LCI_endpoint_t* ep_ptr);
LCI_API
LCI_error_t LCI_sends(LCI_endpoint_t ep, LCI_short_t src, int rank,
LCI_tag_t tag);
/**
* @ingroup LCI_COMM
* @brief Send a medium message with a user-provided buffer (up to
* LCI_MEDIUM_SIZE bytes). The send buffer can be reused after completion
* notification.
* @param [in] ep The endpoint to post this send to.
* @param [in] buffer The buffer to send.
* @param [in] rank The rank of the destination process.
* @param [in] tag The tag of this message.
* @param [in] completion The completion object to be associated with.
* @param [in] user_context Arbitrary data the user want to attach to this
* operation. It will be returned the user through the completion object.
* @return LCI_OK if the send succeeds. LCI_ERR_RETRY if the send fails due to
* temporarily unavailable resources. All the other errors are fatal as defined
* by @ref LCI_error_t.
*/
LCI_API
LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag, LCI_comp_t completion,
void* user_context);
/**
* @ingroup LCI_COMM
* @brief Send a medium message with a user-provided buffer (up to
Expand Down
44 changes: 40 additions & 4 deletions lci/backend/ofi/server_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,28 @@ static inline LCI_error_t LCISD_post_puts(LCIS_endpoint_t endpoint_pp, int rank,
} else {
addr = offset;
}
struct fi_msg_rma msg;
struct iovec iov;
struct fi_rma_iov riov;
iov.iov_base = buf;
iov.iov_len = size;
msg.msg_iov = &iov;
msg.desc = NULL;
msg.iov_count = 1;
msg.addr = endpoint_p->peer_addrs[rank];
riov.addr = addr;
riov.len = size;
riov.key = rkey;
msg.rma_iov = &riov;
msg.rma_iov_count = 1;
msg.context = NULL;
msg.data = 0;
LCISI_OFI_CS_TRY_ENTER(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND,
LCI_ERR_RETRY_LOCK)
ssize_t ret = fi_inject_write(endpoint_p->ep, buf, size,
endpoint_p->peer_addrs[rank], addr, rkey);
// ssize_t ret = fi_inject_write(endpoint_p->ep, buf, size,
// endpoint_p->peer_addrs[rank], addr, rkey);
ssize_t ret =
fi_writemsg(endpoint_p->ep, &msg, FI_INJECT | FI_DELIVERY_COMPLETE);
LCISI_OFI_CS_EXIT(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND)
if (ret == FI_SUCCESS)
return LCI_OK;
Expand Down Expand Up @@ -280,10 +298,28 @@ static inline LCI_error_t LCISD_post_put(LCIS_endpoint_t endpoint_pp, int rank,
} else {
addr = offset;
}
struct fi_msg_rma msg;
struct iovec iov;
struct fi_rma_iov riov;
void* desc = ofi_rma_lkey(mr);
iov.iov_base = buf;
iov.iov_len = size;
msg.msg_iov = &iov;
msg.desc = &desc;
msg.iov_count = 1;
msg.addr = endpoint_p->peer_addrs[rank];
riov.addr = addr;
riov.len = size;
riov.key = rkey;
msg.rma_iov = &riov;
msg.rma_iov_count = 1;
msg.context = ctx;
msg.data = 0;
LCISI_OFI_CS_TRY_ENTER(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND,
LCI_ERR_RETRY_LOCK)
ssize_t ret = fi_write(endpoint_p->ep, buf, size, ofi_rma_lkey(mr),
endpoint_p->peer_addrs[rank], addr, rkey, ctx);
// ssize_t ret = fi_write(endpoint_p->ep, buf, size, ofi_rma_lkey(mr),
// endpoint_p->peer_addrs[rank], addr, rkey, ctx);
ssize_t ret = fi_writemsg(endpoint_p->ep, &msg, FI_DELIVERY_COMPLETE);
LCISI_OFI_CS_EXIT(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND)
if (ret == FI_SUCCESS)
return LCI_OK;
Expand Down
4 changes: 2 additions & 2 deletions lci/experimental/coll/coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static inline void LCIXC_mcoll_complete(LCI_endpoint_t ep, LCI_mbuffer_t buffer,
LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->msg_comp_type);
ctx->data_type = LCI_MEDIUM;
ctx->user_context = user_context;
ctx->data = (LCI_data_t){.mbuffer = buffer};
ctx->data.mbuffer = buffer;
ctx->rank = -1; /* this doesn't make much sense for collectives */
ctx->tag = tag;
ctx->completion = completion;
Expand All @@ -160,7 +160,7 @@ static inline void LCIXC_lcoll_complete(LCI_endpoint_t ep, LCI_lbuffer_t buffer,
LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->msg_comp_type);
ctx->data_type = LCI_LONG;
ctx->user_context = user_context;
ctx->data = (LCI_data_t){.lbuffer = buffer};
ctx->data.lbuffer = buffer;
ctx->rank = -1; /* this doesn't make much sense for collectives */
ctx->tag = tag;
ctx->completion = completion;
Expand Down
8 changes: 4 additions & 4 deletions lci/runtime/1sided_primitive.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
memcpy(packet->data.address, buffer.address, buffer.length);

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
ctx->data.packet = packet;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

Expand Down Expand Up @@ -108,7 +108,7 @@ LCI_error_t LCI_putmna(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
: -1;

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
ctx->data.packet = packet;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

Expand Down Expand Up @@ -158,7 +158,7 @@ LCI_error_t LCI_putla(LCI_endpoint_t ep, LCI_lbuffer_t buffer,
packet->context.poolid = LCII_POOLID_LOCAL;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
rts_ctx->data.packet = packet;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

Expand Down Expand Up @@ -245,7 +245,7 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec,
: -1;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
rts_ctx->data.packet = packet;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

Expand Down
90 changes: 47 additions & 43 deletions lci/runtime/2sided_primitive.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,87 +17,91 @@ LCI_error_t LCI_sends(LCI_endpoint_t ep, LCI_short_t src, int rank,
return ret;
}

LCI_error_t LCI_sendm(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag)
LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag, LCI_comp_t completion, void* user_context)
{
LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag,
LCI_MAX_TAG);
LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE,
"buffer is too large %lu (maximum: %d)\n", buffer.length,
LCI_MEDIUM_SIZE);
LCI_error_t ret = LCI_OK;
if (buffer.length <= LCI_SHORT_SIZE) {
bool is_user_provided_packet = LCII_is_packet(ep->device, buffer.address);
if (completion == NULL && buffer.length <= LCI_SHORT_SIZE) {
/* if data is this short, we will be able to inline it
* no reason to get a packet, allocate a ctx, etc */
ret = LCIS_post_sends(ep->device->endpoint_worker->endpoint, rank,
buffer.address, buffer.length,
LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag));
if (ret == LCI_OK && is_user_provided_packet) {
LCII_packet_t* packet = LCII_mbuffer2packet(buffer);
packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD)
? lc_pool_get_local(ep->pkpool)
: -1;
LCII_free_packet(packet);
}
} else {
LCII_packet_t* packet = LCII_alloc_packet_nb(ep->pkpool);
if (packet == NULL) {
// no packet is available
return LCI_ERR_RETRY;
LCII_packet_t* packet;
if (is_user_provided_packet) {
packet = LCII_mbuffer2packet(buffer);
} else {
packet = LCII_alloc_packet_nb(ep->pkpool);
if (packet == NULL) {
// no packet is available
return LCI_ERR_RETRY;
}
memcpy(packet->data.address, buffer.address, buffer.length);
}
packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD)
? lc_pool_get_local(ep->pkpool)
: -1;
memcpy(packet->data.address, buffer.address, buffer.length);

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
ctx->data.packet = packet;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);
if (!(is_user_provided_packet && completion)) {
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);
}
if (completion) {
ctx->data_type = LCI_MEDIUM;
ctx->data.mbuffer = buffer;
ctx->rank = rank;
ctx->tag = tag;
ctx->user_context = user_context;
LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->cmd_comp_type);
ctx->completion = completion;
}

ret = LCIS_post_send(ep->device->endpoint_worker->endpoint, rank,
packet->data.address, buffer.length,
ep->device->heap.segment->mr,
LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag), ctx);
if (ret == LCI_ERR_RETRY) {
LCII_free_packet(packet);
if (!is_user_provided_packet) LCII_free_packet(packet);
LCIU_free(ctx);
}
}
if (ret == LCI_OK) {
LCII_PCOUNTER_ADD(send, (int64_t)buffer.length);
}
LCI_DBG_Log(LCI_LOG_TRACE, "comm",
"LCI_sendm(ep %p, buffer {%p, %lu}, rank %d, tag %u) -> %d\n", ep,
buffer.address, buffer.length, rank, tag, ret);
"LCI_sendmc(ep %p, buffer {%p, %lu}(%d), rank %d, tag %u, "
"completion %p, user_context %p) -> %d\n",
ep, buffer.address, buffer.length, is_user_provided_packet, rank,
tag, ret, completion, user_context);
return ret;
}

LCI_error_t LCI_sendm(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag)
{
return LCI_sendmc(ep, buffer, rank, tag, NULL, NULL);
}

LCI_error_t LCI_sendmn(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag)
{
LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag,
LCI_MAX_TAG);
LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE,
"buffer is too large %lu (maximum: %d)\n", buffer.length,
LCI_MEDIUM_SIZE);
LCII_packet_t* packet = LCII_mbuffer2packet(buffer);
packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD)
? lc_pool_get_local(ep->pkpool)
: -1;

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

LCI_error_t ret = LCIS_post_send(
ep->device->endpoint_worker->endpoint, rank, packet->data.address,
buffer.length, ep->device->heap.segment->mr,
LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag), ctx);
if (ret == LCI_ERR_RETRY) {
LCIU_free(ctx);
}
if (ret == LCI_OK) {
LCII_PCOUNTER_ADD(send, (int64_t)buffer.length);
}
LCI_DBG_Log(LCI_LOG_TRACE, "comm",
"LCI_sendmn(ep %p, buffer {%p, %lu}, rank %d, tag %u) -> %d\n",
ep, buffer.address, buffer.length, rank, tag, ret);
return ret;
return LCI_sendmc(ep, buffer, rank, tag, NULL, NULL);
}

LCI_error_t LCI_sendl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank,
Expand All @@ -116,7 +120,7 @@ LCI_error_t LCI_sendl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank,
packet->context.poolid = LCII_POOLID_LOCAL;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
rts_ctx->data.packet = packet;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

Expand Down
3 changes: 1 addition & 2 deletions lci/runtime/completion/sync_flag.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ LCI_error_t LCI_sync_signal(LCI_comp_t completion, LCI_request_t request)
ctx->rank = request.rank;
ctx->tag = request.tag;
ctx->data_type = request.type;
ctx->data = request.data;
memcpy(&ctx->data, &request.data, sizeof(ctx->data));
ctx->user_context = request.user_context;

LCII_sync_signal(completion, ctx);
return LCI_OK;
}
Expand Down
9 changes: 6 additions & 3 deletions lci/runtime/device.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ LCI_error_t LCI_device_init(LCI_device_t* device_ptr)
LCI_Assert(ret == LCI_OK, "Device heap memory allocation failed\n");
uintptr_t base_addr = (uintptr_t)device->heap.address;

uintptr_t base_packet;
LCI_Assert(sizeof(struct LCII_packet_context) <= LCI_CACHE_LINE,
"Unexpected packet_context size\n");
base_packet = base_addr + LCI_CACHE_LINE - sizeof(struct LCII_packet_context);
device->base_packet =
base_addr + LCI_CACHE_LINE - sizeof(struct LCII_packet_context);
LCI_Assert(LCI_PACKET_SIZE % LCI_CACHE_LINE == 0,
"The size of packets should be a multiple of cache line size\n");

LCII_pool_create(&device->pkpool);
for (size_t i = 0; i < LCI_SERVER_NUM_PKTS; i++) {
LCII_packet_t* packet = (LCII_packet_t*)(base_packet + i * LCI_PACKET_SIZE);
LCII_packet_t* packet =
(LCII_packet_t*)(device->base_packet + i * LCI_PACKET_SIZE);
LCI_Assert(((uint64_t) & (packet->data)) % LCI_CACHE_LINE == 0,
"packet.data is not well-aligned\n");
LCI_Assert(LCII_is_packet(device, packet->data.address),
"Not a packet. The computation is wrong!\n");
packet->context.pkpool = device->pkpool;
packet->context.poolid = 0;
#ifdef LCI_DEBUG
Expand Down
7 changes: 5 additions & 2 deletions lci/runtime/lci.c
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ LCI_error_t LCII_barrier()
while (LCI_sync_test(sync, NULL) != LCI_OK) {
LCI_progress(LCI_UR_DEVICE);
}
LCI_sync_free(&sync);
// Phase 2: rank 0 send a message to all the other ranks.
for (int i = 1; i < LCI_NUM_PROCESSES; ++i) {
while (LCI_sendm(ep, buffer, i, tag) != LCI_OK)
while (LCI_sendmc(ep, buffer, i, tag, sync, NULL) != LCI_OK)
LCI_progress(LCI_UR_DEVICE);
}
while (LCI_sync_test(sync, NULL) != LCI_OK) {
LCI_progress(LCI_UR_DEVICE);
}
LCI_sync_free(&sync);
}
LCI_Log(LCI_LOG_INFO, "coll", "End barrier (%d, %p).\n", tag, ep);
return LCI_OK;
Expand Down
26 changes: 21 additions & 5 deletions lci/runtime/lcii.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ struct __attribute__((aligned(LCI_CACHE_LINE))) LCI_device_s {
LCI_matchtable_t mt; // 8B
LCII_rcache_t rcache; // 8B
LCI_lbuffer_t heap; // 24B
uintptr_t base_packet; // 8B
LCIU_CACHE_PADDING(sizeof(LCIS_server_t) + 2 * sizeof(LCIS_endpoint_t) -
sizeof(LCII_pool_t*) + sizeof(LCI_matchtable_t) -
sizeof(LCII_rcache_t*) + sizeof(LCI_lbuffer_t));
sizeof(LCII_rcache_t*) + sizeof(LCI_lbuffer_t) +
sizeof(uintptr_t));
// the following is shared by both progress threads and worker threads
LCM_archive_t ctx_archive; // used for long message protocol
LCIU_CACHE_PADDING(sizeof(LCM_archive_t));
Expand Down Expand Up @@ -137,9 +139,17 @@ typedef struct __attribute__((aligned(LCI_CACHE_LINE))) {
// LCI_request_t fields, 52 bytes
LCI_data_type_t data_type; // 4 bytes
void* user_context; // 8 bytes
LCI_data_t data; // 32 bytes
uint32_t rank; // 4 bytes
LCI_tag_t tag; // 4 bytes
union {
LCI_short_t immediate; // 32 bytes
struct { // 24 bytes
LCI_mbuffer_t mbuffer;
LCII_packet_t* packet;
};
LCI_lbuffer_t lbuffer; // 24 bytes
LCI_iovec_t iovec; // 28 bytes
} data; // 32 bytes
uint32_t rank; // 4 bytes
LCI_tag_t tag; // 4 bytes
// used by LCI internally
LCI_comp_t completion; // 8 bytes
#ifdef LCI_USE_PERFORMANCE_COUNTER
Expand Down Expand Up @@ -200,8 +210,14 @@ static inline LCI_request_t LCII_ctx2req(LCII_context_t* ctx)
.rank = ctx->rank,
.tag = ctx->tag,
.type = ctx->data_type,
.data = ctx->data,
.user_context = ctx->user_context};
LCI_DBG_Assert(sizeof(request.data) == sizeof(ctx->data),
"Unexpected size!\n");
memcpy(&request.data, &ctx->data, sizeof(request.data));
LCI_DBG_Assert(request.data.mbuffer.address == ctx->data.mbuffer.address,
"Invalid conversion!");
LCI_DBG_Assert(request.data.mbuffer.length == ctx->data.mbuffer.length,
"Invalid conversion!");
LCIU_free(ctx);
return request;
}
Expand Down
1 change: 1 addition & 0 deletions lci/runtime/memory_registration.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ LCI_error_t LCI_mbuffer_alloc(LCI_device_t device, LCI_mbuffer_t* mbuffer)

mbuffer->address = packet->data.address;
mbuffer->length = LCI_MEDIUM_SIZE;
LCI_DBG_Assert(LCII_is_packet(device, mbuffer->address), "");
return LCI_OK;
}

Expand Down
Loading