Skip to content

Commit

Permalink
[WIP][VL] Support celeborn sort based shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
kerwin-zk committed May 14, 2024
1 parent e2ff3c6 commit 5513f38
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions cpp/core/shuffle/HashPartitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@

namespace gluten {

int32_t computePid(const int32_t* pidArr, int64_t i, int32_t numPartitions) {
auto pid = pidArr[i] % numPartitions;
#if defined(__x86_64__)
// force to generate ASM
__asm__(
"lea (%[num_partitions],%[pid],1),%[tmp]\n"
"test %[pid],%[pid]\n"
"cmovs %[tmp],%[pid]\n"
: [pid] "+r"(pid)
: [num_partitions] "r"(numPartitions), [tmp] "r"(0));
#else
if (pid < 0) {
pid += numPartitions_;
}
#endif
return pid;
}

arrow::Status gluten::HashPartitioner::compute(
const int32_t* pidArr,
const int64_t numRows,
Expand All @@ -28,20 +46,7 @@ arrow::Status gluten::HashPartitioner::compute(
std::fill(std::begin(partition2RowCount), std::end(partition2RowCount), 0);

for (auto i = 0; i < numRows; ++i) {
auto pid = pidArr[i] % numPartitions_;
#if defined(__x86_64__)
// force to generate ASM
__asm__(
"lea (%[num_partitions],%[pid],1),%[tmp]\n"
"test %[pid],%[pid]\n"
"cmovs %[tmp],%[pid]\n"
: [pid] "+r"(pid)
: [num_partitions] "r"(numPartitions_), [tmp] "r"(0));
#else
if (pid < 0) {
pid += numPartitions_;
}
#endif
auto pid = computePid(pidArr, i, numPartitions_);
row2partition[i] = pid;
}

Expand All @@ -59,20 +64,7 @@ arrow::Status gluten::HashPartitioner::compute(
std::unordered_map<int32_t, std::vector<int64_t>>& rowVectorIndexMap) {
auto index = static_cast<int64_t>(vectorIndex) << 32;
for (auto i = 0; i < numRows; ++i) {
auto pid = pidArr[i] % numPartitions_;
#if defined(__x86_64__)
// force to generate ASM
__asm__(
"lea (%[num_partitions],%[pid],1),%[tmp]\n"
"test %[pid],%[pid]\n"
"cmovs %[tmp],%[pid]\n"
: [pid] "+r"(pid)
: [num_partitions] "r"(numPartitions_), [tmp] "r"(0));
#else
if (pid < 0) {
pid += numPartitions_;
}
#endif
auto pid = computePid(pidArr, i, numPartitions_);
int64_t combined = index | (i & 0xFFFFFFFFLL);
auto& vec = rowVectorIndexMap[pid];
vec.push_back(combined);
Expand Down

0 comments on commit 5513f38

Please sign in to comment.