Skip to content

Commit

Permalink
prov/shm: use owner-allocated srx
Browse files Browse the repository at this point in the history
The peer API has been updated to specify that the owner must allocate
the peer's fid_peer_srx. The shm implementation was allocating its
own internal fid_peer_srx.
This updates the shm implementation to assume it has a unique
fid_peer_srx and updates the imported fid_peer_srx peer_ops, saving
a pointer to the fid_peer_srx instead of the internal fid_ep which
required a wrapper function to get back to the fid_peer_srx

It also returns an internal fid_ep for the created srx which is used
to close the srx by the owner. Even though shm doesn't need anything
attached to the internal fid_ep, it is there for consistency and to
track the domain reference counting for errors.

This patch also moves the srx specific functions into smr_domain
where they belong

Signed-off-by: Alexia Ingerson <[email protected]>
  • Loading branch information
aingerson authored and shijin-aws committed Oct 11, 2024
1 parent dbc685d commit 472dc49
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 78 deletions.
11 changes: 2 additions & 9 deletions prov/shm/src/smr.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ struct smr_domain {
int fast_rma;
/* cache for use with hmem ipc */
struct ofi_mr_cache *ipc_cache;
struct fid_ep rx_ep;
struct fid_peer_srx *srx;
};

Expand Down Expand Up @@ -220,7 +221,7 @@ struct smr_ep {
const char *name;
uint64_t msg_id;
struct smr_region *volatile region;
struct fid_ep *srx;
struct fid_peer_srx *srx;
struct ofi_bufpool *cmd_ctx_pool;
struct ofi_bufpool *unexp_buf_pool;
struct ofi_bufpool *pend_buf_pool;
Expand All @@ -236,11 +237,6 @@ struct smr_ep {
void (*smr_progress_ipc_list)(struct smr_ep *ep);
};

static inline struct fid_peer_srx *smr_get_peer_srx(struct smr_ep *ep)
{
return container_of(ep->srx, struct fid_peer_srx, ep_fid);
}

#define smr_ep_rx_flags(smr_ep) ((smr_ep)->util_ep.rx_op_flags)
#define smr_ep_tx_flags(smr_ep) ((smr_ep)->util_ep.tx_op_flags)

Expand All @@ -251,9 +247,6 @@ static inline int smr_mmap_name(char *shm_name, const char *ep_name,
ep_name, msg_id);
}

int smr_srx_context(struct fid_domain *domain, struct fi_rx_attr *attr,
struct fid_ep **rx_ep, void *context);

int smr_endpoint(struct fid_domain *domain, struct fi_info *info,
struct fid_ep **ep, void *context);
void smr_ep_exchange_fds(struct smr_ep *ep, int64_t id);
Expand Down
5 changes: 2 additions & 3 deletions prov/shm/src/smr_av.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count,
struct util_ep *util_ep;
struct smr_av *smr_av;
struct smr_ep *smr_ep;
struct fid_peer_srx *srx;
struct dlist_entry *av_entry;
fi_addr_t util_addr;
int64_t shm_id = -1;
Expand Down Expand Up @@ -173,8 +172,8 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count,
smr_ep = container_of(util_ep, struct smr_ep, util_ep);
smr_ep->region->max_sar_buf_per_peer =
SMR_MAX_PEERS / smr_av->smr_map.num_peers;
srx = smr_get_peer_srx(smr_ep);
srx->owner_ops->foreach_unspec_addr(srx, &smr_get_addr);
smr_ep->srx->owner_ops->foreach_unspec_addr(smr_ep->srx,
&smr_get_addr);
}

}
Expand Down
72 changes: 72 additions & 0 deletions prov/shm/src/smr_domain.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,78 @@

#include "smr.h"

extern struct fi_ops_srx_peer smr_srx_peer_ops;

static int smr_srx_close(struct fid *fid)
{
struct smr_domain *domain = container_of(fid, struct smr_domain,
rx_ep.fid);

ofi_atomic_dec32(&domain->util_domain.ref);

return FI_SUCCESS;
}

static struct fi_ops smr_srx_fi_ops = {
.size = sizeof(struct fi_ops),
.close = smr_srx_close,
.bind = fi_no_bind,
.control = fi_no_control,
.ops_open = fi_no_ops_open,
};

static struct fi_ops_msg smr_srx_msg_ops = {
.size = sizeof(struct fi_ops_msg),
.recv = fi_no_msg_recv,
.recvv = fi_no_msg_recvv,
.recvmsg = fi_no_msg_recvmsg,
.send = fi_no_msg_send,
.sendv = fi_no_msg_sendv,
.sendmsg = fi_no_msg_sendmsg,
.inject = fi_no_msg_inject,
.senddata = fi_no_msg_senddata,
.injectdata = fi_no_msg_injectdata,
};

static struct fi_ops_tagged smr_srx_tagged_ops = {
.size = sizeof(struct fi_ops_msg),
.recv = fi_no_tagged_recv,
.recvv = fi_no_tagged_recvv,
.recvmsg = fi_no_tagged_recvmsg,
.send = fi_no_tagged_send,
.sendv = fi_no_tagged_sendv,
.sendmsg = fi_no_tagged_sendmsg,
.inject = fi_no_tagged_inject,
.senddata = fi_no_tagged_senddata,
.injectdata = fi_no_tagged_injectdata,
};

static int smr_srx_context(struct fid_domain *domain, struct fi_rx_attr *attr,
struct fid_ep **rx_ep, void *context)
{
struct smr_domain *smr_domain;

smr_domain = container_of(domain, struct smr_domain,
util_domain.domain_fid);

if (attr->op_flags & FI_PEER) {
smr_domain->srx = ((struct fi_peer_srx_context *)
(context))->srx;
smr_domain->srx->peer_ops = &smr_srx_peer_ops;
smr_domain->rx_ep.msg = &smr_srx_msg_ops;
smr_domain->rx_ep.tagged = &smr_srx_tagged_ops;
smr_domain->rx_ep.fid.ops = &smr_srx_fi_ops;
smr_domain->rx_ep.fid.fclass = FI_CLASS_SRX_CTX;
*rx_ep = &smr_domain->rx_ep;
ofi_atomic_inc32(&smr_domain->util_domain.ref);
return FI_SUCCESS;
}
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
"shared srx only supported with FI_PEER flag\n");
return -FI_EINVAL;
}


static struct fi_ops_domain smr_domain_ops = {
.size = sizeof(struct fi_ops_domain),
.av_open = smr_av_open,
Expand Down
56 changes: 16 additions & 40 deletions prov/shm/src/smr_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ int smr_ep_getopt(fid_t fid, int level, int optname, void *optval,
struct smr_ep *smr_ep =
container_of(fid, struct smr_ep, util_ep.ep_fid);

return smr_ep->srx->ops->getopt(&smr_ep->srx->fid, level, optname,
optval, optlen);
return smr_ep->srx->ep_fid.ops->getopt(&smr_ep->srx->ep_fid.fid, level,
optname, optval, optlen);
}

int smr_ep_setopt(fid_t fid, int level, int optname, const void *optval,
Expand All @@ -134,7 +134,7 @@ int smr_ep_setopt(fid_t fid, int level, int optname, const void *optval,
return -FI_ENOPROTOOPT;

if (optname == FI_OPT_MIN_MULTI_RECV) {
srx = util_get_peer_srx(smr_ep->srx)->ep_fid.fid.context;
srx = smr_ep->srx->ep_fid.fid.context;
srx->min_multi_recv_size = *(size_t *)optval;
return FI_SUCCESS;
}
Expand All @@ -159,7 +159,7 @@ static ssize_t smr_ep_cancel(fid_t ep_fid, void *context)
struct smr_ep *ep;

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid);
return ep->srx->ops->cancel(&ep->srx->fid, context);
return ep->srx->ep_fid.ops->cancel(&ep->srx->ep_fid.fid, context);
}

static struct fi_ops_ep smr_ep_ops = {
Expand Down Expand Up @@ -808,9 +808,7 @@ static int smr_ep_close(struct fid *fid)
if (ep->srx) {
/* shm is an owner provider */
if (ep->util_ep.ep_fid.msg != &smr_no_recv_msg_ops)
(void) util_srx_close(&ep->srx->fid);
else /* shm is a peer provider */
free(ep->srx);
(void) util_srx_close(&ep->srx->ep_fid.fid);
}

ofi_endpoint_close(&ep->util_ep);
Expand Down Expand Up @@ -1062,30 +1060,11 @@ static void smr_update(struct util_srx_ctx *srx, struct util_rx_entry *rx_entry)
//by another provider
}

int smr_srx_context(struct fid_domain *domain, struct fi_rx_attr *attr,
struct fid_ep **rx_ep, void *context)
{
struct smr_domain *smr_domain;

smr_domain = container_of(domain, struct smr_domain,
util_domain.domain_fid);

if (attr->op_flags & FI_PEER) {
smr_domain->srx = ((struct fi_peer_srx_context *)
(context))->srx;
return FI_SUCCESS;
}
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
"shared srx only supported with FI_PEER flag\n");
return -FI_EINVAL;
}

static int smr_ep_bind(struct fid *ep_fid, struct fid *bfid, uint64_t flags)
{
struct smr_ep *ep;
struct util_av *av;
int ret = 0;
struct fid_peer_srx *srx, *srx_b;

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
switch (bfid->fclass) {
Expand All @@ -1109,16 +1088,10 @@ static int smr_ep_bind(struct fid *ep_fid, struct fid *bfid, uint64_t flags)
struct util_cntr, cntr_fid.fid), flags);
break;
case FI_CLASS_SRX_CTX:
srx = calloc(1, sizeof(*srx));
srx_b = container_of(bfid, struct fid_peer_srx, ep_fid.fid);
srx->peer_ops = &smr_srx_peer_ops;
srx->owner_ops = srx_b->owner_ops;
srx->ep_fid.fid.context = srx_b->ep_fid.fid.context;
ep->srx = &srx->ep_fid;
ep->srx = (container_of(bfid, struct smr_domain, rx_ep.fid))->srx;
break;
default:
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
"invalid fid class\n");
FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "invalid fid class\n");
ret = -FI_EINVAL;
break;
}
Expand All @@ -1131,6 +1104,7 @@ static int smr_ep_ctrl(struct fid *fid, int command, void *arg)
struct smr_domain *domain;
struct smr_ep *ep;
struct smr_av *av;
struct fid_ep *srx;
int ret;

ep = container_of(fid, struct smr_ep, util_ep.ep_fid.fid);
Expand Down Expand Up @@ -1171,15 +1145,17 @@ static int smr_ep_ctrl(struct fid *fid, int command, void *arg)
ret = util_ep_srx_context(&domain->util_domain,
ep->rx_size, SMR_IOV_LIMIT,
SMR_INJECT_SIZE, &smr_update,
&ep->util_ep.lock, &ep->srx);
&ep->util_ep.lock, &srx);
if (ret)
return ret;

util_get_peer_srx(ep->srx)->peer_ops =
&smr_srx_peer_ops;
ret = util_srx_bind(&ep->srx->fid,
&ep->util_ep.rx_cq->cq_fid.fid,
FI_RECV);
ep->srx = container_of(srx, struct fid_peer_srx,
ep_fid.fid);
ep->srx->peer_ops = &smr_srx_peer_ops;

ret = util_srx_bind(&ep->srx->ep_fid.fid,
&ep->util_ep.rx_cq->cq_fid.fid,
FI_RECV);
if (ret)
return ret;
} else {
Expand Down
28 changes: 15 additions & 13 deletions prov/shm/src/smr_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static ssize_t smr_recvmsg(struct fid_ep *ep_fid, const struct fi_msg *msg,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_recv(ep->srx, msg->msg_iov, msg->desc,
return util_srx_generic_recv(&ep->srx->ep_fid, msg->msg_iov, msg->desc,
msg->iov_count, msg->addr, msg->context,
flags | ep->util_ep.rx_msg_flags);
}
Expand All @@ -58,8 +58,8 @@ static ssize_t smr_recvv(struct fid_ep *ep_fid, const struct iovec *iov,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_recv(ep->srx, iov, desc, count, src_addr,
context, smr_ep_rx_flags(ep));
return util_srx_generic_recv(&ep->srx->ep_fid, iov, desc, count,
src_addr, context, smr_ep_rx_flags(ep));
}

static ssize_t smr_recv(struct fid_ep *ep_fid, void *buf, size_t len,
Expand All @@ -73,8 +73,8 @@ static ssize_t smr_recv(struct fid_ep *ep_fid, void *buf, size_t len,
iov.iov_base = buf;
iov.iov_len = len;

return util_srx_generic_recv(ep->srx, &iov, &desc, 1, src_addr, context,
smr_ep_rx_flags(ep));
return util_srx_generic_recv(&ep->srx->ep_fid, &iov, &desc, 1, src_addr,
context, smr_ep_rx_flags(ep));
}

static ssize_t smr_generic_sendmsg(struct smr_ep *ep, const struct iovec *iov,
Expand Down Expand Up @@ -293,8 +293,9 @@ static ssize_t smr_trecv(struct fid_ep *ep_fid, void *buf, size_t len,
iov.iov_base = buf;
iov.iov_len = len;

return util_srx_generic_trecv(ep->srx, &iov, &desc, 1, src_addr, context,
tag, ignore, smr_ep_rx_flags(ep));
return util_srx_generic_trecv(&ep->srx->ep_fid, &iov, &desc, 1,
src_addr, context, tag, ignore,
smr_ep_rx_flags(ep));
}

static ssize_t smr_trecvv(struct fid_ep *ep_fid, const struct iovec *iov,
Expand All @@ -305,8 +306,9 @@ static ssize_t smr_trecvv(struct fid_ep *ep_fid, const struct iovec *iov,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_trecv(ep->srx, iov, desc, count, src_addr,
context, tag, ignore, smr_ep_rx_flags(ep));
return util_srx_generic_trecv(&ep->srx->ep_fid, iov, desc, count,
src_addr, context, tag, ignore,
smr_ep_rx_flags(ep));
}

static ssize_t smr_trecvmsg(struct fid_ep *ep_fid,
Expand All @@ -316,10 +318,10 @@ static ssize_t smr_trecvmsg(struct fid_ep *ep_fid,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_trecv(ep->srx, msg->msg_iov, msg->desc,
msg->iov_count, msg->addr, msg->context,
msg->tag, msg->ignore,
flags | ep->util_ep.rx_msg_flags);
return util_srx_generic_trecv(&ep->srx->ep_fid, msg->msg_iov, msg->desc,
msg->iov_count, msg->addr, msg->context,
msg->tag, msg->ignore,
flags | ep->util_ep.rx_msg_flags);
}

static ssize_t smr_tsend(struct fid_ep *ep_fid, const void *buf, size_t len,
Expand Down
Loading

0 comments on commit 472dc49

Please sign in to comment.