Skip to content

Commit

Permalink
feat: Add key/value support to radix sort algorithm in breeze
Browse files Browse the repository at this point in the history
  • Loading branch information
David Reveman committed Dec 4, 2024
1 parent ef1bffa commit 50bd444
Show file tree
Hide file tree
Showing 13 changed files with 718 additions and 207 deletions.
96 changes: 69 additions & 27 deletions velox/experimental/breeze/breeze/algorithms/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,20 @@ struct SortBlockType<unsigned> {
}
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS, typename T>
template <typename KeyT, typename ValueT, int BLOCK_ITEMS>
struct KeyValueScatterType {
KeyT keys[BLOCK_ITEMS];
ValueT values[BLOCK_ITEMS];
};

// partial specialization where ValueT is NullType
template <typename KeyT, int BLOCK_ITEMS>
struct KeyValueScatterType<KeyT, utils::NullType, BLOCK_ITEMS> {
KeyT keys[BLOCK_ITEMS];
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS,
typename KeyT, typename ValueT>
struct DeviceRadixSort {
enum {
BLOCK_THREADS = PlatformT::BLOCK_THREADS,
Expand All @@ -164,25 +177,26 @@ struct DeviceRadixSort {
unsigned global_offsets[NUM_BINS];
int block_idx;
};
struct {
T items[BLOCK_ITEMS];
} scatter;
KeyValueScatterType<KeyT, ValueT, BLOCK_ITEMS> scatter;
};
};

template <typename BlockT, typename InputSlice, typename OffsetSlice,
typename OutputSlice, typename BlockIdxSlice, typename BlockSlice,
typename ScratchSlice>
static ATTR void Sort(PlatformT p, const InputSlice in,
template <typename BlockT, typename KeyInputSlice, typename ValueInputSlice,
typename OffsetSlice, typename KeyOutputSlice,
typename ValueOutputSlice, typename BlockIdxSlice,
typename BlockSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, const KeyInputSlice in_keys,
const ValueInputSlice in_values,
const OffsetSlice in_offsets, int start_bit,
int num_pass_bits, OutputSlice out,
int num_pass_bits, KeyOutputSlice out_keys,
ValueOutputSlice out_values,
BlockIdxSlice next_block_idx, BlockSlice blocks,
ScratchSlice scratch, int num_items) {
using namespace functions;
using namespace utils;

enum {
END_BIT = sizeof(T) * /*BITS_PER_BYTE=*/8,
END_BIT = sizeof(KeyT) * /*BITS_PER_BYTE=*/8,
WARP_THREADS = PlatformT::WARP_THREADS,
NUM_WARPS = BLOCK_THREADS / WARP_THREADS,
WARP_ITEMS = WARP_THREADS * ITEMS_PER_THREAD,
Expand Down Expand Up @@ -211,19 +225,19 @@ struct DeviceRadixSort {
// load items into warp-striped arrangement after initializing all values
// to all bits set as that allows us to always use the fast-path version
// radix rank function
T items[ITEMS_PER_THREAD];
KeyT keys[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = NumericLimits<T>::max();
keys[i] = NumericLimits<KeyT>::max();
}
const InputSlice it = in.subslice(block.offset);
const KeyInputSlice it = in_keys.subslice(block.offset);
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, it, make_slice<THREAD, WARP_STRIPED>(items), block.num_items);
p, it, make_slice<THREAD, WARP_STRIPED>(keys), block.num_items);

// convert items to bit ordered representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::to_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::to_bit_ordered(keys[i]);
}

// determine stable rank for each item
Expand All @@ -232,18 +246,31 @@ struct DeviceRadixSort {
int exclusive_scan[BINS_PER_THREAD];
BlockRadixRankT::Rank(
p,
make_bitfield_extractor(make_slice<THREAD, WARP_STRIPED>(items),
make_bitfield_extractor(make_slice<THREAD, WARP_STRIPED>(keys),
start_bit, num_pass_bits),
make_slice<THREAD, WARP_STRIPED>(ranks), make_slice(histogram),
blocks.subslice(block_idx * NUM_BINS), make_slice(exclusive_scan),
make_slice<SHARED>(&scratch->rank));
p.syncthreads();

// scatter items by storing them in shared memory using ranks
// scatter keys by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<THREAD, WARP_STRIPED>(items),
p, make_slice<THREAD, WARP_STRIPED>(keys),
make_slice<THREAD, WARP_STRIPED>(ranks),
make_slice<SHARED>(scratch->scatter.items));
make_slice<SHARED>(scratch->scatter.keys));

// load and scatter optional values
ValueT values[ITEMS_PER_THREAD];
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
const ValueInputSlice it = in_values.subslice(block.offset);
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, it, make_slice<THREAD, WARP_STRIPED>(values), block.num_items);
// scatter values by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<THREAD, WARP_STRIPED>(values),
make_slice<THREAD, WARP_STRIPED>(ranks),
make_slice<SHARED>(scratch->scatter.values));
}
p.syncthreads();

// first block loads initial global offsets from input and other blocks
Expand Down Expand Up @@ -334,9 +361,16 @@ struct DeviceRadixSort {
global_offsets[i] -= exclusive_scan[i];
}

// gather scattered items from scratch
// gather scattered keys from scratch
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.items), make_slice(items));
p, make_slice<SHARED>(scratch->scatter.keys), make_slice(keys));

// gather optional scattered values from scratch
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
// gather scattered values from scratch
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.values), make_slice(values));
}
p.syncthreads();

// store global offsets in scratch
Expand All @@ -349,7 +383,7 @@ struct DeviceRadixSort {
unsigned out_offsets[ITEMS_PER_THREAD];
BlockLoadFrom<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->global_offsets),
make_bitfield_extractor(make_slice(items), start_bit, num_pass_bits),
make_bitfield_extractor(make_slice(keys), start_bit, num_pass_bits),
make_slice(out_offsets));

// add item index (same as rank after scatter/gather) to output offsets
Expand All @@ -358,15 +392,23 @@ struct DeviceRadixSort {
out_offsets[i] += p.thread_idx() + i * BLOCK_THREADS;
}

// convert items back to original representation
// convert keys back to original representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::from_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::from_bit_ordered(keys[i]);
}

// store gathered items in global memory using output offsets
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice(items), make_slice(out_offsets), out, block.num_items);
// store gathered keys in global memory using output offsets
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(p, make_slice(keys),
make_slice(out_offsets),
out_keys, block.num_items);

// store gathered values in global memory using output offsets
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice(values), make_slice(out_offsets), out_values,
block.num_items);
}
}
};

Expand Down
Loading

0 comments on commit 50bd444

Please sign in to comment.