Skip to content

Commit

Permalink
add sendmc: send medium with completion notification
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Dec 20, 2023
1 parent aa8f5f4 commit aeb5bc6
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 70 deletions.
20 changes: 20 additions & 0 deletions lci/api/lci.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,26 @@ LCI_error_t LCI_endpoint_free(LCI_endpoint_t* ep_ptr);
LCI_API
LCI_error_t LCI_sends(LCI_endpoint_t ep, LCI_short_t src, int rank,
LCI_tag_t tag);
/**
* @ingroup LCI_COMM
* @brief Send a medium message with a user-provided buffer (up to
* LCI_MEDIUM_SIZE bytes). The send buffer can be reused after completion
* notification.
* @param [in] ep The endpoint to post this send to.
* @param [in] buffer The buffer to send.
* @param [in] rank The rank of the destination process.
* @param [in] tag The tag of this message.
* @param [in] completion The completion object to be associated with.
* @param [in] user_context Arbitrary data the user want to attach to this
* operation. It will be returned the user through the completion object.
* @return LCI_OK if the send succeeds. LCI_ERR_RETRY if the send fails due to
* temporarily unavailable resources. All the other errors are fatal as defined
* by @ref LCI_error_t.
*/
LCI_API
LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag, LCI_comp_t completion,
void* user_context);
/**
* @ingroup LCI_COMM
* @brief Send a medium message with a user-provided buffer (up to
Expand Down
44 changes: 40 additions & 4 deletions lci/backend/ofi/server_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,28 @@ static inline LCI_error_t LCISD_post_puts(LCIS_endpoint_t endpoint_pp, int rank,
} else {
addr = offset;
}
struct fi_msg_rma msg;
struct iovec iov;
struct fi_rma_iov riov;
iov.iov_base = buf;
iov.iov_len = size;
msg.msg_iov = &iov;
msg.desc = NULL;
msg.iov_count = 1;
msg.addr = endpoint_p->peer_addrs[rank];
riov.addr = addr;
riov.len = size;
riov.key = rkey;
msg.rma_iov = &riov;
msg.rma_iov_count = 1;
msg.context = NULL;
msg.data = 0;
LCISI_OFI_CS_TRY_ENTER(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND,
LCI_ERR_RETRY_LOCK)
ssize_t ret = fi_inject_write(endpoint_p->ep, buf, size,
endpoint_p->peer_addrs[rank], addr, rkey);
// ssize_t ret = fi_inject_write(endpoint_p->ep, buf, size,
// endpoint_p->peer_addrs[rank], addr, rkey);
ssize_t ret =
fi_writemsg(endpoint_p->ep, &msg, FI_INJECT | FI_DELIVERY_COMPLETE);
LCISI_OFI_CS_EXIT(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND)
if (ret == FI_SUCCESS)
return LCI_OK;
Expand Down Expand Up @@ -280,10 +298,28 @@ static inline LCI_error_t LCISD_post_put(LCIS_endpoint_t endpoint_pp, int rank,
} else {
addr = offset;
}
struct fi_msg_rma msg;
struct iovec iov;
struct fi_rma_iov riov;
void* desc = ofi_rma_lkey(mr);
iov.iov_base = buf;
iov.iov_len = size;
msg.msg_iov = &iov;
msg.desc = &desc;
msg.iov_count = 1;
msg.addr = endpoint_p->peer_addrs[rank];
riov.addr = addr;
riov.len = size;
riov.key = rkey;
msg.rma_iov = &riov;
msg.rma_iov_count = 1;
msg.context = ctx;
msg.data = 0;
LCISI_OFI_CS_TRY_ENTER(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND,
LCI_ERR_RETRY_LOCK)
ssize_t ret = fi_write(endpoint_p->ep, buf, size, ofi_rma_lkey(mr),
endpoint_p->peer_addrs[rank], addr, rkey, ctx);
// ssize_t ret = fi_write(endpoint_p->ep, buf, size, ofi_rma_lkey(mr),
// endpoint_p->peer_addrs[rank], addr, rkey, ctx);
ssize_t ret = fi_writemsg(endpoint_p->ep, &msg, FI_DELIVERY_COMPLETE);
LCISI_OFI_CS_EXIT(endpoint_p, LCI_BACKEND_TRY_LOCK_SEND)
if (ret == FI_SUCCESS)
return LCI_OK;
Expand Down
8 changes: 4 additions & 4 deletions lci/runtime/1sided_primitive.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
memcpy(packet->data.address, buffer.address, buffer.length);

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
ctx->data.packet = packet;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

Expand Down Expand Up @@ -108,7 +108,7 @@ LCI_error_t LCI_putmna(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
: -1;

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
ctx->data.packet = packet;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

Expand Down Expand Up @@ -158,7 +158,7 @@ LCI_error_t LCI_putla(LCI_endpoint_t ep, LCI_lbuffer_t buffer,
packet->context.poolid = LCII_POOLID_LOCAL;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
rts_ctx->data.packet = packet;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

Expand Down Expand Up @@ -245,7 +245,7 @@ LCI_error_t LCI_putva(LCI_endpoint_t ep, LCI_iovec_t iovec,
: -1;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
rts_ctx->data.packet = packet;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

Expand Down
90 changes: 47 additions & 43 deletions lci/runtime/2sided_primitive.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,87 +17,91 @@ LCI_error_t LCI_sends(LCI_endpoint_t ep, LCI_short_t src, int rank,
return ret;
}

LCI_error_t LCI_sendm(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag)
LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag, LCI_comp_t completion, void* user_context)
{
LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag,
LCI_MAX_TAG);
LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE,
"buffer is too large %lu (maximum: %d)\n", buffer.length,
LCI_MEDIUM_SIZE);
LCI_error_t ret = LCI_OK;
if (buffer.length <= LCI_SHORT_SIZE) {
bool is_user_provided_packet = LCII_is_packet(ep->device, buffer.address);
if (completion == NULL && buffer.length <= LCI_SHORT_SIZE) {
/* if data is this short, we will be able to inline it
* no reason to get a packet, allocate a ctx, etc */
ret = LCIS_post_sends(ep->device->endpoint_worker->endpoint, rank,
buffer.address, buffer.length,
LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag));
if (ret == LCI_OK && is_user_provided_packet) {
LCII_packet_t* packet = LCII_mbuffer2packet(buffer);
packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD)
? lc_pool_get_local(ep->pkpool)
: -1;
LCII_free_packet(packet);
}
} else {
LCII_packet_t* packet = LCII_alloc_packet_nb(ep->pkpool);
if (packet == NULL) {
// no packet is available
return LCI_ERR_RETRY;
LCII_packet_t* packet;
if (is_user_provided_packet) {
packet = LCII_mbuffer2packet(buffer);
} else {
packet = LCII_alloc_packet_nb(ep->pkpool);
if (packet == NULL) {
// no packet is available
return LCI_ERR_RETRY;
}
memcpy(packet->data.address, buffer.address, buffer.length);
}
packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD)
? lc_pool_get_local(ep->pkpool)
: -1;
memcpy(packet->data.address, buffer.address, buffer.length);

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
ctx->data.packet = packet;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);
if (!(is_user_provided_packet && completion)) {
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);
}
if (completion) {
ctx->data_type = LCI_MEDIUM;
ctx->data.mbuffer = buffer;
ctx->rank = rank;
ctx->tag = tag;
ctx->user_context = user_context;
LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->cmd_comp_type);
ctx->completion = completion;
}

ret = LCIS_post_send(ep->device->endpoint_worker->endpoint, rank,
packet->data.address, buffer.length,
ep->device->heap.segment->mr,
LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag), ctx);
if (ret == LCI_ERR_RETRY) {
LCII_free_packet(packet);
if (!is_user_provided_packet) LCII_free_packet(packet);
LCIU_free(ctx);
}
}
if (ret == LCI_OK) {
LCII_PCOUNTER_ADD(send, (int64_t)buffer.length);
}
LCI_DBG_Log(LCI_LOG_TRACE, "comm",
"LCI_sendm(ep %p, buffer {%p, %lu}, rank %d, tag %u) -> %d\n", ep,
buffer.address, buffer.length, rank, tag, ret);
"LCI_sendmc(ep %p, buffer {%p, %lu}(%d), rank %d, tag %u, "
"completion %p, user_context %p) -> %d\n",
ep, buffer.address, buffer.length, is_user_provided_packet, rank,
tag, ret, completion, user_context);
return ret;
}

LCI_error_t LCI_sendm(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag)
{
return LCI_sendmc(ep, buffer, rank, tag, NULL, NULL);
}

LCI_error_t LCI_sendmn(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank,
LCI_tag_t tag)
{
LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag,
LCI_MAX_TAG);
LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE,
"buffer is too large %lu (maximum: %d)\n", buffer.length,
LCI_MEDIUM_SIZE);
LCII_packet_t* packet = LCII_mbuffer2packet(buffer);
packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD)
? lc_pool_get_local(ep->pkpool)
: -1;

LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t));
ctx->data.mbuffer.address = (void*)packet->data.address;
LCII_initilize_comp_attr(ctx->comp_attr);
LCII_comp_attr_set_free_packet(ctx->comp_attr, 1);

LCI_error_t ret = LCIS_post_send(
ep->device->endpoint_worker->endpoint, rank, packet->data.address,
buffer.length, ep->device->heap.segment->mr,
LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag), ctx);
if (ret == LCI_ERR_RETRY) {
LCIU_free(ctx);
}
if (ret == LCI_OK) {
LCII_PCOUNTER_ADD(send, (int64_t)buffer.length);
}
LCI_DBG_Log(LCI_LOG_TRACE, "comm",
"LCI_sendmn(ep %p, buffer {%p, %lu}, rank %d, tag %u) -> %d\n",
ep, buffer.address, buffer.length, rank, tag, ret);
return ret;
return LCI_sendmc(ep, buffer, rank, tag, NULL, NULL);
}

LCI_error_t LCI_sendl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank,
Expand All @@ -116,7 +120,7 @@ LCI_error_t LCI_sendl(LCI_endpoint_t ep, LCI_lbuffer_t buffer, int rank,
packet->context.poolid = LCII_POOLID_LOCAL;

LCII_context_t* rts_ctx = LCIU_malloc(sizeof(LCII_context_t));
rts_ctx->data.mbuffer.address = (void*)packet->data.address;
rts_ctx->data.packet = packet;
LCII_initilize_comp_attr(rts_ctx->comp_attr);
LCII_comp_attr_set_free_packet(rts_ctx->comp_attr, 1);

Expand Down
9 changes: 6 additions & 3 deletions lci/runtime/device.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ LCI_error_t LCI_device_init(LCI_device_t* device_ptr)
LCI_Assert(ret == LCI_OK, "Device heap memory allocation failed\n");
uintptr_t base_addr = (uintptr_t)device->heap.address;

uintptr_t base_packet;
LCI_Assert(sizeof(struct LCII_packet_context) <= LCI_CACHE_LINE,
"Unexpected packet_context size\n");
base_packet = base_addr + LCI_CACHE_LINE - sizeof(struct LCII_packet_context);
device->base_packet =
base_addr + LCI_CACHE_LINE - sizeof(struct LCII_packet_context);
LCI_Assert(LCI_PACKET_SIZE % LCI_CACHE_LINE == 0,
"The size of packets should be a multiple of cache line size\n");

LCII_pool_create(&device->pkpool);
for (size_t i = 0; i < LCI_SERVER_NUM_PKTS; i++) {
LCII_packet_t* packet = (LCII_packet_t*)(base_packet + i * LCI_PACKET_SIZE);
LCII_packet_t* packet =
(LCII_packet_t*)(device->base_packet + i * LCI_PACKET_SIZE);
LCI_Assert(((uint64_t) & (packet->data)) % LCI_CACHE_LINE == 0,
"packet.data is not well-aligned\n");
LCI_Assert(LCII_is_packet(device, packet->data.address),
"Not a packet. The computation is wrong!\n");
packet->context.pkpool = device->pkpool;
packet->context.poolid = 0;
#ifdef LCI_DEBUG
Expand Down
7 changes: 5 additions & 2 deletions lci/runtime/lci.c
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ LCI_error_t LCII_barrier()
while (LCI_sync_test(sync, NULL) != LCI_OK) {
LCI_progress(LCI_UR_DEVICE);
}
LCI_sync_free(&sync);
// Phase 2: rank 0 send a message to all the other ranks.
for (int i = 1; i < LCI_NUM_PROCESSES; ++i) {
while (LCI_sendm(ep, buffer, i, tag) != LCI_OK)
while (LCI_sendmc(ep, buffer, i, tag, sync, NULL) != LCI_OK)
LCI_progress(LCI_UR_DEVICE);
}
while (LCI_sync_test(sync, NULL) != LCI_OK) {
LCI_progress(LCI_UR_DEVICE);
}
LCI_sync_free(&sync);
}
LCI_Log(LCI_LOG_INFO, "coll", "End barrier (%d, %p).\n", tag, ep);
return LCI_OK;
Expand Down
29 changes: 17 additions & 12 deletions lci/runtime/lcii.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ struct __attribute__((aligned(LCI_CACHE_LINE))) LCI_device_s {
LCI_matchtable_t mt; // 8B
LCII_rcache_t rcache; // 8B
LCI_lbuffer_t heap; // 24B
uintptr_t base_packet; // 8B
LCIU_CACHE_PADDING(sizeof(LCIS_server_t) + 2 * sizeof(LCIS_endpoint_t) -
sizeof(LCII_pool_t*) + sizeof(LCI_matchtable_t) -
sizeof(LCII_rcache_t*) + sizeof(LCI_lbuffer_t));
sizeof(LCII_rcache_t*) + sizeof(LCI_lbuffer_t) +
sizeof(uintptr_t));
// the following is shared by both progress threads and worker threads
LCM_archive_t ctx_archive; // used for long message protocol
LCIU_CACHE_PADDING(sizeof(LCM_archive_t));
Expand Down Expand Up @@ -138,16 +140,16 @@ typedef struct __attribute__((aligned(LCI_CACHE_LINE))) {
LCI_data_type_t data_type; // 4 bytes
void* user_context; // 8 bytes
union {
LCI_short_t immediate; // 32 bytes
struct { // 24 bytes
LCI_short_t immediate; // 32 bytes
struct { // 24 bytes
LCI_mbuffer_t mbuffer;
LCII_packet_t *packet;
LCII_packet_t* packet;
};
LCI_lbuffer_t lbuffer; // 24 bytes
LCI_iovec_t iovec; // 28 bytes
} data; // 32 bytes
uint32_t rank; // 4 bytes
LCI_tag_t tag; // 4 bytes
LCI_lbuffer_t lbuffer; // 24 bytes
LCI_iovec_t iovec; // 28 bytes
} data; // 32 bytes
uint32_t rank; // 4 bytes
LCI_tag_t tag; // 4 bytes
// used by LCI internally
LCI_comp_t completion; // 8 bytes
#ifdef LCI_USE_PERFORMANCE_COUNTER
Expand Down Expand Up @@ -209,10 +211,13 @@ static inline LCI_request_t LCII_ctx2req(LCII_context_t* ctx)
.tag = ctx->tag,
.type = ctx->data_type,
.user_context = ctx->user_context};
LCI_DBG_Assert(sizeof(request.data) == sizeof(ctx->data), "Unexpected size!\n");
LCI_DBG_Assert(sizeof(request.data) == sizeof(ctx->data),
"Unexpected size!\n");
memcpy(&request.data, &ctx->data, sizeof(request.data));
LCI_DBG_Assert(request.data.mbuffer.address == ctx->data.mbuffer.address, "Invalid conversion!");
LCI_DBG_Assert(request.data.mbuffer.length == ctx->data.mbuffer.length, "Invalid conversion!");
LCI_DBG_Assert(request.data.mbuffer.address == ctx->data.mbuffer.address,
"Invalid conversion!");
LCI_DBG_Assert(request.data.mbuffer.length == ctx->data.mbuffer.length,
"Invalid conversion!");
LCIU_free(ctx);
return request;
}
Expand Down
1 change: 1 addition & 0 deletions lci/runtime/memory_registration.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ LCI_error_t LCI_mbuffer_alloc(LCI_device_t device, LCI_mbuffer_t* mbuffer)

mbuffer->address = packet->data.address;
mbuffer->length = LCI_MEDIUM_SIZE;
LCI_DBG_Assert(LCII_is_packet(device, mbuffer->address), "");
return LCI_OK;
}

Expand Down
10 changes: 10 additions & 0 deletions lci/runtime/packet.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,14 @@ static inline LCII_packet_t* LCII_mbuffer2packet(LCI_mbuffer_t mbuffer)
return (LCII_packet_t*)(mbuffer.address - offsetof(LCII_packet_t, data));
}

static inline bool LCII_is_packet(LCI_device_t device, void* address)
{
void* packet_address =
(LCII_packet_t*)(address - offsetof(LCII_packet_t, data));
uintptr_t offset = (uintptr_t)packet_address - device->base_packet;
return (uintptr_t)packet_address >= device->base_packet &&
offset % LCI_PACKET_SIZE == 0 &&
offset / LCI_PACKET_SIZE < LCI_SERVER_NUM_PKTS;
}

#endif
2 changes: 1 addition & 1 deletion lci/runtime/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ static inline void LCIS_serve_send(void* raw_ctx)
return;
}
if (LCII_comp_attr_get_free_packet(ctx->comp_attr) == 1) {
LCII_free_packet(LCII_mbuffer2packet(ctx->data.mbuffer));
LCII_free_packet(ctx->data.packet);
}
lc_ce_dispatch(ctx);
}
Expand Down
Loading

0 comments on commit aeb5bc6

Please sign in to comment.