diff --git a/Makefile.am b/Makefile.am index 2b71edfe515..ad91ddd8e60 100644 --- a/Makefile.am +++ b/Makefile.am @@ -88,6 +88,7 @@ common_srcs = \ prov/util/src/ze_mem_monitor.c \ prov/util/src/cuda_ipc_monitor.c \ prov/util/src/rocr_ipc_monitor.c \ + prov/util/src/ze_ipc_monitor.c \ prov/util/src/xpmem_monitor.c \ prov/util/src/util_profile.c \ prov/coll/src/coll_attr.c \ diff --git a/include/linux/osd.h b/include/linux/osd.h index 9662be6c31a..5b8d0fcd4ee 100644 --- a/include/linux/osd.h +++ b/include/linux/osd.h @@ -100,6 +100,14 @@ size_t ofi_ifaddr_get_speed(struct ifaddrs *ifa); # define __NR_process_vm_writev 311 #endif +#ifndef __NR_pidfd_open +# define __NR_pidfd_open 434 +#endif + +#ifndef __NR_pidfd_getfd +# define __NR_pidfd_getfd 438 +#endif + static inline ssize_t ofi_process_vm_readv(pid_t pid, const struct iovec *local_iov, unsigned long liovcnt, @@ -122,6 +130,16 @@ static inline size_t ofi_process_vm_writev(pid_t pid, remote_iov, riovcnt, flags); } +static inline int ofi_pidfd_open(pid_t pid, unsigned int flags) +{ + return syscall(__NR_pidfd_open, pid, flags); +} + +static inline int ofi_pidfd_getfd(int pidfd, int targetfd, unsigned int flags) +{ + return syscall(__NR_pidfd_getfd, pidfd, targetfd, flags); +} + static inline ssize_t ofi_read_socket(SOCKET fd, void *buf, size_t count) { return read(fd, buf, count); diff --git a/include/ofi_hmem.h b/include/ofi_hmem.h index 4b3bc0c7b13..9db6d94cd70 100644 --- a/include/ofi_hmem.h +++ b/include/ofi_hmem.h @@ -204,24 +204,34 @@ int cuda_gdrcopy_dev_register(const void *buf, size_t len, uint64_t *handle); int cuda_gdrcopy_dev_unregister(uint64_t handle); int cuda_set_sync_memops(void *ptr); +/* + * ze handle is just an fd so we need a wrapper to contain extra fields for ipc + * copies + * get_fd fd from handle_get + * open_fd fd from pidfd_get_fd. Needed for closing extra fd and for + * opening with open_handle + * pid_fd fd that references a pid used to open an fd across processes + */ +struct ze_pid_handle { + int get_fd; + int open_fd; + int pid_fd; +}; + int ze_hmem_copy(uint64_t device, void *dst, const void *src, size_t size); int ze_hmem_init(void); int ze_hmem_cleanup(void); bool ze_hmem_is_addr_valid(const void *addr, uint64_t *device, uint64_t *flags); int ze_hmem_get_handle(void *dev_buf, size_t size, void **handle); +void ze_set_pid_fd(void **handle, int pid_fd); int ze_hmem_open_handle(void **handle, size_t size, uint64_t device, void **ipc_ptr); -int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd, - void **handle); -int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle, - int *ze_fd, void **ipc_ptr); int ze_hmem_close_handle(void *ipc_ptr, void **handle); bool ze_hmem_p2p_enabled(void); int ze_hmem_get_ipc_handle_size(size_t *size); int ze_hmem_get_base_addr(const void *ptr, size_t len, void **base, size_t *size); int ze_hmem_get_id(const void *ptr, uint64_t *id); -int *ze_hmem_get_dev_fds(int *nfds); int ze_hmem_host_register(void *ptr, size_t size); int ze_hmem_host_unregister(void *ptr); int ze_dev_register(const void *addr, size_t size, uint64_t *handle); diff --git a/include/ofi_mr.h b/include/ofi_mr.h index a357121beaf..876a97a8223 100644 --- a/include/ofi_mr.h +++ b/include/ofi_mr.h @@ -233,6 +233,7 @@ extern struct ofi_mem_monitor *cuda_ipc_monitor; extern struct ofi_mem_monitor *rocr_monitor; extern struct ofi_mem_monitor *rocr_ipc_monitor; extern struct ofi_mem_monitor *ze_monitor; +extern struct ofi_mem_monitor *ze_ipc_monitor; extern struct ofi_mem_monitor *import_monitor; extern struct ofi_mem_monitor *xpmem_monitor; diff --git a/libfabric.vcxproj b/libfabric.vcxproj index d297285781b..48209c6a344 100644 --- a/libfabric.vcxproj +++ b/libfabric.vcxproj @@ -755,6 +755,7 @@ + diff --git a/prov/shm/src/smr.h b/prov/shm/src/smr.h index a393d7038d6..d8385d335dc 100644 --- a/prov/shm/src/smr.h +++ b/prov/shm/src/smr.h @@ -124,7 +124,6 @@ struct smr_tx_entry { void *map_ptr; struct smr_ep_name *map_name; struct ofi_mr *mr[SMR_IOV_LIMIT]; - int fd; }; struct smr_pend_entry { @@ -165,8 +164,6 @@ struct smr_domain { #define SMR_PREFIX "fi_shm://" #define SMR_PREFIX_NS "fi_ns://" -#define SMR_ZE_SOCK_PATH "/dev/shm/ze_" - #define SMR_RMA_ORDER (OFI_ORDER_RAR_SET | OFI_ORDER_RAW_SET | FI_ORDER_RAS | \ OFI_ORDER_WAR_SET | OFI_ORDER_WAW_SET | FI_ORDER_WAS | \ FI_ORDER_SAR | FI_ORDER_SAW) @@ -282,7 +279,8 @@ size_t smr_copy_from_sar(struct smr_freestack *sar_pool, struct smr_resp *resp, const struct iovec *iov, size_t count, size_t *bytes_done); int smr_select_proto(void **desc, size_t iov_count, bool cma_avail, - uint32_t op, uint64_t total_len, uint64_t op_flags); + bool ipc_valid, uint32_t op, uint64_t total_len, + uint64_t op_flags); typedef ssize_t (*smr_proto_func)(struct smr_ep *ep, struct smr_region *peer_smr, int64_t id, int64_t peer_id, uint32_t op, uint64_t tag, uint64_t data, uint64_t op_flags, struct ofi_mr **desc, @@ -320,6 +318,20 @@ static inline bool smr_vma_enabled(struct smr_ep *ep, peer_smr->xpmem_cap_self == SMR_VMA_CAP_ON); } +static inline void smr_set_ipc_valid(struct smr_region *region, uint64_t id) +{ + if (ofi_hmem_initialized(FI_HMEM_ZE) && + region->map->peers[id].pid_fd == -1) + smr_peer_data(region)[id].ipc_invalid = 1; +} + +static inline bool smr_ipc_valid(struct smr_ep *ep, struct smr_region *peer_smr, + int64_t id, int64_t peer_id) +{ + return (!smr_peer_data(ep->region)[id].ipc_invalid && + !smr_peer_data(peer_smr)[peer_id].ipc_invalid); +} + static inline bool smr_ze_ipc_enabled(struct smr_region *smr, struct smr_region *peer_smr) { diff --git a/prov/shm/src/smr_ep.c b/prov/shm/src/smr_ep.c index 8b2884e3fa4..b6379471a48 100644 --- a/prov/shm/src/smr_ep.c +++ b/prov/shm/src/smr_ep.c @@ -300,44 +300,9 @@ static void smr_format_iov(struct smr_cmd *cmd, const struct iovec *iov, memcpy(cmd->msg.data.iov, iov, sizeof(*iov) * count); } -static int smr_format_ze_ipc(struct smr_ep *ep, int64_t id, struct smr_cmd *cmd, - const struct iovec *iov, uint64_t device, size_t total_len, - struct smr_region *smr, struct smr_resp *resp, - struct smr_tx_entry *pend) -{ - int ret; - void *base; - - cmd->msg.hdr.op_src = smr_src_ipc; - cmd->msg.hdr.src_data = smr_get_offset(smr, resp); - cmd->msg.hdr.size = total_len; - cmd->msg.data.ipc_info.iface = FI_HMEM_ZE; - - if (ep->sock_info->peers[id].state == SMR_CMAP_INIT) - smr_ep_exchange_fds(ep, id); - if (ep->sock_info->peers[id].state != SMR_CMAP_SUCCESS) - return -FI_EAGAIN; - - ret = ze_hmem_get_base_addr(iov[0].iov_base, iov[0].iov_len, &base, - NULL); - if (ret) - return ret; - - ret = ze_hmem_get_shared_handle(device, base, &pend->fd, - (void **) &cmd->msg.data.ipc_info.ipc_handle); - if (ret) - return ret; - - cmd->msg.data.ipc_info.device = device; - cmd->msg.data.ipc_info.offset = (char *) iov[0].iov_base - - (char *) base; - - return FI_SUCCESS; -} - static int smr_format_ipc(struct smr_cmd *cmd, void *ptr, size_t len, - struct smr_region *smr, struct smr_resp *resp, - enum fi_hmem_iface iface, uint64_t device) + struct smr_region *smr, struct smr_resp *resp, + enum fi_hmem_iface iface, uint64_t device) { int ret; void *base; @@ -347,15 +312,16 @@ static int smr_format_ipc(struct smr_cmd *cmd, void *ptr, size_t len, cmd->msg.hdr.size = len; cmd->msg.data.ipc_info.iface = iface; cmd->msg.data.ipc_info.device = device; - ret = ofi_hmem_get_base_addr(cmd->msg.data.ipc_info.iface, ptr, len, - &base, + + ret = ofi_hmem_get_base_addr(cmd->msg.data.ipc_info.iface, ptr, + len, &base, &cmd->msg.data.ipc_info.base_length); if (ret) return ret; ret = ofi_hmem_get_handle(cmd->msg.data.ipc_info.iface, base, - cmd->msg.data.ipc_info.base_length, - (void **)&cmd->msg.data.ipc_info.ipc_handle); + cmd->msg.data.ipc_info.base_length, + (void **)&cmd->msg.data.ipc_info.ipc_handle); if (ret) return ret; @@ -574,8 +540,8 @@ static int smr_format_sar(struct smr_ep *ep, struct smr_cmd *cmd, return FI_SUCCESS; } -int smr_select_proto(void **desc, size_t iov_count, - bool vma_avail, uint32_t op, uint64_t total_len, +int smr_select_proto(void **desc, size_t iov_count, bool vma_avail, + bool ipc_valid, uint32_t op, uint64_t total_len, uint64_t op_flags) { struct ofi_mr *smr_desc; @@ -584,7 +550,7 @@ int smr_select_proto(void **desc, size_t iov_count, /* Do not inline/inject if IPC is available so device to device * transfer may occur if possible. */ - if (iov_count == 1 && desc && desc[0]) { + if (iov_count == 1 && desc && desc[0] && ipc_valid) { smr_desc = (struct ofi_mr *) *desc; iface = smr_desc->iface; use_ipc = ofi_hmem_is_ipc_enabled(iface) && @@ -740,15 +706,8 @@ static ssize_t smr_do_ipc(struct smr_ep *ep, struct smr_region *peer_smr, int64_ smr_generic_format(cmd, peer_id, op, tag, data, op_flags); assert(iov_count == 1 && desc && desc[0]); - if (desc[0]->iface == FI_HMEM_ZE) { - if (smr_ze_ipc_enabled(ep->region, peer_smr)) - ret = smr_format_ze_ipc(ep, id, cmd, iov, - desc[0]->device, total_len, ep->region, - resp, pend); - } else { - ret = smr_format_ipc(cmd, iov[0].iov_base, total_len, ep->region, - resp, desc[0]->iface, desc[0]->device); - } + ret = smr_format_ipc(cmd, iov[0].iov_base, total_len, ep->region, + resp, desc[0]->iface, desc[0]->device); if (ret) { FI_WARN_ONCE(&smr_prov, FI_LOG_EP_CTRL, @@ -1005,114 +964,6 @@ static int smr_recvmsg_fd(int sock, int64_t *peer_id, int *fds, int nfds) return ret; } -static void *smr_start_listener(void *args) -{ - struct smr_ep *ep = (struct smr_ep *) args; - struct sockaddr_un sockaddr; - struct ofi_epollfds_event events[SMR_MAX_PEERS + 1]; - int i, ret, poll_fds, sock = -1; - int *peer_fds = NULL; - socklen_t len = sizeof(sockaddr); - int64_t id, peer_id; - - peer_fds = calloc(ep->sock_info->nfds, sizeof(*peer_fds)); - if (!peer_fds) - goto out; - - ep->region->flags |= SMR_FLAG_IPC_SOCK; - while (1) { - poll_fds = ofi_epoll_wait(ep->sock_info->epollfd, events, - SMR_MAX_PEERS + 1, -1); - - if (poll_fds < 0) { - FI_WARN(&smr_prov, FI_LOG_EP_CTRL, - "epoll error\n"); - continue; - } - - for (i = 0; i < poll_fds; i++) { - if (!events[i].data.ptr) - goto out; - - sock = accept(ep->sock_info->listen_sock, - (struct sockaddr *) &sockaddr, &len); - if (sock < 0) { - FI_WARN(&smr_prov, FI_LOG_EP_CTRL, - "accept error\n"); - continue; - } - - FI_DBG(&smr_prov, FI_LOG_EP_CTRL, - "EP accepted connection request from %s\n", - sockaddr.sun_path); - - ret = smr_recvmsg_fd(sock, &id, peer_fds, - ep->sock_info->nfds); - if (!ret) { - if (!ep->sock_info->peers[id].device_fds) { - ep->sock_info->peers[id].device_fds = - calloc(ep->sock_info->nfds, - sizeof(*ep->sock_info->peers[id].device_fds)); - if (!ep->sock_info->peers[id].device_fds) { - close(sock); - goto out; - } - } - memcpy(ep->sock_info->peers[id].device_fds, - peer_fds, sizeof(*peer_fds) * - ep->sock_info->nfds); - - peer_id = smr_peer_data(ep->region)[id].addr.id; - ret = smr_sendmsg_fd(sock, id, peer_id, - ep->sock_info->my_fds, - ep->sock_info->nfds); - ep->sock_info->peers[id].state = - ret ? SMR_CMAP_FAILED : - SMR_CMAP_SUCCESS; - } - - close(sock); - unlink(sockaddr.sun_path); - } - } -out: - free(peer_fds); - close(ep->sock_info->listen_sock); - unlink(ep->sock_info->name); - return NULL; -} - -static int smr_init_epoll(struct smr_sock_info *sock_info) -{ - int ret; - - ret = ofi_epoll_create(&sock_info->epollfd); - if (ret < 0) - return ret; - - ret = fd_signal_init(&sock_info->signal); - if (ret < 0) - goto err2; - - ret = ofi_epoll_add(sock_info->epollfd, - sock_info->signal.fd[FI_READ_FD], - OFI_EPOLL_IN, NULL); - if (ret != 0) - goto err1; - - ret = ofi_epoll_add(sock_info->epollfd, sock_info->listen_sock, - OFI_EPOLL_IN, sock_info); - if (ret != 0) - goto err1; - - return FI_SUCCESS; -err1: - ofi_epoll_close(sock_info->epollfd); -err2: - fd_signal_free(&sock_info->signal); - return ret; -} - void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id) { struct smr_region *peer_smr = smr_peer_region(ep->region, id); @@ -1137,8 +988,6 @@ void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id) name2 = smr_sock_name(ep->region); } client_sockaddr.sun_family = AF_UNIX; - snprintf(client_sockaddr.sun_path, SMR_SOCK_NAME_MAX, "%s%s:%s", - SMR_ZE_SOCK_PATH, name1, name2); ret = bind(sock, (struct sockaddr *) &client_sockaddr, (socklen_t) sizeof(client_sockaddr)); @@ -1152,8 +1001,6 @@ void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id) } server_sockaddr.sun_family = AF_UNIX; - snprintf(server_sockaddr.sun_path, SMR_SOCK_NAME_MAX, "%s%s", - SMR_ZE_SOCK_PATH, smr_sock_name(peer_smr)); ret = connect(sock, (struct sockaddr *) &server_sockaddr, sizeof(server_sockaddr)); @@ -1189,80 +1036,6 @@ void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id) SMR_CMAP_FAILED : SMR_CMAP_SUCCESS; } -static void smr_init_ipc_socket(struct smr_ep *ep) -{ - struct smr_sock_name *sock_name; - struct sockaddr_un sockaddr = {0}; - int ret; - - ep->sock_info = calloc(1, sizeof(*ep->sock_info)); - if (!ep->sock_info) - goto err_out; - - ep->sock_info->listen_sock = socket(AF_UNIX, SOCK_STREAM, 0); - if (ep->sock_info->listen_sock < 0) - goto free; - - snprintf(smr_sock_name(ep->region), SMR_SOCK_NAME_MAX, - "%ld:%d", (long) ep->region->pid, ep->ep_idx); - - sockaddr.sun_family = AF_UNIX; - snprintf(sockaddr.sun_path, SMR_SOCK_NAME_MAX, - "%s%s", SMR_ZE_SOCK_PATH, smr_sock_name(ep->region)); - - ret = bind(ep->sock_info->listen_sock, (struct sockaddr *) &sockaddr, - (socklen_t) sizeof(sockaddr)); - if (ret) - goto close; - - ret = listen(ep->sock_info->listen_sock, SMR_MAX_PEERS); - if (ret) - goto close; - - FI_DBG(&smr_prov, FI_LOG_EP_CTRL, "EP listening on UNIX socket %s\n", - sockaddr.sun_path); - - ret = smr_init_epoll(ep->sock_info); - if (ret) - goto close; - - sock_name = calloc(1, sizeof(*sock_name)); - if (!sock_name) - goto cleanup; - - memcpy(sock_name->name, sockaddr.sun_path, strlen(sockaddr.sun_path)); - memcpy(ep->sock_info->name, sockaddr.sun_path, - strlen(sockaddr.sun_path)); - - pthread_mutex_lock(&sock_list_lock); - dlist_insert_tail(&sock_name->entry, &sock_name_list); - pthread_mutex_unlock(&sock_list_lock); - - ep->sock_info->my_fds = ze_hmem_get_dev_fds(&ep->sock_info->nfds); - ret = pthread_create(&ep->sock_info->listener_thread, NULL, - &smr_start_listener, ep); - if (ret) - goto remove; - - return; - -remove: - pthread_mutex_lock(&sock_list_lock); - dlist_remove(&sock_name->entry); - pthread_mutex_unlock(&sock_list_lock); - free(sock_name); -cleanup: - smr_cleanup_epoll(ep->sock_info); -close: - close(ep->sock_info->listen_sock); - unlink(sockaddr.sun_path); -free: - smr_free_sock_info(ep); -err_out: - FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "Unable to initialize IPC socket." - "Defaulting to SAR for device transfers\n"); -} - static int smr_discard(struct fi_peer_rx_entry *rx_entry) { struct smr_cmd_ctx *cmd_ctx = rx_entry->peer_context; @@ -1390,10 +1163,6 @@ static int smr_ep_ctrl(struct fid *fid, int command, void *arg) if (ep->util_ep.caps & FI_HMEM || smr_env.disable_cma) { ep->region->cma_cap_peer = SMR_VMA_CAP_OFF; ep->region->cma_cap_self = SMR_VMA_CAP_OFF; - if (ep->util_ep.caps & FI_HMEM) { - if (ze_hmem_p2p_enabled()) - smr_init_ipc_socket(ep); - } } if (ofi_hmem_any_ipc_enabled()) diff --git a/prov/shm/src/smr_msg.c b/prov/shm/src/smr_msg.c index a26425857e7..641645ea5d3 100644 --- a/prov/shm/src/smr_msg.c +++ b/prov/shm/src/smr_msg.c @@ -112,7 +112,8 @@ static ssize_t smr_generic_sendmsg(struct smr_ep *ep, const struct iovec *iov, assert(!(op_flags & FI_INJECT) || total_len <= SMR_INJECT_SIZE); proto = smr_select_proto(desc, iov_count, smr_vma_enabled(ep, peer_smr), - op, total_len, op_flags); + smr_ipc_valid(ep, peer_smr, id, peer_id), op, + total_len, op_flags); ret = smr_proto_ops[proto](ep, peer_smr, id, peer_id, op, tag, data, op_flags, (struct ofi_mr **)desc, iov, iov_count, total_len, diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index e66ebde9f59..759b3ffb313 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -533,9 +533,8 @@ static struct smr_pend_entry *smr_progress_ipc(struct smr_cmd *cmd, { struct smr_region *peer_smr; struct smr_resp *resp; - void *base, *ptr; - int64_t id; - int ret, ipc_fd; + void *ptr; + int ret; ssize_t hmem_copy_ret; struct ofi_mr_entry *mr_entry; struct smr_domain *domain; @@ -547,27 +546,18 @@ static struct smr_pend_entry *smr_progress_ipc(struct smr_cmd *cmd, peer_smr = smr_peer_region(ep->region, cmd->msg.hdr.id); resp = smr_get_ptr(peer_smr, cmd->msg.hdr.src_data); + if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE) + ze_set_pid_fd((void **) &cmd->msg.data.ipc_info.ipc_handle, + ep->region->map->peers[cmd->msg.hdr.id].pid_fd); + //TODO disable IPC if more than 1 interface is initialized - if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE) { - id = cmd->msg.hdr.id; - ret = ze_hmem_open_shared_handle(cmd->msg.data.ipc_info.device, - ep->sock_info->peers[id].device_fds, - (void **) &cmd->msg.data.ipc_info.ipc_handle, - &ipc_fd, &base); - } else { - ret = ofi_ipc_cache_search(domain->ipc_cache, - cmd->msg.hdr.id, - &cmd->msg.data.ipc_info, - &mr_entry); - } + ret = ofi_ipc_cache_search(domain->ipc_cache, cmd->msg.hdr.id, + &cmd->msg.data.ipc_info, &mr_entry); if (ret) goto out; - if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE) - ptr = (char *) base + (uintptr_t) cmd->msg.data.ipc_info.offset; - else - ptr = (char *) (uintptr_t) mr_entry->info.mapped_addr + - (uintptr_t) cmd->msg.data.ipc_info.offset; + ptr = (char *) (uintptr_t) mr_entry->info.mapped_addr + + (uintptr_t) cmd->msg.data.ipc_info.offset; if (cmd->msg.data.ipc_info.iface == FI_HMEM_ROCR) { *total_len = 0; @@ -594,14 +584,7 @@ static struct smr_pend_entry *smr_progress_ipc(struct smr_cmd *cmd, iov_count, 0, ptr, cmd->msg.hdr.size); } - if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE) { - close(ipc_fd); - /* Truncation error takes precedence over close_handle error */ - ret = ofi_hmem_close_handle(cmd->msg.data.ipc_info.iface, base, - NULL); - } else { - ofi_mr_cache_delete(domain->ipc_cache, mr_entry); - } + ofi_mr_cache_delete(domain->ipc_cache, mr_entry); if (hmem_copy_ret < 0) *err = hmem_copy_ret; @@ -906,6 +889,8 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) smr_map_to_region(&smr_prov, ep->region->map, idx); peer_smr = smr_peer_region(ep->region, idx); } + + smr_set_ipc_valid(ep->region, idx); smr_peer_data(peer_smr)[cmd->msg.hdr.id].addr.id = idx; smr_peer_data(ep->region)[idx].addr.id = cmd->msg.hdr.id; diff --git a/prov/shm/src/smr_rma.c b/prov/shm/src/smr_rma.c index 62ee11d8d81..d1317424a92 100644 --- a/prov/shm/src/smr_rma.c +++ b/prov/shm/src/smr_rma.c @@ -171,7 +171,8 @@ static ssize_t smr_generic_rma(struct smr_ep *ep, const struct iovec *iov, assert(!(op_flags & FI_INJECT) || total_len <= SMR_INJECT_SIZE); proto = smr_select_proto(desc, iov_count, smr_vma_enabled(ep, peer_smr), - op, total_len, op_flags); + smr_ipc_valid(ep, peer_smr, id, peer_id), op, + total_len, op_flags); ret = smr_proto_ops[proto](ep, peer_smr, id, peer_id, op, 0, data, op_flags, (struct ofi_mr **)desc, iov, diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c index 85e76f6bcab..5ddcd4bc512 100644 --- a/prov/shm/src/smr_util.c +++ b/prov/shm/src/smr_util.c @@ -453,6 +453,15 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, if (ret) FI_WARN(prov, FI_LOG_EP_CTRL, "unable to register shm with iface\n"); + if (ofi_hmem_is_initialized(FI_HMEM_ZE)) { + peer_buf->pid_fd = ofi_pidfd_open(peer->pid, 0); + if (peer_buf->pid_fd < 0) { + FI_WARN(prov, FI_LOG_EP_CTRL, + "unable to open pidfd\n"); + } + } else { + peer_buf->pid_fd = -1; + } } out: @@ -499,6 +508,8 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) local_peers[id].xpmem.cap = SMR_VMA_CAP_OFF; } + smr_set_ipc_valid(region, id); + return; } @@ -585,8 +596,12 @@ void smr_map_del(struct smr_map *map, int64_t id) goto unlock; if (!entry) { - if (map->flags & SMR_FLAG_HMEM_ENABLED) + if (map->flags & SMR_FLAG_HMEM_ENABLED) { + if (map->peers[id].pid_fd != -1) + close(map->peers[id].pid_fd); + (void) ofi_hmem_host_unregister(map->peers[id].region); + } munmap(map->peers[id].region, map->peers[id].region->total_size); } diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h index 9ad2fe9337e..3aa2d67f143 100644 --- a/prov/shm/src/smr_util.h +++ b/prov/shm/src/smr_util.h @@ -54,7 +54,7 @@ extern "C" { #endif -#define SMR_VERSION 7 +#define SMR_VERSION 8 #define SMR_FLAG_ATOMIC (1 << 0) #define SMR_FLAG_DEBUG (1 << 1) @@ -177,7 +177,8 @@ struct smr_addr { struct smr_peer_data { struct smr_addr addr; uint32_t sar_status; - uint32_t name_sent; + uint16_t name_sent; + uint16_t ipc_invalid; struct xpmem_client xpmem; }; @@ -205,6 +206,7 @@ struct smr_peer { struct smr_addr peer; fi_addr_t fiaddr; struct smr_region *region; + int pid_fd; }; #define SMR_MAX_PEERS 256 diff --git a/prov/util/src/util_mem_monitor.c b/prov/util/src/util_mem_monitor.c index f7385c5d11f..d1c980bf94b 100644 --- a/prov/util/src/util_mem_monitor.c +++ b/prov/util/src/util_mem_monitor.c @@ -189,6 +189,7 @@ static void initialize_monitor_list() rocr_ipc_monitor, xpmem_monitor, ze_monitor, + ze_ipc_monitor, import_monitor, }; diff --git a/prov/util/src/ze_ipc_monitor.c b/prov/util/src/ze_ipc_monitor.c new file mode 100644 index 00000000000..70cb61a04d4 --- /dev/null +++ b/prov/util/src/ze_ipc_monitor.c @@ -0,0 +1,70 @@ +/* + * (C) Copyright Intel Corporation + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "ofi_mr.h" + +#if HAVE_ZE + +#include "ofi_hmem.h" + +static bool ze_ipc_monitor_valid(struct ofi_mem_monitor *monitor, + const struct ofi_mr_info *info, + struct ofi_mr_entry *entry) +{ + struct ze_pid_handle *existing = (struct ze_pid_handle *)info->handle; + struct ze_pid_handle *comp = (struct ze_pid_handle *)entry->info.handle; + + return (existing->get_fd == comp->get_fd); +} + +#else + +static bool ze_ipc_monitor_valid(struct ofi_mem_monitor *monitor, + const struct ofi_mr_info *info, + struct ofi_mr_entry *entry) +{ + return false; +} + +#endif /* HAVE_ZE */ +static struct ofi_mem_monitor ze_ipc_monitor_ = { + .init = ofi_monitor_init, + .cleanup = ofi_monitor_cleanup, + .start = ofi_monitor_start_no_op, + .stop = ofi_monitor_stop_no_op, + .subscribe = ofi_monitor_subscribe_no_op, + .unsubscribe = ofi_monitor_unsubscribe_no_op, + .valid = ze_ipc_monitor_valid, + .name = "ze_ipc", +}; + +struct ofi_mem_monitor *ze_ipc_monitor = &ze_ipc_monitor_; diff --git a/src/hmem_ipc_cache.c b/src/hmem_ipc_cache.c index 62ae1d37dbf..39bc61fe4bb 100644 --- a/src/hmem_ipc_cache.c +++ b/src/hmem_ipc_cache.c @@ -94,11 +94,13 @@ int ofi_ipc_cache_open(struct ofi_mr_cache **cache, int ret; if (!ofi_hmem_is_ipc_enabled(FI_HMEM_CUDA) && - !ofi_hmem_is_ipc_enabled(FI_HMEM_ROCR)) + !ofi_hmem_is_ipc_enabled(FI_HMEM_ROCR) && + !ofi_hmem_is_ipc_enabled(FI_HMEM_ZE)) return FI_SUCCESS; memory_monitors[FI_HMEM_CUDA] = cuda_ipc_monitor; memory_monitors[FI_HMEM_ROCR] = rocr_ipc_monitor; + memory_monitors[FI_HMEM_ZE] = ze_ipc_monitor; *cache = calloc(1, sizeof(*(*cache))); if (!*cache) { diff --git a/src/hmem_ze.c b/src/hmem_ze.c index e3390580870..2ef42db4c12 100644 --- a/src/hmem_ze.c +++ b/src/hmem_ze.c @@ -60,8 +60,6 @@ static ze_context_handle_t context; static ze_device_handle_t *devices = NULL; static struct ofi_ze_dev_info *dev_info = NULL; static int num_devices = 0; -static int num_pci_devices = 0; -static int *dev_fds = NULL; static bool p2p_enabled = false; static bool host_reg_enabled = true; ofi_spin_t cl_lock; @@ -382,111 +380,6 @@ ze_result_t ofi_zexDriverReleaseImportedPointer(ze_driver_handle_t hDriver, #include #include -/* -static int ze_hmem_init_fds(void) -{ - const char *dev_dir = "/dev/dri/by-path/"; - const char *suffix = "-render"; - DIR *dir; - struct dirent *ent = NULL; - char dev_name[NAME_MAX]; - int ret, i; - char *str; - uint16_t domain_id; - uint8_t pci_id; - char *saveptr; - - dir = opendir(dev_dir); - if (dir == NULL) - return -FI_EIO; - - while ((ent = readdir(dir)) != NULL) { - if (ent->d_name[0] == '.' || - strstr(ent->d_name, suffix) == NULL) - continue; - - memset(dev_name, 0, sizeof(dev_name)); - ret = snprintf(dev_name, NAME_MAX, "%s%s", dev_dir, ent->d_name); - if (ret < 0 || ret >= NAME_MAX) - goto err; - - dev_fds[num_pci_devices] = open(dev_name, O_RDWR); - if (dev_fds[num_pci_devices] == -1) - goto err; - str = strtok_r(ent->d_name, "-", &saveptr); - str = strtok_r(NULL, ":", &saveptr); - domain_id = (uint16_t) strtol(str, NULL, 16); - str = strtok_r(NULL, ":", &saveptr); - pci_id = (uint8_t) strtol(str, NULL, 16); - for (i = 0; i < num_devices; i++) { - if (dev_info[i].uuid.id[8] == pci_id && - (uint16_t) dev_info[i].uuid.id[6] == domain_id) - dev_info[i].pci_device = num_pci_devices; - } - num_pci_devices++; - } - - (void) closedir(dir); - return FI_SUCCESS; - -err: - (void) closedir(dir); - FI_WARN(&core_prov, FI_LOG_CORE, - "Failed open device %d\n", num_pci_devices); - return -FI_EIO; -} -*/ - -int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd, - void **handle) -{ - struct drm_prime_handle open_fd = {0, 0, 0}; - ze_ipc_mem_handle_t ze_handle; - int dev_fd = dev_fds[dev_info[device].pci_device]; - int ret; - - assert(dev_fd != -1); - ret = ze_hmem_get_handle(dev_buf, 0, (void **) &ze_handle); - if (ret) - return ret; - - memcpy(ze_fd, &ze_handle, sizeof(*ze_fd)); - memcpy(&open_fd.fd, &ze_handle, sizeof(open_fd.fd)); - ret = ioctl(dev_fd, DRM_IOCTL_PRIME_FD_TO_HANDLE, &open_fd); - if (ret) { - FI_WARN(&core_prov, FI_LOG_CORE, - "ioctl call failed on get, err %d\n", errno); - return -FI_EINVAL; - } - - *(int *) handle = open_fd.handle; - return FI_SUCCESS; -} - -int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle, - int *ze_fd, void **ipc_ptr) -{ - struct drm_prime_handle open_fd = {0, 0, 0}; - ze_ipc_mem_handle_t ze_handle; - int dev_fd = peer_fds[dev_info[device].pci_device]; - int ret; - - open_fd.flags = DRM_CLOEXEC | DRM_RDWR; - open_fd.handle = *(int *) handle; - - ret = ioctl(dev_fd, DRM_IOCTL_PRIME_HANDLE_TO_FD, &open_fd); - if (ret) { - FI_WARN(&core_prov, FI_LOG_CORE, - "ioctl call failed on open, err %d\n", errno); - return -FI_EINVAL; - } - - *ze_fd = open_fd.fd; - memset(&ze_handle, 0, sizeof(ze_handle)); - memcpy(&ze_handle, &open_fd.fd, sizeof(open_fd.fd)); - return ze_hmem_open_handle((void **) &ze_handle, 0, device, ipc_ptr); -} - bool ze_hmem_p2p_enabled(void) { return !ofi_hmem_p2p_disabled() && p2p_enabled; @@ -494,23 +387,6 @@ bool ze_hmem_p2p_enabled(void) #else -static int ze_hmem_init_fds(void) -{ - return FI_SUCCESS; -} - -int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd, - void **handle) -{ - return -FI_ENOSYS; -} - -int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle, - int *ze_fd, void **ipc_ptr) -{ - return -FI_ENOSYS; -} - bool ze_hmem_p2p_enabled(void) { return false; @@ -762,15 +638,10 @@ static int ze_hmem_cleanup_internal(int fini_workaround) } if (dev_info[i].cl_pool) ofi_bufpool_destroy(dev_info[i].cl_pool); - if (dev_fds[i] != -1) { - close(dev_fds[i]); - dev_fds[i] = -1; - } } free(devices); free(dev_info); - free(dev_fds); if (!fini_workaround) { if (ofi_zeContextDestroy(context)) @@ -792,9 +663,9 @@ int ze_hmem_init(void) ze_context_desc_t context_desc = {0}; ze_device_properties_t dev_prop = {0}; ze_result_t ze_ret; -// ze_bool_t access; + ze_bool_t access; uint32_t count, i; -// bool p2p = false; + bool p2p = true; int ret; char *enginestr = NULL; int ordinal = -1; @@ -850,16 +721,13 @@ int ze_hmem_init(void) devices = calloc(count, sizeof(*devices)); dev_info = calloc(count, sizeof(*dev_info)); - dev_fds = calloc(count, sizeof(*dev_fds)); - if (!devices || !dev_info || !dev_fds) + if (!devices || !dev_info) goto err; ze_ret = ofi_zeDeviceGet(driver, &count, devices); if (ze_ret) goto err; - for (i = 0; i < count; dev_fds[i++] = -1) - ; for (num_devices = 0; num_devices < count; num_devices++) { dev_info[num_devices].cl_pool = NULL; @@ -878,22 +746,14 @@ int ze_hmem_init(void) if (ze_ret) goto err; -/* for (i = 0; i < count; i++) { if (ofi_zeDeviceCanAccessPeer(devices[num_devices], devices[i], &access) || !access) p2p = false; } -*/ } -/* - ret = ze_hmem_init_fds(); - if (ret) - goto err; - p2p_enabled = p2p; -*/ return FI_SUCCESS; err: @@ -1055,42 +915,67 @@ bool ze_hmem_is_addr_valid(const void *addr, uint64_t *device, uint64_t *flags) int ze_hmem_get_handle(void *dev_buf, size_t size, void **handle) { ze_result_t ze_ret; + struct ze_pid_handle *pid_handle = (struct ze_pid_handle *) handle; + ze_ipc_mem_handle_t temp_handle; - ze_ret = ofi_zeMemGetIpcHandle(context, dev_buf, - (ze_ipc_mem_handle_t *) handle); + ze_ret = ofi_zeMemGetIpcHandle(context, dev_buf, &temp_handle); if (ze_ret) { FI_WARN(&core_prov, FI_LOG_CORE, "Unable to get handle\n"); return -FI_EINVAL; } + memcpy(&pid_handle->get_fd, &temp_handle, + sizeof(pid_handle->get_fd)); return FI_SUCCESS; } +void ze_set_pid_fd(void **handle, int pid_fd) +{ + ((struct ze_pid_handle *) handle)->pid_fd = pid_fd; +} + int ze_hmem_open_handle(void **handle, size_t size, uint64_t device, void **ipc_ptr) { ze_result_t ze_ret; + struct ze_pid_handle *pid_handle = (struct ze_pid_handle *) handle; + ze_ipc_mem_handle_t temp_handle = {0}; + int fd; int dev_id = ze_get_device_idx(device); + void *buf; /* only device memory is supported */ assert(!ze_get_driver_idx(device) && dev_id >= 0); - ze_ret = ofi_zeMemOpenIpcHandle(context, devices[dev_id], - *((ze_ipc_mem_handle_t *) handle), - 0, ipc_ptr); + fd = ofi_pidfd_getfd(pid_handle->pid_fd, pid_handle->get_fd, 0); + if (fd < 0) { + FI_WARN(&core_prov, FI_LOG_CORE, + "Unable find fd from peer pid\n"); + return -FI_EINVAL; + } + + pid_handle->open_fd = fd; + memcpy(&temp_handle, &pid_handle->open_fd, sizeof(pid_handle->open_fd)); + ze_ret = ofi_zeMemOpenIpcHandle(context, devices[dev_id], temp_handle, + 0, &buf); if (ze_ret) { FI_WARN(&core_prov, FI_LOG_CORE, "Unable to open memory handle\n"); + pid_handle->open_fd = -1; + close(fd); return -FI_EINVAL; } + *ipc_ptr = buf; return FI_SUCCESS; } int ze_hmem_close_handle(void *ipc_ptr, void **handle) { ze_result_t ze_ret; + struct ze_pid_handle *pid_handle = (struct ze_pid_handle *) handle; + close(pid_handle->open_fd); ze_ret = ofi_zeMemCloseIpcHandle(context, ipc_ptr); if (ze_ret) { FI_WARN(&core_prov, FI_LOG_CORE, @@ -1103,7 +988,7 @@ int ze_hmem_close_handle(void *ipc_ptr, void **handle) int ze_hmem_get_ipc_handle_size(size_t *size) { - *size = sizeof(ze_ipc_mem_handle_t); + *size = sizeof(struct ze_pid_handle); return FI_SUCCESS; } @@ -1140,12 +1025,6 @@ int ze_hmem_get_id(const void *ptr, uint64_t *id) return FI_SUCCESS; } -int *ze_hmem_get_dev_fds(int *nfds) -{ - *nfds = num_pci_devices; - return dev_fds; -} - int ze_hmem_host_register(void *ptr, size_t size) { ze_result_t ze_ret; @@ -1395,20 +1274,13 @@ int ze_hmem_get_handle(void *dev_buf, size_t size, void **handle) return -FI_ENOSYS; } -int ze_hmem_open_handle(void **handle, size_t size, uint64_t device, - void **ipc_ptr) -{ - return -FI_ENOSYS; -} - -int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd, - void **handle) +void ze_set_pid_fd(void **handle, int pid_fd) { - return -FI_ENOSYS; + /* no-op */ } -int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle, - int *ze_fd, void **ipc_ptr) +int ze_hmem_open_handle(void **handle, size_t size, uint64_t device, + void **ipc_ptr) { return -FI_ENOSYS; } @@ -1439,12 +1311,6 @@ int ze_hmem_get_id(const void *ptr, uint64_t *id) return -FI_ENOSYS; } -int *ze_hmem_get_dev_fds(int *nfds) -{ - *nfds = 0; - return NULL; -} - int ze_hmem_host_register(void *ptr, size_t size) { return FI_SUCCESS;