diff --git a/prov/ucx/src/ucx.h b/prov/ucx/src/ucx.h index 86b91fa3319..91729a591ef 100644 --- a/prov/ucx/src/ucx.h +++ b/prov/ucx/src/ucx.h @@ -58,13 +58,11 @@ extern "C" { #include #include #include -#include "ofi_enosys.h" -#include -#include +#include #include #include #include -#include +#include #include #include @@ -90,7 +88,7 @@ extern "C" { #define FI_UCX_RMA_CAPS (FI_RMA | FI_READ | FI_WRITE | FI_REMOTE_READ |\ FI_REMOTE_WRITE) #define FI_UCX_CAPS (FI_SEND | FI_RECV | FI_TAGGED | FI_MSG | FI_MULTI_RECV |\ - FI_UCX_RMA_CAPS | FI_UCX_DOM_CAPS) + FI_UCX_RMA_CAPS | FI_UCX_DOM_CAPS | FI_HMEM) #define FI_UCX_MODE (0ULL) #define FI_UCX_TX_FLAGS (FI_COMPLETION | FI_INJECT) diff --git a/prov/ucx/src/ucx_domain.c b/prov/ucx/src/ucx_domain.c index 08a45b4c568..9549c2bf27e 100644 --- a/prov/ucx/src/ucx_domain.c +++ b/prov/ucx/src/ucx_domain.c @@ -192,6 +192,38 @@ static struct fi_ops ucx_mr_fi_ops = { .ops_open = fi_no_ops_open, }; +static inline void ucx_update_memtype_cache(enum fi_hmem_iface iface, + void *addr, size_t len) +{ + void *base; + size_t size; + ucs_memory_info_t mem_info; + + if (iface == FI_HMEM_SYSTEM) + return; + + if (ofi_hmem_get_base_addr(iface, addr, len, &base, &size)) + return; + + if (ucs_memtype_cache_lookup(base, size, &mem_info) != UCS_ERR_NO_ELEM) + return; + + /* + * Now we know that the address range is not in the memtype cache. The + * reason could be: + * (1) No hook has been installed for the memory allocator being used; + * (2) The hook has not been installed properly. For example, when the + * allocator is obtained via dlsym(); + * (3) The memory is allocated before the hooks are installed. + * + * We only need to add the range to the memtype cache. The exact value + * of memtype will be updated next time a lookup is performed within + * this range. + */ + ucs_memtype_cache_update(base, size, UCS_MEMORY_TYPE_UNKNOWN, + UCS_SYS_DEVICE_ID_UNKNOWN); +} + static int ucx_mr_regattr(struct fid *fid, const struct fi_mr_attr *attr, uint64_t flags, struct fid_mr **mr_fid) { @@ -207,6 +239,9 @@ static int ucx_mr_regattr(struct fid *fid, const struct fi_mr_attr *attr, if (fid->fclass != FI_CLASS_DOMAIN || !attr || attr->iov_count <= 0) return -FI_EINVAL; + if (flags & FI_MR_DMABUF) + return -FI_EINVAL; + domain = container_of(fid, struct util_domain, domain_fid.fid); m_domain = container_of(domain, struct ucx_domain, u_domain); ucx_mr = calloc(1, sizeof(*ucx_mr)); @@ -229,6 +264,9 @@ static int ucx_mr_regattr(struct fid *fid, const struct fi_mr_attr *attr, um_params.length = attr->mr_iov->iov_len; um_params.flags = 0; + ucx_update_memtype_cache(attr->iface, attr->mr_iov->iov_base, + attr->mr_iov->iov_len); + status = ucp_mem_map(m_domain->context, &um_params, &ucx_mr->memh); if (status != UCS_OK) { ret = ucx_translate_errcode(status); @@ -267,6 +305,7 @@ static int ucx_mr_regv(struct fid *fid, const struct iovec *iov, attr.offset = offset; attr.requested_key = requested_key; attr.context = context; + attr.iface = FI_HMEM_SYSTEM; return ucx_mr_regattr(fid, &attr, flags, mr_fid); } diff --git a/prov/ucx/src/ucx_init.c b/prov/ucx/src/ucx_init.c index 2d6bbf7de3c..ec246fe8d0e 100644 --- a/prov/ucx/src/ucx_init.c +++ b/prov/ucx/src/ucx_init.c @@ -364,6 +364,10 @@ static int ucx_getinfo(uint32_t version, const char *node, if ((*info)->nic) (*info)->nic->link_attr->speed = (size_t) speed_gbps * 1000 * 1000 * 1000; + + if (hints && hints->domain_attr && + (hints->domain_attr->mr_mode & FI_MR_HMEM)) + (*info)->domain_attr->mr_mode |= FI_MR_HMEM; } /* make sure the memery hooks are installed for memory type cache */ @@ -374,6 +378,10 @@ static int ucx_getinfo(uint32_t version, const char *node, static void ucx_cleanup(void) { +#if HAVE_UCX_DL + ofi_hmem_cleanup(); +#endif + FI_DBG(&ucx_prov, FI_LOG_CORE, "provider goes cleanup sequence\n"); if (ucx_descriptor.config) { ucp_config_release(ucx_descriptor.config); @@ -392,6 +400,10 @@ struct fi_provider ucx_prov = { UCX_INI { +#if HAVE_UCX_DL + ofi_hmem_init(); +#endif + ucx_init_errcodes(); fi_param_define(&ucx_prov,