Skip to content

Commit

Permalink
Dynamically find tag upper limit.
Browse files Browse the repository at this point in the history
  • Loading branch information
BradWhitlock committed Oct 18, 2024
1 parent 5cf37ce commit 038c261
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 74 deletions.
10 changes: 5 additions & 5 deletions src/libs/blueprint/conduit_blueprint_mpi_mesh_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,12 +436,12 @@ MatchQuery::execute()
C.set_logging_root("mpi_matchquery");
C.set_logging(true);
#endif
int query_tag = 1;
const int query_tag = 770;
for(size_t i = 0; i < allqueries.size(); i += ntuple_values)
{
int owner = allqueries[i];
int domain = allqueries[i + 1];
int query_domain = allqueries[i + 2];
const int owner = allqueries[i];
const int domain = allqueries[i + 1];
const int query_domain = allqueries[i + 2];

auto oppositeKey = std::make_pair(query_domain, domain);

Expand Down Expand Up @@ -668,7 +668,7 @@ compare_pointwise_impl(conduit::Node &mesh, const std::string &adjsetName,

// Iterate over each of the possible adjset relationships. Not all of these
// will have adjset groups.
const int tag = 1;
const int tag = 122;
for(int d0 = 0; d0 < maxDomains; d0++)
{
for(int d1 = d0 + 1; d1 < maxDomains; d1++)
Expand Down
253 changes: 184 additions & 69 deletions src/libs/relay/conduit_relay_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,66 +257,6 @@ mpi_dtype_to_conduit_dtype_id(MPI_Datatype dt)
return res;
}

//---------------------------------------------------------------------------//
/**
@brief Checks whether the input tag is invalid.
@param tag The MPI tag to check for validity.
@return True if the tag is wildly incorrect (negative).
@note Right now, this function only flags tags that are negative. Tags are
supposed to be less than or equal to MPI_TAG_UB too. This function does
not check for that - it's done by using safe_tag when tags are used.
*/
bool invalid_tag(int tag)
{
return tag < 0;
}

//---------------------------------------------------------------------------//
/**
@brief MPI tags can be in the range [0,MPI_TAG_UB]. The values are
implementation-dependent. If the tag is not in that range, return
the value for MPI_TAG_UB so it is safe to use with MPI functions.
@param tag The input tag.
@param comm The MPI communicator.
@return A tag value that is safe to use with MPI.
*/
int safe_tag(int tag, MPI_Comm comm)
{
// Get the tag upper bound for the communicator.
// MPI_Comm_get_attr with MPI_TAG_UB is not a very reliable function.
constexpr int backup_tag_limit = 4096;
int tag_ub = 0, flag = 0;
int mpi_error = MPI_Comm_get_attr(comm, MPI_TAG_UB, &tag_ub, &flag);
if(mpi_error == MPI_SUCCESS && flag != 0)
{
// Some MPI implementations give a negative upper bound; maybe they
// do not really support querying MPI_TAG_UB.
if(tag_ub < 0)
{
tag_ub = backup_tag_limit;
}
}
else
{
tag_ub = backup_tag_limit;
}

int newtag = std::max(0, tag);
if(newtag >= tag_ub)
{
// Some operations may emit a bunch of large tag numbers. If they fall
// outside the allowable range, it is probably better to spread them over
// the range than to just clamp them.
newtag = newtag % tag_ub;
}

return newtag;
}

//---------------------------------------------------------------------------//
/**
@brief Some MPI installations install an error handler that causes functions
Expand Down Expand Up @@ -372,8 +312,8 @@ class HandleMPICommError
}

/// MPI calls this function to handle errors.
static void handler(MPI_Comm *comm,
int *errcode,
static void handler(MPI_Comm */*comm*/,
int */*errcode*/,
...)
{
#if 0
Expand All @@ -386,7 +326,7 @@ class HandleMPICommError
va_end(argp);
#endif

std::cout << "handler: comm=" << *comm << ", errcode=" << *errcode << std::endl;
//std::cout << "handler: comm=" << *comm << ", errcode=" << *errcode << std::endl;

#if 0
// We could try emitting a Conduit error.
Expand All @@ -411,6 +351,179 @@ class HandleMPICommError
MPI_Errhandler m_newHandler;
};

//---------------------------------------------------------------------------//
/**
* @brief This class helps to determine MPI tag upper limits.
*/
class TagLimits
{
public:
/**
* @brief Return the tag upper limit.
*
* @param comm The MPI communicator.
*
* @return The tag upper limit.
*
* @note We probe to determine the value since query is not as reliable across MPI distributions.
*/
static int upper_bound(MPI_Comm comm)
{
return probe(comm);
}

private:
/**
* @brief Query MPI for the maximum tag value, set m_upper_bound on success.
*
* @param comm The MPI communicator.
*
* @return The tag upper bound or -1 on error.
*
* @note This is how we are supposed to be able to ask for the max tag value.
* However, this method does not seem reliable across MPIs and it is
* possible for the query to return values that still do not work in
* Isend/Irecv sometimes.
*/
static int query(MPI_Comm comm)
{
bool ok = false;
int tag_ub = 0, flag = 0;
int mpi_error = MPI_Comm_get_attr(comm, MPI_TAG_UB, &tag_ub, &flag);
if(mpi_error == MPI_SUCCESS && flag != 0)
{
if(tag_ub > 0)
{
ok = true;
}
}
return ok ? tag_ub : -1;
}

/**
* @brief Probe MPI to determine the max tag value and set m_upper_bound.
*
* @param comm The MPI communicator.
*
* @note MPI error handlers are installed that ignore problems, preventing the
* program from dying if the default handler is set to abort on error. The
* error handler is restored when leaving this function.
*/
static int probe(MPI_Comm comm)
{
// Temporarily override MPI error handler with a more benign one.
HandleMPICommError err(comm);
int tag = probeTagUpperBound(0, std::numeric_limits<int>::max(), comm);
return tag;
}

/**
* @brief Probe a range of tag values to determine if the range is valid.
*
* @param low The low tag value
* @param high The high tag value
* @param comm The MPI communicator.
*
* @return The max tag value.
*
* @note The rank sends a message to itself using a tag value. The result of that
* is used to narrow the range of tag values.
*/
static int probeTagUpperBound(int low, int high, MPI_Comm comm)
{
int tag;
if((high - low) < 2)
tag = low;
else
{
int rank;
MPI_Comm_rank(comm, &rank);

tag = (low + high) / 2;

// Try sending with the current tag.
int srcBuff = 0;
MPI_Request requests[2];
int mpi_error = MPI_Isend(&srcBuff,
1,
MPI_INT,
rank,
tag,
comm,
&requests[0]);
if(mpi_error == MPI_SUCCESS)
{
// It worked.
// Issue the recv.
int destBuff = 0;
MPI_Irecv(&destBuff,
1,
MPI_INT,
rank,
tag,
comm,
&requests[1]);

MPI_Status statuses[2];
MPI_Waitall(2, requests, statuses);

tag = probeTagUpperBound(tag, high, comm);
}
else
{
tag = probeTagUpperBound(low, tag, comm);
}
}
return tag;
}
};

//---------------------------------------------------------------------------//
/**
@brief Checks whether the input tag is invalid.
@param tag The MPI tag to check for validity.
@return True if the tag is wildly incorrect (negative).
@note Right now, this function only flags tags that are negative. Tags are
supposed to be less than or equal to MPI_TAG_UB too. This function does
not check for that - it's done by using safe_tag when tags are used.
*/
bool invalid_tag(int tag)
{
return tag < 0;
}

//---------------------------------------------------------------------------//
/**
@brief MPI tags can be in the range [0,MPI_TAG_UB]. The values are
implementation-dependent. If the tag is not in that range, return
the value for MPI_TAG_UB so it is safe to use with MPI functions.
@param tag The input tag.
@param comm The MPI communicator.
@return A tag value that is safe to use with MPI.
*/
int safe_tag(int tag, MPI_Comm comm)
{
static constexpr int UPPER_BOUND_NOT_SET = -1;
static int tag_upper_bound = UPPER_BOUND_NOT_SET;
if(tag_upper_bound == UPPER_BOUND_NOT_SET)
{
// The first time through, determine the upper bound.
tag_upper_bound = TagLimits::upper_bound(comm);
}

int newtag = std::max(0, tag);
if(newtag > tag_upper_bound)
{
newtag = tag_upper_bound;
}

return newtag;
}

//---------------------------------------------------------------------------//
int
send_using_schema(const Node &node, int dest, int tag, MPI_Comm comm)
Expand Down Expand Up @@ -2010,14 +2123,15 @@ communicate_using_schema::execute_internal()

// Send the serialized node data.
index_t msg_data_size = operations[i].node[1]->total_bytes_compact();
const int newtag = safe_tag(operations[i].tag, comm);
if(logging)
{
log << " MPI_Isend("
<< const_cast<void*>(operations[i].node[1]->data_ptr()) << ", "
<< msg_data_size << ", "
<< "MPI_BYTE, "
<< operations[i].rank << ", "
<< safe_tag(operations[i].tag, comm) << ", "
<< newtag << ", "
<< "comm, &requests[" << i << "]);" << std::endl;
}

Expand All @@ -2032,7 +2146,7 @@ communicate_using_schema::execute_internal()
static_cast<int>(msg_data_size),
MPI_BYTE,
operations[i].rank,
safe_tag(operations[i].tag, comm),
newtag,
comm,
&requests[i]);
CONDUIT_CHECK_MPI_ERROR(mpi_error);
Expand All @@ -2051,14 +2165,15 @@ communicate_using_schema::execute_internal()
if(operations[i].op == OP_RECV)
{
// Probe the message for its buffer size.
const int newtag = safe_tag(operations[i].tag, comm);
if(logging)
{
log << " MPI_Probe("
<< operations[i].rank << ", "
<< safe_tag(operations[i].tag, comm) << ", "
<< newtag << ", "
<< "comm, &statuses[" << i << "]);" << std::endl;
}
mpi_error = MPI_Probe(operations[i].rank, safe_tag(operations[i].tag, comm), comm, &statuses[i]);
mpi_error = MPI_Probe(operations[i].rank, newtag, comm, &statuses[i]);
CONDUIT_CHECK_MPI_ERROR(mpi_error);

int buffer_size = 0;
Expand All @@ -2080,7 +2195,7 @@ communicate_using_schema::execute_internal()
<< buffer_size << ", "
<< "MPI_BYTE, "
<< operations[i].rank << ", "
<< safe_tag(operations[i].tag, comm) << ", "
<< newtag << ", "
<< "comm, &requests[" << i << "]);" << std::endl;
}

Expand All @@ -2089,7 +2204,7 @@ communicate_using_schema::execute_internal()
buffer_size,
MPI_BYTE,
operations[i].rank,
safe_tag(operations[i].tag, comm),
newtag,
comm,
&requests[i]);
CONDUIT_CHECK_MPI_ERROR(mpi_error);
Expand Down

0 comments on commit 038c261

Please sign in to comment.