Skip to content

Commit

Permalink
rocr: Allow 0/NULL/invalid signal handles for wait operations to be n…
Browse files Browse the repository at this point in the history
…o-op

Remove hard assertions for signal validation on hsa_amd_signal_wait_* operations, instead ignore 0/NULL/invalid signals in the dependency condition evaluation to align with HSA specs for barrier-AND and barrier-OR packets.
satisfying_values of 0/NULL/invalid signals are set to 0 for hsa_amd_signal_wait_all.

Signed-off-by: zichguan-amd <[email protected]>
  • Loading branch information
zichguan-amd committed Feb 20, 2025
1 parent 107b48f commit 1e4b46a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
48 changes: 34 additions & 14 deletions runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,20 +580,32 @@ uint32_t hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* hsa_signal
assert(false && "hsa_amd_signal_wait_all called while not initialized.");
return 0;
}
// Do not check for signal invalidation. Invalidation may occur during async
// signal handler loop and is not an error.
for (int i = 0; i < signal_count; ++i)
assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() &&
"Invalid signal.");

// Treat NULL and invalid signals as already satisfied their condition and skip them
std::vector<hsa_signal_t> valid_signals;
std::vector<uint32_t> valid_signal_ids;
for (uint32_t i = 0; i < signal_count; i++){
if (hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid()){
valid_signals.emplace_back(hsa_signals[i]);
valid_signal_ids.emplace_back(i);
}
}

uint32_t valid_signal_count = valid_signals.size();

std::vector<hsa_signal_value_t> satisfying_values_vec;
satisfying_values_vec.resize(signal_count);
satisfying_values_vec.resize(valid_signal_count);
uint32_t first_satysifying_signal_idx =
core::Signal::WaitMultiple(signal_count, hsa_signals, conds, values, timeout_hint, wait_hint,
core::Signal::WaitMultiple(valid_signal_count, valid_signals.data(), conds, values, timeout_hint, wait_hint,
satisfying_values_vec, true);

if (satisfying_values) {
std::copy(satisfying_values_vec.begin(), satisfying_values_vec.end(), satisfying_values);
// Set 0 as satisfying value for NULL and invalid signals
std::vector<hsa_signal_value_t> satisfying_values_vec_result(signal_count, 0);
for (uint32_t i = 0; i < valid_signal_count; i++){
satisfying_values_vec_result[valid_signal_ids[i]] = satisfying_values_vec[i];
}
std::copy(satisfying_values_vec_result.begin(), satisfying_values_vec_result.end(), satisfying_values);
}

return first_satysifying_signal_idx;
Expand All @@ -609,16 +621,24 @@ uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* hsa_signal
assert(false && "hsa_amd_signal_wait_any called while not initialized.");
return uint32_t(0);
}
// Do not check for signal invalidation. Invalidation may occur during async
// signal handler loop and is not an error.
for (uint i = 0; i < signal_count; i++)
assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() &&
"Invalid signal.");

// Ignore NULL and invalid signals
std::vector<hsa_signal_t> valid_signals;
std::vector<uint32_t> valid_signal_ids;
for (uint32_t i = 0; i < signal_count; i++){
if (hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid()){
valid_signals.emplace_back(hsa_signals[i]);
valid_signal_ids.emplace_back(i);
}
}

std::vector<hsa_signal_value_t> satisfying_value_vec(1);
uint32_t satisfying_signal_idx =
core::Signal::WaitMultiple(signal_count, hsa_signals, conds, values, timeout_hint, wait_hint,
core::Signal::WaitMultiple(valid_signals.size(), valid_signals.data(), conds, values, timeout_hint, wait_hint,
satisfying_value_vec, false);

// Map back the index
satisfying_signal_idx = valid_signal_ids[satisfying_signal_idx];

if (satisfying_value) *satisfying_value = satisfying_value_vec.at(0);

Expand Down
10 changes: 6 additions & 4 deletions runtime/hsa-runtime/inc/hsa_ext_amd.h
Original file line number Diff line number Diff line change
Expand Up @@ -1209,8 +1209,9 @@ hsa_status_t HSA_API
* @details Allows waiting for all of several signal and condition pairs to be
* satisfied. The function returns 0 if all signals met their conditions and -1
* on a timeout. The value of each signal's satisfying value is returned in
* satisfying_value unless satisfying_value is nullptr. This function provides
* only relaxed memory semantics.
* satisfying_value unless satisfying_value is nullptr. NULL and invalid signals
* are considered to have value 0 and their conditions already satisfied. This
* function provides only relaxed memory semantics.
*/
uint32_t HSA_API hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* signals,
hsa_signal_condition_t* conds, hsa_signal_value_t* values,
Expand All @@ -1223,8 +1224,9 @@ uint32_t HSA_API hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* si
* @details Allows waiting for any of several signal and conditions pairs to be
* satisfied. The function returns the index into the list of signals of the
* first satisfying signal-condition pair. The value of the satisfying signal's
* value is returned in satisfying_value unless satisfying_value is NULL. This
* function provides only relaxed memory semantics.
* value is returned in satisfying_value unless satisfying_value is NULL. NULL
* and invalid signals are ignored. This function provides only relaxed memory
* semantics.
*/
uint32_t HSA_API
hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* signals,
Expand Down

0 comments on commit 1e4b46a

Please sign in to comment.