Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor multi-device low-level resource and heap allocation #72

Merged
merged 9 commits into from
Aug 20, 2024
9 changes: 2 additions & 7 deletions lci/api/lci.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,6 @@ typedef enum {
} LCI_rdv_protocol_t;
extern LCI_rdv_protocol_t LCI_RDV_PROTOCOL;

/**
* @ingroup LCI_COMM
* @brief For the libfabric cxi provider, Try turning off the hacking to see
* whether cxi has fixed the double mr_bind error.
*/
extern bool LCI_OFI_CXI_TRY_NO_HACK;

/**
* @ingroup LCI_COMM
* @brief For the UCX backend, use try_lock to wrap the ucx function calls.
Expand Down Expand Up @@ -692,6 +685,8 @@ LCI_error_t LCI_barrier();
*/
LCI_API
LCI_error_t LCI_device_init(LCI_device_t* device_ptr);
LCI_API
LCI_error_t LCI_device_initx(LCI_device_t* device_ptr);
/**
* @ingroup LCI_DEVICE
* @brief Initialize a device.
Expand Down
3 changes: 1 addition & 2 deletions lci/backend/ibv/server_ibv.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ void LCISI_event_polling_thread_fina(LCISI_server_t* server)
}
}

void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
void LCISD_server_init(LCIS_server_t* s)
{
LCISI_server_t* server = LCIU_malloc(sizeof(LCISI_server_t));
*s = (LCIS_server_t)server;
server->device = device;

int num_devices;
server->dev_list = ibv_get_device_list(&num_devices);
Expand Down
20 changes: 11 additions & 9 deletions lci/backend/ibv/server_ibv.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
;

typedef struct __attribute__((aligned(LCI_CACHE_LINE))) LCISI_server_t {
LCI_device_t device;

// Device fields.
struct ibv_device** dev_list;
struct ibv_device* ib_dev;
Expand Down Expand Up @@ -62,10 +60,11 @@ typedef struct LCISI_endpoint_t {
int qp2rank_mod;
} LCISI_endpoint_t;

static inline void* LCISI_real_server_reg(LCIS_server_t s, void* buf,
size_t size)
static inline void* LCISI_real_server_reg(LCIS_endpoint_t endpoint_pp,
void* buf, size_t size)
{
LCISI_server_t* server = (LCISI_server_t*)s;
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCISI_server_t* server = endpoint_p->server;
int mr_flags;
if (LCI_IBV_USE_ODP) {
mr_flags = IBV_ACCESS_ON_DEMAND | IBV_ACCESS_LOCAL_WRITE |
Expand All @@ -87,16 +86,18 @@ static inline uint32_t ibv_rma_lkey(LCIS_mr_t mr)
return ((struct ibv_mr*)mr.mr_p)->lkey;
}

static inline LCIS_mr_t LCISD_rma_reg(LCIS_server_t s, void* buf, size_t size)
static inline LCIS_mr_t LCISD_rma_reg(LCIS_endpoint_t endpoint_pp, void* buf,
size_t size)
{
LCISI_server_t* server = (LCISI_server_t*)s;
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCISI_server_t* server = endpoint_p->server;
LCIS_mr_t mr;
if (LCI_IBV_USE_ODP == 2) {
mr.mr_p = server->odp_mr;
mr.address = buf;
mr.length = size;
} else {
mr.mr_p = LCISI_real_server_reg(s, buf, size);
mr.mr_p = LCISI_real_server_reg(endpoint_pp, buf, size);
mr.address = buf;
mr.length = size;
}
Expand Down Expand Up @@ -139,7 +140,8 @@ static inline int LCISD_poll_cq(LCIS_endpoint_t endpoint_pp,
#ifdef LCI_ENABLE_MULTITHREAD_PROGRESS
LCIU_release_spinlock(&endpoint_p->cq_lock);
#endif
if (ne > 0) LCII_PCOUNTER_ADD(net_poll_cq_num, ne);
LCII_PCOUNTER_ADD(net_poll_cq_calls, 1);
if (ne > 0) LCII_PCOUNTER_ADD(net_poll_cq_entry_count, ne);
for (int i = 0; i < ne; i++) {
LCI_DBG_Assert(
wc[i].status == IBV_WC_SUCCESS, "Failed status %s (%d) for wr_id %d\n",
Expand Down
45 changes: 14 additions & 31 deletions lci/backend/ofi/server_ofi.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ static struct fi_info* search_for_prov(struct fi_info* info, char* prov_name)
return NULL;
}

void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
void LCISD_server_init(LCIS_server_t* s)
{
LCISI_server_t* server = LCIU_malloc(sizeof(LCISI_server_t));
*s = (LCIS_server_t)server;
server->device = device;

// Create hint.
char* p = getenv("LCI_OFI_PROVIDER_HINT");
Expand Down Expand Up @@ -104,6 +103,10 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)
LCI_Assert(LCI_USE_DREG == 0,
"The registration cache should be turned off "
"for libfabric cxi backend. Use `export LCI_USE_DREG=0`.\n");
LCI_Assert(LCI_ENABLE_PRG_NET_ENDPOINT == 0,
"The progress-specific network endpoint "
"for libfabric cxi backend. Use `export "
"LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
if (LCI_RDV_PROTOCOL != LCI_RDV_WRITE) {
LCI_RDV_PROTOCOL = LCI_RDV_WRITE;
LCI_Warn(
Expand All @@ -114,19 +117,11 @@ void LCISD_server_init(LCI_device_t device, LCIS_server_t* s)

// Create libfabric obj.
FI_SAFECALL(fi_fabric(server->info->fabric_attr, &server->fabric, NULL));

// Create domain.
FI_SAFECALL(fi_domain(server->fabric, server->info, &server->domain, NULL));

server->endpoint_count = 0;
}

void LCISD_server_fina(LCIS_server_t s)
{
LCISI_server_t* server = (LCISI_server_t*)s;
LCI_Assert(server->endpoint_count == 0, "Endpoint count is not zero (%d)\n",
server->endpoint_count);
FI_SAFECALL(fi_close((struct fid*)&server->domain->fid));
FI_SAFECALL(fi_close((struct fid*)&server->fabric->fid));
fi_freeinfo(server->info);
free(s);
Expand All @@ -139,32 +134,24 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
LCISI_endpoint_t* endpoint_p = LCIU_malloc(sizeof(LCISI_endpoint_t));
*endpoint_pp = (LCIS_endpoint_t)endpoint_p;
endpoint_p->server = (LCISI_server_t*)server_pp;
endpoint_p->server->endpoints[endpoint_p->server->endpoint_count++] =
endpoint_p;
endpoint_p->is_single_threaded = single_threaded;
if (!LCI_OFI_CXI_TRY_NO_HACK &&
strcmp(endpoint_p->server->info->fabric_attr->prov_name, "cxi") == 0 &&
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_ENDPOINT &&
endpoint_p->server->endpoint_count > 1) {
// We are using more than one endpoint per server, but the cxi provider
// can only bind mr to one endpoint. We have to guess here.
endpoint_p->server->cxi_mr_bind_hack = true;
} else {
endpoint_p->server->cxi_mr_bind_hack = false;
}

// Create domain.
FI_SAFECALL(fi_domain(endpoint_p->server->fabric, endpoint_p->server->info,
&endpoint_p->domain, NULL));

// Create end-point;
endpoint_p->server->info->tx_attr->size = LCI_SERVER_MAX_SENDS;
endpoint_p->server->info->rx_attr->size = LCI_SERVER_MAX_RECVS;
FI_SAFECALL(fi_endpoint(endpoint_p->server->domain, endpoint_p->server->info,
FI_SAFECALL(fi_endpoint(endpoint_p->domain, endpoint_p->server->info,
&endpoint_p->ep, NULL));

// Create cq.
struct fi_cq_attr cq_attr;
memset(&cq_attr, 0, sizeof(struct fi_cq_attr));
cq_attr.format = FI_CQ_FORMAT_DATA;
cq_attr.size = LCI_SERVER_MAX_CQES;
FI_SAFECALL(
fi_cq_open(endpoint_p->server->domain, &cq_attr, &endpoint_p->cq, NULL));
FI_SAFECALL(fi_cq_open(endpoint_p->domain, &cq_attr, &endpoint_p->cq, NULL));

// Bind my ep to cq.
FI_SAFECALL(
Expand All @@ -173,8 +160,7 @@ void LCISD_endpoint_init(LCIS_server_t server_pp, LCIS_endpoint_t* endpoint_pp,
struct fi_av_attr av_attr;
memset(&av_attr, 0, sizeof(av_attr));
av_attr.type = FI_AV_MAP;
FI_SAFECALL(
fi_av_open(endpoint_p->server->domain, &av_attr, &endpoint_p->av, NULL));
FI_SAFECALL(fi_av_open(endpoint_p->domain, &av_attr, &endpoint_p->av, NULL));
FI_SAFECALL(fi_ep_bind(endpoint_p->ep, (fid_t)endpoint_p->av, 0));
FI_SAFECALL(fi_enable(endpoint_p->ep));

Expand Down Expand Up @@ -224,11 +210,8 @@ void LCISD_endpoint_fina(LCIS_endpoint_t endpoint_pp)
LCT_pmi_barrier();
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCIU_free(endpoint_p->peer_addrs);
int my_idx = --endpoint_p->server->endpoint_count;
LCI_Assert(endpoint_p->server->endpoints[my_idx] == endpoint_p,
"This is not me!\n");
endpoint_p->server->endpoints[my_idx] = NULL;
FI_SAFECALL(fi_close((struct fid*)&endpoint_p->ep->fid));
FI_SAFECALL(fi_close((struct fid*)&endpoint_p->cq->fid));
FI_SAFECALL(fi_close((struct fid*)&endpoint_p->av->fid));
FI_SAFECALL(fi_close((struct fid*)&endpoint_p->domain->fid));
}
59 changes: 12 additions & 47 deletions lci/backend/ofi/server_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,14 @@
struct LCISI_endpoint_t;

typedef struct __attribute__((aligned(LCI_CACHE_LINE))) LCISI_server_t {
LCI_device_t device;
struct fi_info* info;
struct fid_fabric* fabric;
struct fid_domain* domain;
struct LCISI_endpoint_t* endpoints[LCI_SERVER_MAX_ENDPOINTS];
int endpoint_count;
bool cxi_mr_bind_hack;
} LCISI_server_t;

typedef struct __attribute__((aligned(LCI_CACHE_LINE))) LCISI_endpoint_t {
struct LCISI_endpoint_super_t super;
LCISI_server_t* server;
struct fid_domain* domain;
struct fid_ep* ep;
struct fid_cq* cq;
struct fid_av* av;
Expand All @@ -57,32 +53,23 @@ typedef struct __attribute__((aligned(LCI_CACHE_LINE))) LCISI_endpoint_t {

extern int g_next_rdma_key;

static inline void* LCISI_real_server_reg(LCIS_server_t s, void* buf,
size_t size)
static inline void* LCISI_real_server_reg(LCIS_endpoint_t endpoint_pp,
void* buf, size_t size)
{
LCISI_server_t* server = (LCISI_server_t*)s;
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCISI_server_t* server = endpoint_p->server;
int rdma_key;
if (server->info->domain_attr->mr_mode & FI_MR_PROV_KEY) {
rdma_key = 0;
} else {
rdma_key = __sync_fetch_and_add(&g_next_rdma_key, 1);
}
struct fid_mr* mr;
FI_SAFECALL(fi_mr_reg(server->domain, buf, size,
FI_SAFECALL(fi_mr_reg(endpoint_p->domain, buf, size,
FI_READ | FI_WRITE | FI_REMOTE_WRITE, 0, rdma_key, 0,
&mr, 0));
if (server->info->domain_attr->mr_mode & FI_MR_ENDPOINT) {
LCI_DBG_Assert(server->endpoint_count >= 1, "No endpoints available!\n");
if (server->cxi_mr_bind_hack) {
// A temporary fix for the cxi provider, currently cxi cannot bind a
// memory region to more than one endpoint.
FI_SAFECALL(fi_mr_bind(
mr, &server->endpoints[server->endpoint_count - 1]->ep->fid, 0));
} else {
for (int i = 0; i < server->endpoint_count; ++i) {
FI_SAFECALL(fi_mr_bind(mr, &server->endpoints[i]->ep->fid, 0));
}
}
FI_SAFECALL(fi_mr_bind(mr, &endpoint_p->ep->fid, 0));
FI_SAFECALL(fi_mr_enable(mr));
}
return (void*)mr;
Expand All @@ -94,10 +81,11 @@ static inline void LCISI_real_server_dereg(void* mr_opaque)
FI_SAFECALL(fi_close((struct fid*)&mr->fid));
}

static inline LCIS_mr_t LCISD_rma_reg(LCIS_server_t s, void* buf, size_t size)
static inline LCIS_mr_t LCISD_rma_reg(LCIS_endpoint_t endpoint_pp, void* buf,
size_t size)
{
LCIS_mr_t mr;
mr.mr_p = LCISI_real_server_reg(s, buf, size);
mr.mr_p = LCISI_real_server_reg(endpoint_pp, buf, size);
mr.address = buf;
mr.length = size;
return mr;
Expand Down Expand Up @@ -132,8 +120,9 @@ static inline int LCISD_poll_cq(LCIS_endpoint_t endpoint_pp,
ne = fi_cq_read(endpoint_p->cq, &fi_entry, LCI_CQ_MAX_POLL);
LCISI_OFI_CS_EXIT(endpoint_p, LCI_BACKEND_TRY_LOCK_POLL)
ret = ne;
LCII_PCOUNTER_ADD(net_poll_cq_calls, 1);
if (ne > 0) {
LCII_PCOUNTER_ADD(net_poll_cq_num, ne);
LCII_PCOUNTER_ADD(net_poll_cq_entry_count, ne);
// Got an entry here
for (int i = 0; i < ne; i++) {
if (fi_entry[i].flags & FI_RECV) {
Expand Down Expand Up @@ -240,12 +229,6 @@ static inline LCI_error_t LCISD_post_puts(LCIS_endpoint_t endpoint_pp, int rank,
LCIS_rkey_t rkey)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down Expand Up @@ -292,12 +275,6 @@ static inline LCI_error_t LCISD_post_put(LCIS_endpoint_t endpoint_pp, int rank,
LCIS_rkey_t rkey, void* ctx)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but an unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down Expand Up @@ -345,12 +322,6 @@ static inline LCI_error_t LCISD_post_putImms(LCIS_endpoint_t endpoint_pp,
LCIS_rkey_t rkey, uint32_t meta)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but an unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down Expand Up @@ -381,12 +352,6 @@ static inline LCI_error_t LCISD_post_putImm(LCIS_endpoint_t endpoint_pp,
void* ctx)
{
LCISI_endpoint_t* endpoint_p = (LCISI_endpoint_t*)endpoint_pp;
LCI_Assert(
!endpoint_p->server->cxi_mr_bind_hack ||
endpoint_p == endpoint_p->server
->endpoints[endpoint_p->server->endpoint_count - 1],
"We are using cxi mr_bind hacking mode but an unexpected endpoint is "
"performing remote put. Try `export LCI_ENABLE_PRG_NET_ENDPOINT=0`.\n");
uintptr_t addr;
if (endpoint_p->server->info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ||
endpoint_p->server->info->domain_attr->mr_mode & FI_MR_BASIC) {
Expand Down
16 changes: 8 additions & 8 deletions lci/backend/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ static inline void LCIS_serve_send(void* ctx);

/* Following functions are required to be implemented by each server backend. */

void LCISD_server_init(LCI_device_t device, LCIS_server_t* s);
void LCISD_server_init(LCIS_server_t* s);
void LCISD_server_fina(LCIS_server_t s);
static inline LCIS_mr_t LCISD_rma_reg(LCIS_server_t s, void* buf, size_t size);
static inline LCIS_mr_t LCISD_rma_reg(LCIS_endpoint_t endpoint_pp, void* buf,
size_t size);
static inline void LCISD_rma_dereg(LCIS_mr_t mr);
static inline LCIS_rkey_t LCISD_rma_rkey(LCIS_mr_t mr);

Expand Down Expand Up @@ -116,10 +117,7 @@ static inline LCI_error_t LCISD_post_recv(LCIS_endpoint_t endpoint_pp,
LCIU_release_spinlock(&LCIS_endpoint_super(endpoint).lock);

/* Wrapper functions */
static inline void LCIS_server_init(LCI_device_t device, LCIS_server_t* s)
{
LCISD_server_init(device, s);
}
static inline void LCIS_server_init(LCIS_server_t* s) { LCISD_server_init(s); }

static inline void LCIS_server_fina(LCIS_server_t s) { LCISD_server_fina(s); }

Expand All @@ -128,10 +126,11 @@ static inline LCIS_rkey_t LCIS_rma_rkey(LCIS_mr_t mr)
return LCISD_rma_rkey(mr);
}

static inline LCIS_mr_t LCIS_rma_reg(LCIS_server_t s, void* buf, size_t size)
static inline LCIS_mr_t LCIS_rma_reg(LCIS_endpoint_t endpoint_pp, void* buf,
size_t size)
{
LCII_PCOUNTER_START(net_mem_reg_timer);
LCIS_mr_t mr = LCISD_rma_reg(s, buf, size);
LCIS_mr_t mr = LCISD_rma_reg(endpoint_pp, buf, size);
LCII_PCOUNTER_END(net_mem_reg_timer);
LCI_DBG_Log(LCI_LOG_TRACE, "server-reg",
"LCIS_rma_reg: mr %p buf %p size %lu rkey %lu\n", mr.mr_p, buf,
Expand Down Expand Up @@ -166,6 +165,7 @@ static inline void LCIS_endpoint_fina(LCIS_endpoint_t endpoint_pp)
static inline int LCIS_poll_cq(LCIS_endpoint_t endpoint_pp,
LCIS_cq_entry_t* entry)
{
LCII_PCOUNTER_ADD(net_poll_cq_attempts, 1);
LCISI_CS_ENTER(endpoint_pp, 0);
int ret = LCISD_poll_cq(endpoint_pp, entry);
LCISI_CS_EXIT(endpoint_pp);
Expand Down
4 changes: 3 additions & 1 deletion lci/profile/performance_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ extern LCT_pcounter_ctx_t LCII_pcounter_ctx;
_macro(net_send_failed_lock) \
_macro(net_send_failed_nomem) \
_macro(net_recv_failed_nopacket) \
_macro(net_poll_cq_num) \
_macro(net_poll_cq_attempts) \
_macro(net_poll_cq_calls) \
_macro(net_poll_cq_entry_count) \
_macro(progress_call) \
_macro(packet_get) \
_macro(packet_put) \
Expand Down
Loading
Loading