Skip to content

Commit

Permalink
coll/ucc: fall through when previous coll handler is NULL
Browse files Browse the repository at this point in the history
Signed-off-by: Tomislav Janjusic <[email protected]>
  • Loading branch information
janjust committed Jan 30, 2025
1 parent 07f1505 commit 2579537
Show file tree
Hide file tree
Showing 18 changed files with 70 additions and 84 deletions.
2 changes: 1 addition & 1 deletion 3rd-party/openpmix
Submodule openpmix updated 203 files
2 changes: 1 addition & 1 deletion 3rd-party/prrte
Submodule prrte updated 278 files
5 changes: 5 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ BEGIN_C_DECLS
"iallgatherv,ireduce,igather,igatherv,ireduce_scatter_block,"\
"ireduce_scatter,iscatterv,iscatter"

#define mca_coll_ucc_call_previous(__api, ucc_module, ...) \
(ucc_module->previous_ ## __api == NULL ? \
UCC_ERR_NOT_SUPPORTED : \
ucc_module->previous_ ## __api (__VA_ARGS__, ucc_module->previous_ ## __api ## _module))

typedef struct mca_coll_ucc_req {
ompi_request_t super;
ucc_coll_req_h ucc_req;
Expand Down
10 changes: 6 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ int mca_coll_ucc_allgather(const void *sbuf, int scount, struct ompi_datatype_t
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allgather");
return ucc_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, ucc_module->previous_allgather_module);

return mca_coll_ucc_call_previous(allgather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
}

int mca_coll_ucc_iallgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -104,6 +105,7 @@ int mca_coll_ucc_iallgather(const void *sbuf, int scount, struct ompi_datatype_t
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, request, ucc_module->previous_iallgather_module);

return mca_coll_ucc_call_previous(iallgather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ int mca_coll_ucc_allgatherv(const void *sbuf, int scount,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allgatherv");
return ucc_module->previous_allgatherv(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, ucc_module->previous_allgatherv_module);
return mca_coll_ucc_call_previous(allgatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, comm);
}

int mca_coll_ucc_iallgatherv(const void *sbuf, int scount,
Expand Down Expand Up @@ -108,7 +107,6 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iallgatherv(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, request, ucc_module->previous_iallgatherv_module);
return mca_coll_ucc_call_previous(iallgatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, comm, request);
}
8 changes: 4 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, int count,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allreduce");
return ucc_module->previous_allreduce(sbuf, rbuf, count, dtype, op,
comm, ucc_module->previous_allreduce_module);
return mca_coll_ucc_call_previous(allreduce, ucc_module,
sbuf, rbuf, count, dtype, op, comm);
}

int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
Expand All @@ -100,6 +100,6 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op,
comm, request, ucc_module->previous_iallreduce_module);
return mca_coll_ucc_call_previous(iallreduce, ucc_module,
sbuf, rbuf, count, dtype, op, comm, request);
}
8 changes: 4 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ int mca_coll_ucc_alltoall(const void *sbuf, int scount, struct ompi_datatype_t *
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback alltoall");
return ucc_module->previous_alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, ucc_module->previous_alltoall_module);
return mca_coll_ucc_call_previous(alltoall, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
}

int mca_coll_ucc_ialltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -104,6 +104,6 @@ int mca_coll_ucc_ialltoall(const void *sbuf, int scount, struct ompi_datatype_t
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ialltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, request, ucc_module->previous_ialltoall_module);
return mca_coll_ucc_call_previous(ialltoall, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_alltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ int mca_coll_ucc_alltoallv(const void *sbuf, const int *scounts,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback alltoallv");
return ucc_module->previous_alltoallv(sbuf, scounts, sdisps, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, ucc_module->previous_alltoallv_module);
return mca_coll_ucc_call_previous(alltoallv, ucc_module,
sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, rdtype, comm);
}

int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts,
Expand Down Expand Up @@ -109,7 +108,6 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ialltoallv(sbuf, scounts, sdisps, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, request, ucc_module->previous_ialltoallv_module);
return mca_coll_ucc_call_previous(ialltoallv, ucc_module,
sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, rdtype, comm, request);
}
5 changes: 2 additions & 3 deletions ompi/mca/coll/ucc/coll_ucc_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback barrier");
return ucc_module->previous_barrier(comm, ucc_module->previous_barrier_module);
return mca_coll_ucc_call_previous(barrier, ucc_module, comm);
}

int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm,
Expand All @@ -58,6 +58,5 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ibarrier(comm, request,
ucc_module->previous_ibarrier_module);
return mca_coll_ucc_call_previous(ibarrier, ucc_module, comm, request);
}
9 changes: 5 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ int mca_coll_ucc_bcast(void *buf, int count, struct ompi_datatype_t *dtype,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback bcast");
return ucc_module->previous_bcast(buf, count, dtype, root,
comm, ucc_module->previous_bcast_module);
return mca_coll_ucc_call_previous(bcast, ucc_module,
buf, count, dtype, root, comm);

}

int mca_coll_ucc_ibcast(void *buf, int count, struct ompi_datatype_t *dtype,
Expand All @@ -76,6 +77,6 @@ int mca_coll_ucc_ibcast(void *buf, int count, struct ompi_datatype_t *dtype,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ibcast(buf, count, dtype, root,
comm, request, ucc_module->previous_ibcast_module);
return mca_coll_ucc_call_previous(ibcast, ucc_module,
buf, count, dtype, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_gather.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ int mca_coll_ucc_gather(const void *sbuf, int scount, struct ompi_datatype_t *sd
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback gather");
return ucc_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm,
ucc_module->previous_gather_module);
return mca_coll_ucc_call_previous(gather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -119,7 +118,6 @@ int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *s
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_igather(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm, request,
ucc_module->previous_igather_module);
return mca_coll_ucc_call_previous(igather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_gatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ int mca_coll_ucc_gatherv(const void *sbuf, int scount, struct ompi_datatype_t *s
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback gatherv");
return ucc_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts,
disps, rdtype, root, comm,
ucc_module->previous_gatherv_module);
return mca_coll_ucc_call_previous(gatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, comm);
}

int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -115,7 +114,6 @@ int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_igatherv(sbuf, scount, sdtype, rbuf, rcounts,
disps, rdtype, root, comm, request,
ucc_module->previous_igatherv_module);
return mca_coll_ucc_call_previous(igatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, comm, request);
}
16 changes: 6 additions & 10 deletions ompi/mca/coll/ucc/coll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,12 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
#define SAVE_PREV_COLL_API(__api) do { \
ucc_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
ucc_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
return OMPI_ERROR; \
if (comm->c_coll->coll_ ## __api && comm->c_coll->coll_ ## __api ## _module) { \
OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \
} \
OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \
} while(0)

static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
static void mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
{
ompi_communicator_t *comm = ucc_module->comm;
SAVE_PREV_COLL_API(allreduce);
Expand Down Expand Up @@ -178,7 +177,6 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
SAVE_PREV_COLL_API(iscatterv);
SAVE_PREV_COLL_API(scatter);
SAVE_PREV_COLL_API(iscatter);
return OMPI_SUCCESS;
}

/*
Expand Down Expand Up @@ -470,11 +468,6 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
(void*)comm, (long long unsigned)team_params.id,
ompi_comm_size(comm));

if (OMPI_SUCCESS != mca_coll_ucc_save_coll_handlers(ucc_module)){
UCC_ERROR("mca_coll_ucc_save_coll_handlers failed");
goto err;
}

if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1,
&team_params, &ucc_module->ucc_team)) {
UCC_ERROR("ucc_team_create_post failed");
Expand All @@ -495,6 +488,9 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
UCC_ERROR("ucc ompi_attr_set_c failed");
goto err;
}

mca_coll_ucc_save_coll_handlers(ucc_module);

return OMPI_SUCCESS;

err:
Expand Down
8 changes: 4 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, int count,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce");
return ucc_module->previous_reduce(sbuf, rbuf, count, dtype, op, root,
comm, ucc_module->previous_reduce_module);
return mca_coll_ucc_call_previous(reduce, ucc_module,
sbuf, rbuf, count, dtype, op, root, comm);
}

int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, int count,
Expand All @@ -103,6 +103,6 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, int count,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce(sbuf, rbuf, count, dtype, op, root,
comm, request, ucc_module->previous_ireduce_module);
return mca_coll_ucc_call_previous(ireduce, ucc_module,
sbuf, rbuf, count, dtype, op, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce_scatter");
return ucc_module->previous_reduce_scatter(sbuf, rbuf, rcounts, dtype, op,
comm,
ucc_module->previous_reduce_scatter_module);
return mca_coll_ucc_call_previous(reduce_scatter, ucc_module,
sbuf, rbuf, rcounts, dtype, op, comm);
}

int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
Expand All @@ -115,7 +114,6 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, const int *rcount
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce_scatter(sbuf, rbuf, rcounts, dtype, op,
comm, request,
ucc_module->previous_ireduce_scatter_module);
return mca_coll_ucc_call_previous(ireduce_scatter, ucc_module,
sbuf, rbuf, rcounts, dtype, op, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce_scatter_block");
return ucc_module->previous_reduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm,
ucc_module->previous_reduce_scatter_block_module);
return mca_coll_ucc_call_previous(reduce_scatter_block, ucc_module,
sbuf, rbuf, rcount, dtype, op, comm);
}

int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
Expand All @@ -111,7 +110,6 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm, request,
ucc_module->previous_ireduce_scatter_block_module);
return mca_coll_ucc_call_previous(ireduce_scatter_block, ucc_module,
sbuf, rbuf, rcount, dtype, op, comm, request);
}
11 changes: 4 additions & 7 deletions ompi/mca/coll/ucc/coll_ucc_scatter.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ int mca_coll_ucc_scatter(const void *sbuf, int scount,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback scatter");
return ucc_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm,
ucc_module->previous_scatter_module);

return mca_coll_ucc_call_previous(scatter, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_iscatter(const void *sbuf, int scount,
Expand All @@ -117,7 +115,6 @@ int mca_coll_ucc_iscatter(const void *sbuf, int scount,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iscatter(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm, request,
ucc_module->previous_iscatter_module);
return mca_coll_ucc_call_previous(iscatter, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_scatterv.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ int mca_coll_ucc_scatterv(const void *sbuf, const int *scounts,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback scatterv");
return ucc_module->previous_scatterv(sbuf, scounts, disps, sdtype, rbuf,
rcount, rdtype, root, comm,
ucc_module->previous_scatterv_module);
return mca_coll_ucc_call_previous(scatterv, ucc_module,
sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts,
Expand Down Expand Up @@ -120,7 +119,6 @@ int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iscatterv(sbuf, scounts, disps, sdtype, rbuf,
rcount, rdtype, root, comm, request,
ucc_module->previous_iscatterv_module);
return mca_coll_ucc_call_previous(iscatterv, ucc_module,
sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, root, comm, request);
}

0 comments on commit 2579537

Please sign in to comment.