diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index a10dd627e38..b575e007cb4 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -150,6 +150,10 @@ static ucs_config_field_t ucp_config_table[] = { "Also the value has to be bigger than UCX_TM_THRESH to take an effect." , ucs_offsetof(ucp_config_t, ctx.tm_max_bcopy), UCS_CONFIG_TYPE_MEMUNITS}, + {"RNDV_FRAG_SIZE", "65536", + "RNDV fragment size \n", + ucs_offsetof(ucp_config_t, ctx.rndv_frag_size), UCS_CONFIG_TYPE_MEMUNITS}, + {NULL} }; diff --git a/src/ucp/core/ucp_context.h b/src/ucp/core/ucp_context.h index 827f486ef87..4e848c2c618 100644 --- a/src/ucp/core/ucp_context.h +++ b/src/ucp/core/ucp_context.h @@ -51,6 +51,8 @@ typedef struct ucp_context_config { ucp_atomic_mode_t atomic_mode; /** If use mutex for MT support or not */ int use_mt_mutex; + /** RNDV pipeline fragment size */ + size_t rndv_frag_size; /** On-demand progress */ int adaptive_progress; } ucp_context_config_t; diff --git a/src/ucp/core/ucp_worker.c b/src/ucp/core/ucp_worker.c index 4f3270ae5d6..f5e1cf5f730 100644 --- a/src/ucp/core/ucp_worker.c +++ b/src/ucp/core/ucp_worker.c @@ -36,6 +36,88 @@ static ucs_stats_class_t ucp_worker_stats_class = { #endif +static ucs_status_t ucp_mpool_dereg_mds(ucp_context_h context, ucp_mem_h memh) { + unsigned md_index, uct_index; + ucs_status_t status; + + uct_index = 0; + + for (md_index = 0; md_index < context->num_mds; ++md_index) { + if (!(memh->md_map & UCS_BIT(md_index))) { + continue; + } + + status = uct_md_mem_dereg(context->tl_mds[md_index].md, + memh->uct[uct_index]); + if (status != UCS_OK) { + ucs_error("Failed to dereg address %p with md %s", memh->address, + context->tl_mds[md_index].rsc.md_name); + return status; + } + + ++uct_index; + } + + return UCS_OK; +} + +static ucs_status_t ucp_mpool_reg_mds(ucp_context_h context, ucp_mem_h memh) { + unsigned md_index, uct_memh_count; + ucs_status_t status; + + uct_memh_count = 0; + memh->md_map = 0; + + for (md_index = 0; md_index < context->num_mds; ++md_index) { + if (context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_REG) { + status = uct_md_mem_reg(context->tl_mds[md_index].md, memh->address, + memh->length, 0, memh->uct[uct_memh_count]); + if (status != UCS_OK) { + ucs_error("Failed to register memory pool chunk %p with md %s", + memh->address, context->tl_mds[md_index].rsc.md_name); + return status; + } + + memh->md_map |= UCS_BIT(md_index); + uct_memh_count++; + } + } + + return UCS_OK; +} + + +static ucs_status_t ucp_mpool_rndv_malloc(ucs_mpool_t *mp, size_t *size_p, void **chunk_p) { + ucp_worker_h worker = ucs_container_of(mp, ucp_worker_t, reg_mp); + ucp_mem_desc_t *chunk_hdr; + ucs_status_t status; + + status = ucp_mpool_malloc(mp, size_p, chunk_p); + if (status != UCS_OK) { + ucs_error("Failed to allocate memory pool chunk: %s", ucs_status_string(status)); + return UCS_ERR_NO_MEMORY; + } + + chunk_hdr = (ucp_mem_desc_t *)(*chunk_p) - 1; + + status = ucp_mpool_reg_mds(worker->context, chunk_hdr->memh); + if (status != UCS_OK) { + ucp_mpool_dereg_mds(worker->context, chunk_hdr->memh); + return status; + } + + return UCS_OK; +} + + +static void ucp_mpool_rndv_free(ucs_mpool_t *mp, void *chunk) { + ucp_worker_h worker = ucs_container_of(mp, ucp_worker_t, reg_mp); + ucp_mem_desc_t *chunk_hdr = (ucp_mem_desc_t *)chunk - 1; + ucp_mpool_dereg_mds(worker->context, chunk_hdr->memh); + ucp_mpool_free(mp, chunk); +} + + ucs_mpool_ops_t ucp_am_mpool_ops = { .chunk_alloc = ucs_mpool_hugetlb_malloc, .chunk_release = ucs_mpool_hugetlb_free, @@ -52,6 +134,14 @@ ucs_mpool_ops_t ucp_reg_mpool_ops = { }; +ucs_mpool_ops_t ucp_rndv_frag_mpool_ops = { + .chunk_alloc = ucp_mpool_rndv_malloc, + .chunk_release = ucp_mpool_rndv_free, + .obj_init = ucs_empty_function, + .obj_cleanup = ucs_empty_function +}; + + void ucp_worker_iface_check_events(ucp_worker_iface_t *wiface, int force); @@ -909,8 +999,19 @@ static ucs_status_t ucp_worker_init_mpools(ucp_worker_h worker, goto err_release_am_mpool; } + + status = ucs_mpool_init(&worker->rndv_frag_mp, 0, + context->config.ext.rndv_frag_size, + 0, 128, 128, UINT_MAX, + &ucp_rndv_frag_mpool_ops, "ucp_rndv_frags"); + if (status != UCS_OK) { + goto err_release_reg_mpool; + } + return UCS_OK; +err_release_reg_mpool: + ucs_mpool_cleanup(&worker->reg_mp, 0); err_release_am_mpool: ucs_mpool_cleanup(&worker->am_mp, 0); out: @@ -1120,6 +1221,7 @@ void ucp_worker_destroy(ucp_worker_h worker) ucp_worker_destroy_eps(worker); ucs_mpool_cleanup(&worker->am_mp, 1); ucs_mpool_cleanup(&worker->reg_mp, 1); + ucs_mpool_cleanup(&worker->rndv_frag_mp, 1); ucp_worker_close_ifaces(worker); ucp_worker_wakeup_cleanup(worker); ucs_mpool_cleanup(&worker->req_mp, 1); diff --git a/src/ucp/core/ucp_worker.h b/src/ucp/core/ucp_worker.h index 916755910a5..202e5ffcd99 100644 --- a/src/ucp/core/ucp_worker.h +++ b/src/ucp/core/ucp_worker.h @@ -147,6 +147,7 @@ typedef struct ucp_worker { ucs_mpool_t am_mp; /* Memory pool for AM receives */ ucs_mpool_t reg_mp; /* Registered memory pool */ ucp_mt_lock_t mt_lock; /* Configuration of multi-threading support */ + ucs_mpool_t rndv_frag_mp; /* Memory pool for RNDV fragments */ UCS_STATS_NODE_DECLARE(stats); diff --git a/src/uct/cuda/cuda_copy/cuda_copy_md.c b/src/uct/cuda/cuda_copy/cuda_copy_md.c index 17cc88eb225..33b3b3982a7 100644 --- a/src/uct/cuda/cuda_copy/cuda_copy_md.c +++ b/src/uct/cuda/cuda_copy/cuda_copy_md.c @@ -25,7 +25,7 @@ static ucs_config_field_t uct_cuda_copy_md_config_table[] = { static ucs_status_t uct_cuda_copy_md_query(uct_md_h md, uct_md_attr_t *md_attr) { - md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_RNDV_REG | UCT_MD_FLAG_ADDR_DN; + md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_ADDR_DN; md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_CUDA; md_attr->cap.max_alloc = 0; md_attr->cap.max_reg = ULONG_MAX; diff --git a/src/uct/ib/base/ib_md.c b/src/uct/ib/base/ib_md.c index 7afc1184962..161ed70ca46 100644 --- a/src/uct/ib/base/ib_md.c +++ b/src/uct/ib/base/ib_md.c @@ -155,7 +155,6 @@ static ucs_status_t uct_ib_md_query(uct_md_h uct_md, uct_md_attr_t *md_attr) md_attr->cap.max_alloc = ULONG_MAX; /* TODO query device */ md_attr->cap.max_reg = ULONG_MAX; /* TODO query device */ md_attr->cap.flags = UCT_MD_FLAG_REG | - UCT_MD_FLAG_RNDV_REG | UCT_MD_FLAG_NEED_MEMH | UCT_MD_FLAG_NEED_RKEY | UCT_MD_FLAG_ADVISE;