diff --git a/cpp/src/wholememory/communicator.cpp b/cpp/src/wholememory/communicator.cpp index d08fe0804..dabb9ba1b 100644 --- a/cpp/src/wholememory/communicator.cpp +++ b/cpp/src/wholememory/communicator.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -557,7 +558,7 @@ void exchange_rank_info(wholememory_comm_t wm_comm) wm_comm->clique_info.clique_rank = -1; wm_comm->clique_info.clique_rank_num = 0; - std::set clique_ids{}; + std::set clique_uuids{}; for (int r = 0; r < wm_comm->world_size; r++) { WHOLEMEMORY_CHECK(r == p_rank_info.get()[r].rank); @@ -583,16 +584,21 @@ void exchange_rank_info(wholememory_comm_t wm_comm) if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; } wm_comm->clique_info.clique_rank_num++; } - clique_ids.insert(p_rank_info.get()[r].fabric_info.cliqueId); + clique_uuids.insert( + std::string(reinterpret_cast(p_rank_info.get()[r].fabric_info.clusterUuid), + NVML_GPU_FABRIC_UUID_LEN)); #endif } #if CUDA_VERSION >= 12030 - wm_comm->clique_info.clique_num = clique_ids.size(); - int id = 0; - for (auto clique_id : clique_ids) { - if (clique_id == ri.fabric_info.cliqueId) { wm_comm->clique_info.clique_id = id; } + wm_comm->clique_info.clique_num = clique_uuids.size(); + + std::string uuid = std::string(reinterpret_cast(ri.fabric_info.clusterUuid), + NVML_GPU_FABRIC_UUID_LEN); + int id = 0; + for (auto clique_uuid : clique_uuids) { + if (clique_uuid == uuid) { wm_comm->clique_info.clique_id = id; } id++; }