Skip to content

Commit

Permalink
[VL] Row-based sort shuffle follow-up (minor) (#6628)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Jul 30, 2024
1 parent 73fd854 commit c0d633d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 69 deletions.
10 changes: 6 additions & 4 deletions cpp/velox/shuffle/RadixSort.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

namespace gluten {

template <typename Element>
// Spark radix sort implementation. This implementation is for shuffle sort only as it removes unused
// params (desc, signed) in shuffle.
// https://github.com/apache/spark/blob/308669fc301916837bacb7c3ec1ecef93190c094/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java#L25
class RadixSort {
public:
/**
Expand All @@ -39,7 +41,7 @@ class RadixSort {
* @return The starting index of the sorted data within the given array. We return this instead
* of always copying the data back to position zero for efficiency.
*/
static int32_t sort(Element* array, size_t size, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
static int32_t sort(uint64_t* array, size_t size, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
assert(startByteIndex >= 0 && "startByteIndex should >= 0");
assert(endByteIndex <= 7 && "endByteIndex should <= 7");
assert(endByteIndex > startByteIndex);
Expand Down Expand Up @@ -75,7 +77,7 @@ class RadixSort {
* @param outIndex the starting index where sorted output data should be written.
*/
static void sortAtByte(
Element* array,
uint64_t* array,
int64_t numRecords,
std::vector<int64_t>& counts,
int32_t byteIdx,
Expand Down Expand Up @@ -103,7 +105,7 @@ class RadixSort {
* significant byte. If the byte does not need sorting the vector entry will be empty.
*/
static std::vector<std::vector<int64_t>>
getCounts(Element* array, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
getCounts(uint64_t* array, int64_t numRecords, int32_t startByteIndex, int32_t endByteIndex) {
std::vector<std::vector<int64_t>> counts;
counts.resize(8);

Expand Down
7 changes: 2 additions & 5 deletions cpp/velox/shuffle/VeloxSortShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,9 @@ arrow::Status VeloxSortShuffleWriter::evictAllPartitions() {
{
ScopedTimer timer(&sortTime_);
if (options_.useRadixSort) {
begin = RadixSort<uint64_t>::sort(
arrayPtr_, arraySize_, numRecords, kPartitionIdStartByteIndex, kPartitionIdEndByteIndex);
begin = RadixSort::sort(arrayPtr_, arraySize_, numRecords, kPartitionIdStartByteIndex, kPartitionIdEndByteIndex);
} else {
auto ptr = arrayPtr_;
qsort(ptr, numRecords, sizeof(uint64_t), compare);
(void)ptr;
std::sort(arrayPtr_, arrayPtr_ + numRecords);
}
}

Expand Down
Loading

0 comments on commit c0d633d

Please sign in to comment.