diff --git a/Makefile.am b/Makefile.am
index 2b71edfe515..ad91ddd8e60 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -88,6 +88,7 @@ common_srcs = \
prov/util/src/ze_mem_monitor.c \
prov/util/src/cuda_ipc_monitor.c \
prov/util/src/rocr_ipc_monitor.c \
+ prov/util/src/ze_ipc_monitor.c \
prov/util/src/xpmem_monitor.c \
prov/util/src/util_profile.c \
prov/coll/src/coll_attr.c \
diff --git a/include/linux/osd.h b/include/linux/osd.h
index 9662be6c31a..5b8d0fcd4ee 100644
--- a/include/linux/osd.h
+++ b/include/linux/osd.h
@@ -100,6 +100,14 @@ size_t ofi_ifaddr_get_speed(struct ifaddrs *ifa);
# define __NR_process_vm_writev 311
#endif
+#ifndef __NR_pidfd_open
+# define __NR_pidfd_open 434
+#endif
+
+#ifndef __NR_pidfd_getfd
+# define __NR_pidfd_getfd 438
+#endif
+
static inline ssize_t ofi_process_vm_readv(pid_t pid,
const struct iovec *local_iov,
unsigned long liovcnt,
@@ -122,6 +130,16 @@ static inline size_t ofi_process_vm_writev(pid_t pid,
remote_iov, riovcnt, flags);
}
+static inline int ofi_pidfd_open(pid_t pid, unsigned int flags)
+{
+ return syscall(__NR_pidfd_open, pid, flags);
+}
+
+static inline int ofi_pidfd_getfd(int pidfd, int targetfd, unsigned int flags)
+{
+ return syscall(__NR_pidfd_getfd, pidfd, targetfd, flags);
+}
+
static inline ssize_t ofi_read_socket(SOCKET fd, void *buf, size_t count)
{
return read(fd, buf, count);
diff --git a/include/ofi_hmem.h b/include/ofi_hmem.h
index 4b3bc0c7b13..9db6d94cd70 100644
--- a/include/ofi_hmem.h
+++ b/include/ofi_hmem.h
@@ -204,24 +204,34 @@ int cuda_gdrcopy_dev_register(const void *buf, size_t len, uint64_t *handle);
int cuda_gdrcopy_dev_unregister(uint64_t handle);
int cuda_set_sync_memops(void *ptr);
+/*
+ * ze handle is just an fd so we need a wrapper to contain extra fields for ipc
+ * copies
+ * get_fd fd from handle_get
+ * open_fd fd from pidfd_get_fd. Needed for closing extra fd and for
+ * opening with open_handle
+ * pid_fd fd that references a pid used to open an fd across processes
+ */
+struct ze_pid_handle {
+ int get_fd;
+ int open_fd;
+ int pid_fd;
+};
+
int ze_hmem_copy(uint64_t device, void *dst, const void *src, size_t size);
int ze_hmem_init(void);
int ze_hmem_cleanup(void);
bool ze_hmem_is_addr_valid(const void *addr, uint64_t *device, uint64_t *flags);
int ze_hmem_get_handle(void *dev_buf, size_t size, void **handle);
+void ze_set_pid_fd(void **handle, int pid_fd);
int ze_hmem_open_handle(void **handle, size_t size, uint64_t device,
void **ipc_ptr);
-int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd,
- void **handle);
-int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle,
- int *ze_fd, void **ipc_ptr);
int ze_hmem_close_handle(void *ipc_ptr, void **handle);
bool ze_hmem_p2p_enabled(void);
int ze_hmem_get_ipc_handle_size(size_t *size);
int ze_hmem_get_base_addr(const void *ptr, size_t len, void **base,
size_t *size);
int ze_hmem_get_id(const void *ptr, uint64_t *id);
-int *ze_hmem_get_dev_fds(int *nfds);
int ze_hmem_host_register(void *ptr, size_t size);
int ze_hmem_host_unregister(void *ptr);
int ze_dev_register(const void *addr, size_t size, uint64_t *handle);
diff --git a/include/ofi_mr.h b/include/ofi_mr.h
index a357121beaf..876a97a8223 100644
--- a/include/ofi_mr.h
+++ b/include/ofi_mr.h
@@ -233,6 +233,7 @@ extern struct ofi_mem_monitor *cuda_ipc_monitor;
extern struct ofi_mem_monitor *rocr_monitor;
extern struct ofi_mem_monitor *rocr_ipc_monitor;
extern struct ofi_mem_monitor *ze_monitor;
+extern struct ofi_mem_monitor *ze_ipc_monitor;
extern struct ofi_mem_monitor *import_monitor;
extern struct ofi_mem_monitor *xpmem_monitor;
diff --git a/libfabric.vcxproj b/libfabric.vcxproj
index d297285781b..48209c6a344 100644
--- a/libfabric.vcxproj
+++ b/libfabric.vcxproj
@@ -755,6 +755,7 @@
+
diff --git a/prov/shm/src/smr.h b/prov/shm/src/smr.h
index a393d7038d6..d8385d335dc 100644
--- a/prov/shm/src/smr.h
+++ b/prov/shm/src/smr.h
@@ -124,7 +124,6 @@ struct smr_tx_entry {
void *map_ptr;
struct smr_ep_name *map_name;
struct ofi_mr *mr[SMR_IOV_LIMIT];
- int fd;
};
struct smr_pend_entry {
@@ -165,8 +164,6 @@ struct smr_domain {
#define SMR_PREFIX "fi_shm://"
#define SMR_PREFIX_NS "fi_ns://"
-#define SMR_ZE_SOCK_PATH "/dev/shm/ze_"
-
#define SMR_RMA_ORDER (OFI_ORDER_RAR_SET | OFI_ORDER_RAW_SET | FI_ORDER_RAS | \
OFI_ORDER_WAR_SET | OFI_ORDER_WAW_SET | FI_ORDER_WAS | \
FI_ORDER_SAR | FI_ORDER_SAW)
@@ -282,7 +279,8 @@ size_t smr_copy_from_sar(struct smr_freestack *sar_pool, struct smr_resp *resp,
const struct iovec *iov, size_t count,
size_t *bytes_done);
int smr_select_proto(void **desc, size_t iov_count, bool cma_avail,
- uint32_t op, uint64_t total_len, uint64_t op_flags);
+ bool ipc_valid, uint32_t op, uint64_t total_len,
+ uint64_t op_flags);
typedef ssize_t (*smr_proto_func)(struct smr_ep *ep, struct smr_region *peer_smr,
int64_t id, int64_t peer_id, uint32_t op, uint64_t tag,
uint64_t data, uint64_t op_flags, struct ofi_mr **desc,
@@ -320,6 +318,20 @@ static inline bool smr_vma_enabled(struct smr_ep *ep,
peer_smr->xpmem_cap_self == SMR_VMA_CAP_ON);
}
+static inline void smr_set_ipc_valid(struct smr_region *region, uint64_t id)
+{
+ if (ofi_hmem_initialized(FI_HMEM_ZE) &&
+ region->map->peers[id].pid_fd == -1)
+ smr_peer_data(region)[id].ipc_invalid = 1;
+}
+
+static inline bool smr_ipc_valid(struct smr_ep *ep, struct smr_region *peer_smr,
+ int64_t id, int64_t peer_id)
+{
+ return (!smr_peer_data(ep->region)[id].ipc_invalid &&
+ !smr_peer_data(peer_smr)[peer_id].ipc_invalid);
+}
+
static inline bool smr_ze_ipc_enabled(struct smr_region *smr,
struct smr_region *peer_smr)
{
diff --git a/prov/shm/src/smr_ep.c b/prov/shm/src/smr_ep.c
index 8b2884e3fa4..b6379471a48 100644
--- a/prov/shm/src/smr_ep.c
+++ b/prov/shm/src/smr_ep.c
@@ -300,44 +300,9 @@ static void smr_format_iov(struct smr_cmd *cmd, const struct iovec *iov,
memcpy(cmd->msg.data.iov, iov, sizeof(*iov) * count);
}
-static int smr_format_ze_ipc(struct smr_ep *ep, int64_t id, struct smr_cmd *cmd,
- const struct iovec *iov, uint64_t device, size_t total_len,
- struct smr_region *smr, struct smr_resp *resp,
- struct smr_tx_entry *pend)
-{
- int ret;
- void *base;
-
- cmd->msg.hdr.op_src = smr_src_ipc;
- cmd->msg.hdr.src_data = smr_get_offset(smr, resp);
- cmd->msg.hdr.size = total_len;
- cmd->msg.data.ipc_info.iface = FI_HMEM_ZE;
-
- if (ep->sock_info->peers[id].state == SMR_CMAP_INIT)
- smr_ep_exchange_fds(ep, id);
- if (ep->sock_info->peers[id].state != SMR_CMAP_SUCCESS)
- return -FI_EAGAIN;
-
- ret = ze_hmem_get_base_addr(iov[0].iov_base, iov[0].iov_len, &base,
- NULL);
- if (ret)
- return ret;
-
- ret = ze_hmem_get_shared_handle(device, base, &pend->fd,
- (void **) &cmd->msg.data.ipc_info.ipc_handle);
- if (ret)
- return ret;
-
- cmd->msg.data.ipc_info.device = device;
- cmd->msg.data.ipc_info.offset = (char *) iov[0].iov_base -
- (char *) base;
-
- return FI_SUCCESS;
-}
-
static int smr_format_ipc(struct smr_cmd *cmd, void *ptr, size_t len,
- struct smr_region *smr, struct smr_resp *resp,
- enum fi_hmem_iface iface, uint64_t device)
+ struct smr_region *smr, struct smr_resp *resp,
+ enum fi_hmem_iface iface, uint64_t device)
{
int ret;
void *base;
@@ -347,15 +312,16 @@ static int smr_format_ipc(struct smr_cmd *cmd, void *ptr, size_t len,
cmd->msg.hdr.size = len;
cmd->msg.data.ipc_info.iface = iface;
cmd->msg.data.ipc_info.device = device;
- ret = ofi_hmem_get_base_addr(cmd->msg.data.ipc_info.iface, ptr, len,
- &base,
+
+ ret = ofi_hmem_get_base_addr(cmd->msg.data.ipc_info.iface, ptr,
+ len, &base,
&cmd->msg.data.ipc_info.base_length);
if (ret)
return ret;
ret = ofi_hmem_get_handle(cmd->msg.data.ipc_info.iface, base,
- cmd->msg.data.ipc_info.base_length,
- (void **)&cmd->msg.data.ipc_info.ipc_handle);
+ cmd->msg.data.ipc_info.base_length,
+ (void **)&cmd->msg.data.ipc_info.ipc_handle);
if (ret)
return ret;
@@ -574,8 +540,8 @@ static int smr_format_sar(struct smr_ep *ep, struct smr_cmd *cmd,
return FI_SUCCESS;
}
-int smr_select_proto(void **desc, size_t iov_count,
- bool vma_avail, uint32_t op, uint64_t total_len,
+int smr_select_proto(void **desc, size_t iov_count, bool vma_avail,
+ bool ipc_valid, uint32_t op, uint64_t total_len,
uint64_t op_flags)
{
struct ofi_mr *smr_desc;
@@ -584,7 +550,7 @@ int smr_select_proto(void **desc, size_t iov_count,
/* Do not inline/inject if IPC is available so device to device
* transfer may occur if possible. */
- if (iov_count == 1 && desc && desc[0]) {
+ if (iov_count == 1 && desc && desc[0] && ipc_valid) {
smr_desc = (struct ofi_mr *) *desc;
iface = smr_desc->iface;
use_ipc = ofi_hmem_is_ipc_enabled(iface) &&
@@ -740,15 +706,8 @@ static ssize_t smr_do_ipc(struct smr_ep *ep, struct smr_region *peer_smr, int64_
smr_generic_format(cmd, peer_id, op, tag, data, op_flags);
assert(iov_count == 1 && desc && desc[0]);
- if (desc[0]->iface == FI_HMEM_ZE) {
- if (smr_ze_ipc_enabled(ep->region, peer_smr))
- ret = smr_format_ze_ipc(ep, id, cmd, iov,
- desc[0]->device, total_len, ep->region,
- resp, pend);
- } else {
- ret = smr_format_ipc(cmd, iov[0].iov_base, total_len, ep->region,
- resp, desc[0]->iface, desc[0]->device);
- }
+ ret = smr_format_ipc(cmd, iov[0].iov_base, total_len, ep->region,
+ resp, desc[0]->iface, desc[0]->device);
if (ret) {
FI_WARN_ONCE(&smr_prov, FI_LOG_EP_CTRL,
@@ -1005,114 +964,6 @@ static int smr_recvmsg_fd(int sock, int64_t *peer_id, int *fds, int nfds)
return ret;
}
-static void *smr_start_listener(void *args)
-{
- struct smr_ep *ep = (struct smr_ep *) args;
- struct sockaddr_un sockaddr;
- struct ofi_epollfds_event events[SMR_MAX_PEERS + 1];
- int i, ret, poll_fds, sock = -1;
- int *peer_fds = NULL;
- socklen_t len = sizeof(sockaddr);
- int64_t id, peer_id;
-
- peer_fds = calloc(ep->sock_info->nfds, sizeof(*peer_fds));
- if (!peer_fds)
- goto out;
-
- ep->region->flags |= SMR_FLAG_IPC_SOCK;
- while (1) {
- poll_fds = ofi_epoll_wait(ep->sock_info->epollfd, events,
- SMR_MAX_PEERS + 1, -1);
-
- if (poll_fds < 0) {
- FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
- "epoll error\n");
- continue;
- }
-
- for (i = 0; i < poll_fds; i++) {
- if (!events[i].data.ptr)
- goto out;
-
- sock = accept(ep->sock_info->listen_sock,
- (struct sockaddr *) &sockaddr, &len);
- if (sock < 0) {
- FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
- "accept error\n");
- continue;
- }
-
- FI_DBG(&smr_prov, FI_LOG_EP_CTRL,
- "EP accepted connection request from %s\n",
- sockaddr.sun_path);
-
- ret = smr_recvmsg_fd(sock, &id, peer_fds,
- ep->sock_info->nfds);
- if (!ret) {
- if (!ep->sock_info->peers[id].device_fds) {
- ep->sock_info->peers[id].device_fds =
- calloc(ep->sock_info->nfds,
- sizeof(*ep->sock_info->peers[id].device_fds));
- if (!ep->sock_info->peers[id].device_fds) {
- close(sock);
- goto out;
- }
- }
- memcpy(ep->sock_info->peers[id].device_fds,
- peer_fds, sizeof(*peer_fds) *
- ep->sock_info->nfds);
-
- peer_id = smr_peer_data(ep->region)[id].addr.id;
- ret = smr_sendmsg_fd(sock, id, peer_id,
- ep->sock_info->my_fds,
- ep->sock_info->nfds);
- ep->sock_info->peers[id].state =
- ret ? SMR_CMAP_FAILED :
- SMR_CMAP_SUCCESS;
- }
-
- close(sock);
- unlink(sockaddr.sun_path);
- }
- }
-out:
- free(peer_fds);
- close(ep->sock_info->listen_sock);
- unlink(ep->sock_info->name);
- return NULL;
-}
-
-static int smr_init_epoll(struct smr_sock_info *sock_info)
-{
- int ret;
-
- ret = ofi_epoll_create(&sock_info->epollfd);
- if (ret < 0)
- return ret;
-
- ret = fd_signal_init(&sock_info->signal);
- if (ret < 0)
- goto err2;
-
- ret = ofi_epoll_add(sock_info->epollfd,
- sock_info->signal.fd[FI_READ_FD],
- OFI_EPOLL_IN, NULL);
- if (ret != 0)
- goto err1;
-
- ret = ofi_epoll_add(sock_info->epollfd, sock_info->listen_sock,
- OFI_EPOLL_IN, sock_info);
- if (ret != 0)
- goto err1;
-
- return FI_SUCCESS;
-err1:
- ofi_epoll_close(sock_info->epollfd);
-err2:
- fd_signal_free(&sock_info->signal);
- return ret;
-}
-
void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id)
{
struct smr_region *peer_smr = smr_peer_region(ep->region, id);
@@ -1137,8 +988,6 @@ void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id)
name2 = smr_sock_name(ep->region);
}
client_sockaddr.sun_family = AF_UNIX;
- snprintf(client_sockaddr.sun_path, SMR_SOCK_NAME_MAX, "%s%s:%s",
- SMR_ZE_SOCK_PATH, name1, name2);
ret = bind(sock, (struct sockaddr *) &client_sockaddr,
(socklen_t) sizeof(client_sockaddr));
@@ -1152,8 +1001,6 @@ void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id)
}
server_sockaddr.sun_family = AF_UNIX;
- snprintf(server_sockaddr.sun_path, SMR_SOCK_NAME_MAX, "%s%s",
- SMR_ZE_SOCK_PATH, smr_sock_name(peer_smr));
ret = connect(sock, (struct sockaddr *) &server_sockaddr,
sizeof(server_sockaddr));
@@ -1189,80 +1036,6 @@ void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id)
SMR_CMAP_FAILED : SMR_CMAP_SUCCESS;
}
-static void smr_init_ipc_socket(struct smr_ep *ep)
-{
- struct smr_sock_name *sock_name;
- struct sockaddr_un sockaddr = {0};
- int ret;
-
- ep->sock_info = calloc(1, sizeof(*ep->sock_info));
- if (!ep->sock_info)
- goto err_out;
-
- ep->sock_info->listen_sock = socket(AF_UNIX, SOCK_STREAM, 0);
- if (ep->sock_info->listen_sock < 0)
- goto free;
-
- snprintf(smr_sock_name(ep->region), SMR_SOCK_NAME_MAX,
- "%ld:%d", (long) ep->region->pid, ep->ep_idx);
-
- sockaddr.sun_family = AF_UNIX;
- snprintf(sockaddr.sun_path, SMR_SOCK_NAME_MAX,
- "%s%s", SMR_ZE_SOCK_PATH, smr_sock_name(ep->region));
-
- ret = bind(ep->sock_info->listen_sock, (struct sockaddr *) &sockaddr,
- (socklen_t) sizeof(sockaddr));
- if (ret)
- goto close;
-
- ret = listen(ep->sock_info->listen_sock, SMR_MAX_PEERS);
- if (ret)
- goto close;
-
- FI_DBG(&smr_prov, FI_LOG_EP_CTRL, "EP listening on UNIX socket %s\n",
- sockaddr.sun_path);
-
- ret = smr_init_epoll(ep->sock_info);
- if (ret)
- goto close;
-
- sock_name = calloc(1, sizeof(*sock_name));
- if (!sock_name)
- goto cleanup;
-
- memcpy(sock_name->name, sockaddr.sun_path, strlen(sockaddr.sun_path));
- memcpy(ep->sock_info->name, sockaddr.sun_path,
- strlen(sockaddr.sun_path));
-
- pthread_mutex_lock(&sock_list_lock);
- dlist_insert_tail(&sock_name->entry, &sock_name_list);
- pthread_mutex_unlock(&sock_list_lock);
-
- ep->sock_info->my_fds = ze_hmem_get_dev_fds(&ep->sock_info->nfds);
- ret = pthread_create(&ep->sock_info->listener_thread, NULL,
- &smr_start_listener, ep);
- if (ret)
- goto remove;
-
- return;
-
-remove:
- pthread_mutex_lock(&sock_list_lock);
- dlist_remove(&sock_name->entry);
- pthread_mutex_unlock(&sock_list_lock);
- free(sock_name);
-cleanup:
- smr_cleanup_epoll(ep->sock_info);
-close:
- close(ep->sock_info->listen_sock);
- unlink(sockaddr.sun_path);
-free:
- smr_free_sock_info(ep);
-err_out:
- FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "Unable to initialize IPC socket."
- "Defaulting to SAR for device transfers\n");
-}
-
static int smr_discard(struct fi_peer_rx_entry *rx_entry)
{
struct smr_cmd_ctx *cmd_ctx = rx_entry->peer_context;
@@ -1390,10 +1163,6 @@ static int smr_ep_ctrl(struct fid *fid, int command, void *arg)
if (ep->util_ep.caps & FI_HMEM || smr_env.disable_cma) {
ep->region->cma_cap_peer = SMR_VMA_CAP_OFF;
ep->region->cma_cap_self = SMR_VMA_CAP_OFF;
- if (ep->util_ep.caps & FI_HMEM) {
- if (ze_hmem_p2p_enabled())
- smr_init_ipc_socket(ep);
- }
}
if (ofi_hmem_any_ipc_enabled())
diff --git a/prov/shm/src/smr_msg.c b/prov/shm/src/smr_msg.c
index a26425857e7..641645ea5d3 100644
--- a/prov/shm/src/smr_msg.c
+++ b/prov/shm/src/smr_msg.c
@@ -112,7 +112,8 @@ static ssize_t smr_generic_sendmsg(struct smr_ep *ep, const struct iovec *iov,
assert(!(op_flags & FI_INJECT) || total_len <= SMR_INJECT_SIZE);
proto = smr_select_proto(desc, iov_count, smr_vma_enabled(ep, peer_smr),
- op, total_len, op_flags);
+ smr_ipc_valid(ep, peer_smr, id, peer_id), op,
+ total_len, op_flags);
ret = smr_proto_ops[proto](ep, peer_smr, id, peer_id, op, tag, data, op_flags,
(struct ofi_mr **)desc, iov, iov_count, total_len,
diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c
index e66ebde9f59..759b3ffb313 100644
--- a/prov/shm/src/smr_progress.c
+++ b/prov/shm/src/smr_progress.c
@@ -533,9 +533,8 @@ static struct smr_pend_entry *smr_progress_ipc(struct smr_cmd *cmd,
{
struct smr_region *peer_smr;
struct smr_resp *resp;
- void *base, *ptr;
- int64_t id;
- int ret, ipc_fd;
+ void *ptr;
+ int ret;
ssize_t hmem_copy_ret;
struct ofi_mr_entry *mr_entry;
struct smr_domain *domain;
@@ -547,27 +546,18 @@ static struct smr_pend_entry *smr_progress_ipc(struct smr_cmd *cmd,
peer_smr = smr_peer_region(ep->region, cmd->msg.hdr.id);
resp = smr_get_ptr(peer_smr, cmd->msg.hdr.src_data);
+ if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE)
+ ze_set_pid_fd((void **) &cmd->msg.data.ipc_info.ipc_handle,
+ ep->region->map->peers[cmd->msg.hdr.id].pid_fd);
+
//TODO disable IPC if more than 1 interface is initialized
- if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE) {
- id = cmd->msg.hdr.id;
- ret = ze_hmem_open_shared_handle(cmd->msg.data.ipc_info.device,
- ep->sock_info->peers[id].device_fds,
- (void **) &cmd->msg.data.ipc_info.ipc_handle,
- &ipc_fd, &base);
- } else {
- ret = ofi_ipc_cache_search(domain->ipc_cache,
- cmd->msg.hdr.id,
- &cmd->msg.data.ipc_info,
- &mr_entry);
- }
+ ret = ofi_ipc_cache_search(domain->ipc_cache, cmd->msg.hdr.id,
+ &cmd->msg.data.ipc_info, &mr_entry);
if (ret)
goto out;
- if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE)
- ptr = (char *) base + (uintptr_t) cmd->msg.data.ipc_info.offset;
- else
- ptr = (char *) (uintptr_t) mr_entry->info.mapped_addr +
- (uintptr_t) cmd->msg.data.ipc_info.offset;
+ ptr = (char *) (uintptr_t) mr_entry->info.mapped_addr +
+ (uintptr_t) cmd->msg.data.ipc_info.offset;
if (cmd->msg.data.ipc_info.iface == FI_HMEM_ROCR) {
*total_len = 0;
@@ -594,14 +584,7 @@ static struct smr_pend_entry *smr_progress_ipc(struct smr_cmd *cmd,
iov_count, 0, ptr, cmd->msg.hdr.size);
}
- if (cmd->msg.data.ipc_info.iface == FI_HMEM_ZE) {
- close(ipc_fd);
- /* Truncation error takes precedence over close_handle error */
- ret = ofi_hmem_close_handle(cmd->msg.data.ipc_info.iface, base,
- NULL);
- } else {
- ofi_mr_cache_delete(domain->ipc_cache, mr_entry);
- }
+ ofi_mr_cache_delete(domain->ipc_cache, mr_entry);
if (hmem_copy_ret < 0)
*err = hmem_copy_ret;
@@ -906,6 +889,8 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd)
smr_map_to_region(&smr_prov, ep->region->map, idx);
peer_smr = smr_peer_region(ep->region, idx);
}
+
+ smr_set_ipc_valid(ep->region, idx);
smr_peer_data(peer_smr)[cmd->msg.hdr.id].addr.id = idx;
smr_peer_data(ep->region)[idx].addr.id = cmd->msg.hdr.id;
diff --git a/prov/shm/src/smr_rma.c b/prov/shm/src/smr_rma.c
index 62ee11d8d81..d1317424a92 100644
--- a/prov/shm/src/smr_rma.c
+++ b/prov/shm/src/smr_rma.c
@@ -171,7 +171,8 @@ static ssize_t smr_generic_rma(struct smr_ep *ep, const struct iovec *iov,
assert(!(op_flags & FI_INJECT) || total_len <= SMR_INJECT_SIZE);
proto = smr_select_proto(desc, iov_count, smr_vma_enabled(ep, peer_smr),
- op, total_len, op_flags);
+ smr_ipc_valid(ep, peer_smr, id, peer_id), op,
+ total_len, op_flags);
ret = smr_proto_ops[proto](ep, peer_smr, id, peer_id, op, 0, data,
op_flags, (struct ofi_mr **)desc, iov,
diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c
index 85e76f6bcab..5ddcd4bc512 100644
--- a/prov/shm/src/smr_util.c
+++ b/prov/shm/src/smr_util.c
@@ -453,6 +453,15 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map,
if (ret)
FI_WARN(prov, FI_LOG_EP_CTRL,
"unable to register shm with iface\n");
+ if (ofi_hmem_is_initialized(FI_HMEM_ZE)) {
+ peer_buf->pid_fd = ofi_pidfd_open(peer->pid, 0);
+ if (peer_buf->pid_fd < 0) {
+ FI_WARN(prov, FI_LOG_EP_CTRL,
+ "unable to open pidfd\n");
+ }
+ } else {
+ peer_buf->pid_fd = -1;
+ }
}
out:
@@ -499,6 +508,8 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id)
local_peers[id].xpmem.cap = SMR_VMA_CAP_OFF;
}
+ smr_set_ipc_valid(region, id);
+
return;
}
@@ -585,8 +596,12 @@ void smr_map_del(struct smr_map *map, int64_t id)
goto unlock;
if (!entry) {
- if (map->flags & SMR_FLAG_HMEM_ENABLED)
+ if (map->flags & SMR_FLAG_HMEM_ENABLED) {
+ if (map->peers[id].pid_fd != -1)
+ close(map->peers[id].pid_fd);
+
(void) ofi_hmem_host_unregister(map->peers[id].region);
+ }
munmap(map->peers[id].region, map->peers[id].region->total_size);
}
diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h
index 9ad2fe9337e..3aa2d67f143 100644
--- a/prov/shm/src/smr_util.h
+++ b/prov/shm/src/smr_util.h
@@ -54,7 +54,7 @@
extern "C" {
#endif
-#define SMR_VERSION 7
+#define SMR_VERSION 8
#define SMR_FLAG_ATOMIC (1 << 0)
#define SMR_FLAG_DEBUG (1 << 1)
@@ -177,7 +177,8 @@ struct smr_addr {
struct smr_peer_data {
struct smr_addr addr;
uint32_t sar_status;
- uint32_t name_sent;
+ uint16_t name_sent;
+ uint16_t ipc_invalid;
struct xpmem_client xpmem;
};
@@ -205,6 +206,7 @@ struct smr_peer {
struct smr_addr peer;
fi_addr_t fiaddr;
struct smr_region *region;
+ int pid_fd;
};
#define SMR_MAX_PEERS 256
diff --git a/prov/util/src/util_mem_monitor.c b/prov/util/src/util_mem_monitor.c
index f7385c5d11f..d1c980bf94b 100644
--- a/prov/util/src/util_mem_monitor.c
+++ b/prov/util/src/util_mem_monitor.c
@@ -189,6 +189,7 @@ static void initialize_monitor_list()
rocr_ipc_monitor,
xpmem_monitor,
ze_monitor,
+ ze_ipc_monitor,
import_monitor,
};
diff --git a/prov/util/src/ze_ipc_monitor.c b/prov/util/src/ze_ipc_monitor.c
new file mode 100644
index 00000000000..70cb61a04d4
--- /dev/null
+++ b/prov/util/src/ze_ipc_monitor.c
@@ -0,0 +1,70 @@
+/*
+ * (C) Copyright Intel Corporation
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the
+ * BSD license below:
+ *
+ * Redistribution and use in source and binary forms, with or
+ * without modification, are permitted provided that the following
+ * conditions are met:
+ *
+ * - Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * - Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ofi_mr.h"
+
+#if HAVE_ZE
+
+#include "ofi_hmem.h"
+
+static bool ze_ipc_monitor_valid(struct ofi_mem_monitor *monitor,
+ const struct ofi_mr_info *info,
+ struct ofi_mr_entry *entry)
+{
+ struct ze_pid_handle *existing = (struct ze_pid_handle *)info->handle;
+ struct ze_pid_handle *comp = (struct ze_pid_handle *)entry->info.handle;
+
+ return (existing->get_fd == comp->get_fd);
+}
+
+#else
+
+static bool ze_ipc_monitor_valid(struct ofi_mem_monitor *monitor,
+ const struct ofi_mr_info *info,
+ struct ofi_mr_entry *entry)
+{
+ return false;
+}
+
+#endif /* HAVE_ZE */
+static struct ofi_mem_monitor ze_ipc_monitor_ = {
+ .init = ofi_monitor_init,
+ .cleanup = ofi_monitor_cleanup,
+ .start = ofi_monitor_start_no_op,
+ .stop = ofi_monitor_stop_no_op,
+ .subscribe = ofi_monitor_subscribe_no_op,
+ .unsubscribe = ofi_monitor_unsubscribe_no_op,
+ .valid = ze_ipc_monitor_valid,
+ .name = "ze_ipc",
+};
+
+struct ofi_mem_monitor *ze_ipc_monitor = &ze_ipc_monitor_;
diff --git a/src/hmem_ipc_cache.c b/src/hmem_ipc_cache.c
index 62ae1d37dbf..39bc61fe4bb 100644
--- a/src/hmem_ipc_cache.c
+++ b/src/hmem_ipc_cache.c
@@ -94,11 +94,13 @@ int ofi_ipc_cache_open(struct ofi_mr_cache **cache,
int ret;
if (!ofi_hmem_is_ipc_enabled(FI_HMEM_CUDA) &&
- !ofi_hmem_is_ipc_enabled(FI_HMEM_ROCR))
+ !ofi_hmem_is_ipc_enabled(FI_HMEM_ROCR) &&
+ !ofi_hmem_is_ipc_enabled(FI_HMEM_ZE))
return FI_SUCCESS;
memory_monitors[FI_HMEM_CUDA] = cuda_ipc_monitor;
memory_monitors[FI_HMEM_ROCR] = rocr_ipc_monitor;
+ memory_monitors[FI_HMEM_ZE] = ze_ipc_monitor;
*cache = calloc(1, sizeof(*(*cache)));
if (!*cache) {
diff --git a/src/hmem_ze.c b/src/hmem_ze.c
index e3390580870..2ef42db4c12 100644
--- a/src/hmem_ze.c
+++ b/src/hmem_ze.c
@@ -60,8 +60,6 @@ static ze_context_handle_t context;
static ze_device_handle_t *devices = NULL;
static struct ofi_ze_dev_info *dev_info = NULL;
static int num_devices = 0;
-static int num_pci_devices = 0;
-static int *dev_fds = NULL;
static bool p2p_enabled = false;
static bool host_reg_enabled = true;
ofi_spin_t cl_lock;
@@ -382,111 +380,6 @@ ze_result_t ofi_zexDriverReleaseImportedPointer(ze_driver_handle_t hDriver,
#include
#include
-/*
-static int ze_hmem_init_fds(void)
-{
- const char *dev_dir = "/dev/dri/by-path/";
- const char *suffix = "-render";
- DIR *dir;
- struct dirent *ent = NULL;
- char dev_name[NAME_MAX];
- int ret, i;
- char *str;
- uint16_t domain_id;
- uint8_t pci_id;
- char *saveptr;
-
- dir = opendir(dev_dir);
- if (dir == NULL)
- return -FI_EIO;
-
- while ((ent = readdir(dir)) != NULL) {
- if (ent->d_name[0] == '.' ||
- strstr(ent->d_name, suffix) == NULL)
- continue;
-
- memset(dev_name, 0, sizeof(dev_name));
- ret = snprintf(dev_name, NAME_MAX, "%s%s", dev_dir, ent->d_name);
- if (ret < 0 || ret >= NAME_MAX)
- goto err;
-
- dev_fds[num_pci_devices] = open(dev_name, O_RDWR);
- if (dev_fds[num_pci_devices] == -1)
- goto err;
- str = strtok_r(ent->d_name, "-", &saveptr);
- str = strtok_r(NULL, ":", &saveptr);
- domain_id = (uint16_t) strtol(str, NULL, 16);
- str = strtok_r(NULL, ":", &saveptr);
- pci_id = (uint8_t) strtol(str, NULL, 16);
- for (i = 0; i < num_devices; i++) {
- if (dev_info[i].uuid.id[8] == pci_id &&
- (uint16_t) dev_info[i].uuid.id[6] == domain_id)
- dev_info[i].pci_device = num_pci_devices;
- }
- num_pci_devices++;
- }
-
- (void) closedir(dir);
- return FI_SUCCESS;
-
-err:
- (void) closedir(dir);
- FI_WARN(&core_prov, FI_LOG_CORE,
- "Failed open device %d\n", num_pci_devices);
- return -FI_EIO;
-}
-*/
-
-int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd,
- void **handle)
-{
- struct drm_prime_handle open_fd = {0, 0, 0};
- ze_ipc_mem_handle_t ze_handle;
- int dev_fd = dev_fds[dev_info[device].pci_device];
- int ret;
-
- assert(dev_fd != -1);
- ret = ze_hmem_get_handle(dev_buf, 0, (void **) &ze_handle);
- if (ret)
- return ret;
-
- memcpy(ze_fd, &ze_handle, sizeof(*ze_fd));
- memcpy(&open_fd.fd, &ze_handle, sizeof(open_fd.fd));
- ret = ioctl(dev_fd, DRM_IOCTL_PRIME_FD_TO_HANDLE, &open_fd);
- if (ret) {
- FI_WARN(&core_prov, FI_LOG_CORE,
- "ioctl call failed on get, err %d\n", errno);
- return -FI_EINVAL;
- }
-
- *(int *) handle = open_fd.handle;
- return FI_SUCCESS;
-}
-
-int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle,
- int *ze_fd, void **ipc_ptr)
-{
- struct drm_prime_handle open_fd = {0, 0, 0};
- ze_ipc_mem_handle_t ze_handle;
- int dev_fd = peer_fds[dev_info[device].pci_device];
- int ret;
-
- open_fd.flags = DRM_CLOEXEC | DRM_RDWR;
- open_fd.handle = *(int *) handle;
-
- ret = ioctl(dev_fd, DRM_IOCTL_PRIME_HANDLE_TO_FD, &open_fd);
- if (ret) {
- FI_WARN(&core_prov, FI_LOG_CORE,
- "ioctl call failed on open, err %d\n", errno);
- return -FI_EINVAL;
- }
-
- *ze_fd = open_fd.fd;
- memset(&ze_handle, 0, sizeof(ze_handle));
- memcpy(&ze_handle, &open_fd.fd, sizeof(open_fd.fd));
- return ze_hmem_open_handle((void **) &ze_handle, 0, device, ipc_ptr);
-}
-
bool ze_hmem_p2p_enabled(void)
{
return !ofi_hmem_p2p_disabled() && p2p_enabled;
@@ -494,23 +387,6 @@ bool ze_hmem_p2p_enabled(void)
#else
-static int ze_hmem_init_fds(void)
-{
- return FI_SUCCESS;
-}
-
-int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd,
- void **handle)
-{
- return -FI_ENOSYS;
-}
-
-int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle,
- int *ze_fd, void **ipc_ptr)
-{
- return -FI_ENOSYS;
-}
-
bool ze_hmem_p2p_enabled(void)
{
return false;
@@ -762,15 +638,10 @@ static int ze_hmem_cleanup_internal(int fini_workaround)
}
if (dev_info[i].cl_pool)
ofi_bufpool_destroy(dev_info[i].cl_pool);
- if (dev_fds[i] != -1) {
- close(dev_fds[i]);
- dev_fds[i] = -1;
- }
}
free(devices);
free(dev_info);
- free(dev_fds);
if (!fini_workaround) {
if (ofi_zeContextDestroy(context))
@@ -792,9 +663,9 @@ int ze_hmem_init(void)
ze_context_desc_t context_desc = {0};
ze_device_properties_t dev_prop = {0};
ze_result_t ze_ret;
-// ze_bool_t access;
+ ze_bool_t access;
uint32_t count, i;
-// bool p2p = false;
+ bool p2p = true;
int ret;
char *enginestr = NULL;
int ordinal = -1;
@@ -850,16 +721,13 @@ int ze_hmem_init(void)
devices = calloc(count, sizeof(*devices));
dev_info = calloc(count, sizeof(*dev_info));
- dev_fds = calloc(count, sizeof(*dev_fds));
- if (!devices || !dev_info || !dev_fds)
+ if (!devices || !dev_info)
goto err;
ze_ret = ofi_zeDeviceGet(driver, &count, devices);
if (ze_ret)
goto err;
- for (i = 0; i < count; dev_fds[i++] = -1)
- ;
for (num_devices = 0; num_devices < count; num_devices++) {
dev_info[num_devices].cl_pool = NULL;
@@ -878,22 +746,14 @@ int ze_hmem_init(void)
if (ze_ret)
goto err;
-/*
for (i = 0; i < count; i++) {
if (ofi_zeDeviceCanAccessPeer(devices[num_devices],
devices[i], &access) || !access)
p2p = false;
}
-*/
}
-/*
- ret = ze_hmem_init_fds();
- if (ret)
- goto err;
-
p2p_enabled = p2p;
-*/
return FI_SUCCESS;
err:
@@ -1055,42 +915,67 @@ bool ze_hmem_is_addr_valid(const void *addr, uint64_t *device, uint64_t *flags)
int ze_hmem_get_handle(void *dev_buf, size_t size, void **handle)
{
ze_result_t ze_ret;
+ struct ze_pid_handle *pid_handle = (struct ze_pid_handle *) handle;
+ ze_ipc_mem_handle_t temp_handle;
- ze_ret = ofi_zeMemGetIpcHandle(context, dev_buf,
- (ze_ipc_mem_handle_t *) handle);
+ ze_ret = ofi_zeMemGetIpcHandle(context, dev_buf, &temp_handle);
if (ze_ret) {
FI_WARN(&core_prov, FI_LOG_CORE, "Unable to get handle\n");
return -FI_EINVAL;
}
+ memcpy(&pid_handle->get_fd, &temp_handle,
+ sizeof(pid_handle->get_fd));
return FI_SUCCESS;
}
+void ze_set_pid_fd(void **handle, int pid_fd)
+{
+ ((struct ze_pid_handle *) handle)->pid_fd = pid_fd;
+}
+
int ze_hmem_open_handle(void **handle, size_t size, uint64_t device,
void **ipc_ptr)
{
ze_result_t ze_ret;
+ struct ze_pid_handle *pid_handle = (struct ze_pid_handle *) handle;
+ ze_ipc_mem_handle_t temp_handle = {0};
+ int fd;
int dev_id = ze_get_device_idx(device);
+ void *buf;
/* only device memory is supported */
assert(!ze_get_driver_idx(device) && dev_id >= 0);
- ze_ret = ofi_zeMemOpenIpcHandle(context, devices[dev_id],
- *((ze_ipc_mem_handle_t *) handle),
- 0, ipc_ptr);
+ fd = ofi_pidfd_getfd(pid_handle->pid_fd, pid_handle->get_fd, 0);
+ if (fd < 0) {
+ FI_WARN(&core_prov, FI_LOG_CORE,
+ "Unable find fd from peer pid\n");
+ return -FI_EINVAL;
+ }
+
+ pid_handle->open_fd = fd;
+ memcpy(&temp_handle, &pid_handle->open_fd, sizeof(pid_handle->open_fd));
+ ze_ret = ofi_zeMemOpenIpcHandle(context, devices[dev_id], temp_handle,
+ 0, &buf);
if (ze_ret) {
FI_WARN(&core_prov, FI_LOG_CORE,
"Unable to open memory handle\n");
+ pid_handle->open_fd = -1;
+ close(fd);
return -FI_EINVAL;
}
+ *ipc_ptr = buf;
return FI_SUCCESS;
}
int ze_hmem_close_handle(void *ipc_ptr, void **handle)
{
ze_result_t ze_ret;
+ struct ze_pid_handle *pid_handle = (struct ze_pid_handle *) handle;
+ close(pid_handle->open_fd);
ze_ret = ofi_zeMemCloseIpcHandle(context, ipc_ptr);
if (ze_ret) {
FI_WARN(&core_prov, FI_LOG_CORE,
@@ -1103,7 +988,7 @@ int ze_hmem_close_handle(void *ipc_ptr, void **handle)
int ze_hmem_get_ipc_handle_size(size_t *size)
{
- *size = sizeof(ze_ipc_mem_handle_t);
+ *size = sizeof(struct ze_pid_handle);
return FI_SUCCESS;
}
@@ -1140,12 +1025,6 @@ int ze_hmem_get_id(const void *ptr, uint64_t *id)
return FI_SUCCESS;
}
-int *ze_hmem_get_dev_fds(int *nfds)
-{
- *nfds = num_pci_devices;
- return dev_fds;
-}
-
int ze_hmem_host_register(void *ptr, size_t size)
{
ze_result_t ze_ret;
@@ -1395,20 +1274,13 @@ int ze_hmem_get_handle(void *dev_buf, size_t size, void **handle)
return -FI_ENOSYS;
}
-int ze_hmem_open_handle(void **handle, size_t size, uint64_t device,
- void **ipc_ptr)
-{
- return -FI_ENOSYS;
-}
-
-int ze_hmem_get_shared_handle(uint64_t device, void *dev_buf, int *ze_fd,
- void **handle)
+void ze_set_pid_fd(void **handle, int pid_fd)
{
- return -FI_ENOSYS;
+ /* no-op */
}
-int ze_hmem_open_shared_handle(uint64_t device, int *peer_fds, void **handle,
- int *ze_fd, void **ipc_ptr)
+int ze_hmem_open_handle(void **handle, size_t size, uint64_t device,
+ void **ipc_ptr)
{
return -FI_ENOSYS;
}
@@ -1439,12 +1311,6 @@ int ze_hmem_get_id(const void *ptr, uint64_t *id)
return -FI_ENOSYS;
}
-int *ze_hmem_get_dev_fds(int *nfds)
-{
- *nfds = 0;
- return NULL;
-}
-
int ze_hmem_host_register(void *ptr, size_t size)
{
return FI_SUCCESS;