Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add key/value support to radix sort in breeze #11733

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 69 additions & 27 deletions velox/experimental/breeze/breeze/algorithms/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,20 @@ struct SortBlockType<unsigned> {
}
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS, typename T>
template <typename KeyT, typename ValueT, int BLOCK_ITEMS>
struct KeyValueScatterType {
KeyT keys[BLOCK_ITEMS];
ValueT values[BLOCK_ITEMS];
};

// partial specialization where ValueT is NullType
template <typename KeyT, int BLOCK_ITEMS>
struct KeyValueScatterType<KeyT, utils::NullType, BLOCK_ITEMS> {
KeyT keys[BLOCK_ITEMS];
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS,
typename KeyT, typename ValueT>
struct DeviceRadixSort {
enum {
BLOCK_THREADS = PlatformT::BLOCK_THREADS,
Expand All @@ -164,25 +177,26 @@ struct DeviceRadixSort {
unsigned global_offsets[NUM_BINS];
int block_idx;
};
struct {
T items[BLOCK_ITEMS];
} scatter;
KeyValueScatterType<KeyT, ValueT, BLOCK_ITEMS> scatter;
};
};

template <typename BlockT, typename InputSlice, typename OffsetSlice,
typename OutputSlice, typename BlockIdxSlice, typename BlockSlice,
typename ScratchSlice>
static ATTR void Sort(PlatformT p, const InputSlice in,
template <typename BlockT, typename KeyInputSlice, typename ValueInputSlice,
typename OffsetSlice, typename KeyOutputSlice,
typename ValueOutputSlice, typename BlockIdxSlice,
typename BlockSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, const KeyInputSlice in_keys,
const ValueInputSlice in_values,
const OffsetSlice in_offsets, int start_bit,
int num_pass_bits, OutputSlice out,
int num_pass_bits, KeyOutputSlice out_keys,
ValueOutputSlice out_values,
BlockIdxSlice next_block_idx, BlockSlice blocks,
ScratchSlice scratch, int num_items) {
using namespace functions;
using namespace utils;

enum {
END_BIT = sizeof(T) * /*BITS_PER_BYTE=*/8,
END_BIT = sizeof(KeyT) * /*BITS_PER_BYTE=*/8,
WARP_THREADS = PlatformT::WARP_THREADS,
NUM_WARPS = BLOCK_THREADS / WARP_THREADS,
WARP_ITEMS = WARP_THREADS * ITEMS_PER_THREAD,
Expand Down Expand Up @@ -211,19 +225,19 @@ struct DeviceRadixSort {
// load items into warp-striped arrangement after initializing all values
// to all bits set as that allows us to always use the fast-path version
// radix rank function
T items[ITEMS_PER_THREAD];
KeyT keys[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = NumericLimits<T>::max();
keys[i] = NumericLimits<KeyT>::max();
}
const InputSlice it = in.subslice(block.offset);
const KeyInputSlice it = in_keys.subslice(block.offset);
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, it, make_slice<THREAD, WARP_STRIPED>(items), block.num_items);
p, it, make_slice<THREAD, WARP_STRIPED>(keys), block.num_items);

// convert items to bit ordered representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::to_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::to_bit_ordered(keys[i]);
}

// determine stable rank for each item
Expand All @@ -232,18 +246,31 @@ struct DeviceRadixSort {
int exclusive_scan[BINS_PER_THREAD];
BlockRadixRankT::Rank(
p,
make_bitfield_extractor(make_slice<THREAD, WARP_STRIPED>(items),
make_bitfield_extractor(make_slice<THREAD, WARP_STRIPED>(keys),
start_bit, num_pass_bits),
make_slice<THREAD, WARP_STRIPED>(ranks), make_slice(histogram),
blocks.subslice(block_idx * NUM_BINS), make_slice(exclusive_scan),
make_slice<SHARED>(&scratch->rank));
p.syncthreads();

// scatter items by storing them in shared memory using ranks
// scatter keys by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<THREAD, WARP_STRIPED>(items),
p, make_slice<THREAD, WARP_STRIPED>(keys),
make_slice<THREAD, WARP_STRIPED>(ranks),
make_slice<SHARED>(scratch->scatter.items));
make_slice<SHARED>(scratch->scatter.keys));

// load and scatter optional values
ValueT values[ITEMS_PER_THREAD];
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
const ValueInputSlice it = in_values.subslice(block.offset);
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, it, make_slice<THREAD, WARP_STRIPED>(values), block.num_items);
// scatter values by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<THREAD, WARP_STRIPED>(values),
make_slice<THREAD, WARP_STRIPED>(ranks),
make_slice<SHARED>(scratch->scatter.values));
}
p.syncthreads();

// first block loads initial global offsets from input and other blocks
Expand Down Expand Up @@ -334,9 +361,16 @@ struct DeviceRadixSort {
global_offsets[i] -= exclusive_scan[i];
}

// gather scattered items from scratch
// gather scattered keys from scratch
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.items), make_slice(items));
p, make_slice<SHARED>(scratch->scatter.keys), make_slice(keys));

// gather optional scattered values from scratch
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
// gather scattered values from scratch
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.values), make_slice(values));
}
p.syncthreads();

// store global offsets in scratch
Expand All @@ -349,7 +383,7 @@ struct DeviceRadixSort {
unsigned out_offsets[ITEMS_PER_THREAD];
BlockLoadFrom<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->global_offsets),
make_bitfield_extractor(make_slice(items), start_bit, num_pass_bits),
make_bitfield_extractor(make_slice(keys), start_bit, num_pass_bits),
make_slice(out_offsets));

// add item index (same as rank after scatter/gather) to output offsets
Expand All @@ -358,15 +392,23 @@ struct DeviceRadixSort {
out_offsets[i] += p.thread_idx() + i * BLOCK_THREADS;
}

// convert items back to original representation
// convert keys back to original representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::from_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::from_bit_ordered(keys[i]);
}

// store gathered items in global memory using output offsets
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice(items), make_slice(out_offsets), out, block.num_items);
// store gathered keys in global memory using output offsets
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(p, make_slice(keys),
make_slice(out_offsets),
out_keys, block.num_items);

// store gathered values in global memory using output offsets
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice(values), make_slice(out_offsets), out_values,
block.num_items);
}
}
};

Expand Down
104 changes: 76 additions & 28 deletions velox/experimental/breeze/breeze/functions/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ struct BlockRadixRank {
}
};

template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS, typename T>
template <typename PlatformT, int ITEMS_PER_THREAD, int RADIX_BITS,
typename KeyT, typename ValueT>
struct BlockRadixSort {
enum {
BLOCK_THREADS = PlatformT::BLOCK_THREADS,
WARP_THREADS = PlatformT::WARP_THREADS,
END_BIT = sizeof(T) * /*BITS_PER_BYTE=*/8,
END_BIT = sizeof(KeyT) * /*BITS_PER_BYTE=*/8,
NUM_PASSES = utils::DivideAndRoundUp<END_BIT, RADIX_BITS>::VALUE,
NUM_BINS = 1 << RADIX_BITS,
BINS_PER_THREAD = utils::DivideAndRoundUp<NUM_BINS, BLOCK_THREADS>::VALUE,
Expand All @@ -284,21 +285,27 @@ struct BlockRadixSort {
union {
typename BlockRadixRank<PlatformT, ITEMS_PER_THREAD, RADIX_BITS>::Scratch
rank;
T scatter[BLOCK_THREADS * ITEMS_PER_THREAD];
struct {
union {
KeyT keys[BLOCK_THREADS * ITEMS_PER_THREAD];
ValueT values[BLOCK_THREADS * ITEMS_PER_THREAD];
};
} scatter;
};
};

template <typename ItemSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, ItemSlice items, ScratchSlice scratch) {
template <typename KeySlice, typename ValueSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, KeySlice keys, ValueSlice values,
ScratchSlice scratch) {
using namespace utils;

static_assert(IsSame<typename ScratchSlice::data_type, Scratch>::VALUE,
"incorrect scratch type");

// convert items to bit ordered representation if needed
// convert keys to bit ordered representation if needed
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::to_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::to_bit_ordered(keys[i]);
}

// start from LSB and loop until no bits are left
Expand All @@ -307,36 +314,49 @@ struct BlockRadixSort {
int start_bit = i * RADIX_BITS;
int num_pass_bits = p.min(RADIX_BITS, END_BIT - start_bit);

// determine stable rank for each item
// determine stable rank for each key
int ranks[ITEMS_PER_THREAD];
BlockRadixRank<PlatformT, ITEMS_PER_THREAD, RADIX_BITS>::Rank(
p, make_bitfield_extractor(items, start_bit, num_pass_bits),
make_slice<THREAD, ItemSlice::ARRANGEMENT>(ranks),
p, make_bitfield_extractor(keys, start_bit, num_pass_bits),
make_slice<THREAD, KeySlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(&scratch->rank));
p.syncthreads();

// scatter items by storing them in shared memory using ranks
// scatter keys by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, items, make_slice<THREAD, ItemSlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(scratch->scatter));
p, keys, make_slice<THREAD, KeySlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(scratch->scatter.keys));
p.syncthreads();

// load scattered items
// load scattered keys
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter), items);
p, make_slice<SHARED>(scratch->scatter.keys), keys);
p.syncthreads();

if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
// scatter values by storing them in scratch using ranks
BlockStoreAt<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, values, make_slice<THREAD, KeySlice::ARRANGEMENT>(ranks),
make_slice<SHARED>(scratch->scatter.values));
p.syncthreads();

// load scattered values
BlockLoad<BLOCK_THREADS, ITEMS_PER_THREAD>(
p, make_slice<SHARED>(scratch->scatter.values), values);
p.syncthreads();
}
}

// convert items back to original representation
// convert keys back to original representation
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
items[i] = RadixSortTraits<T>::from_bit_ordered(items[i]);
keys[i] = RadixSortTraits<KeyT>::from_bit_ordered(keys[i]);
}
}

template <typename ItemSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, ItemSlice items, ScratchSlice scratch,
int num_items) {
template <typename KeySlice, typename ValueSlice, typename ScratchSlice>
static ATTR void Sort(PlatformT p, KeySlice keys, ValueSlice values,
ScratchSlice scratch, int num_items) {
using namespace utils;

enum {
Expand All @@ -345,31 +365,59 @@ struct BlockRadixSort {

static_assert((BLOCK_THREADS % WARP_THREADS) == 0,
"BLOCK_THREADS must be a multiple of WARP_THREADS");
static_assert(ItemSlice::ARRANGEMENT == WARP_STRIPED,
static_assert(KeySlice::ARRANGEMENT == WARP_STRIPED,
"input must have warp-striped arrangement");

int thread_offset = p.warp_idx() * WARP_ITEMS + p.lane_idx();

// pad items with values that have all bits set
T padded_items[ITEMS_PER_THREAD];
// pad keys with values that have all bits set
KeyT padded_keys[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
padded_items[i] = NumericLimits<T>::max();
padded_keys[i] = NumericLimits<KeyT>::max();
}
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
padded_items[i] = items[i];
padded_keys[i] = keys[i];
}
}

Sort(p, make_slice<THREAD, WARP_STRIPED>(padded_items), scratch);
if constexpr (IsDifferent<ValueT, NullType>::VALUE) {
static_assert(ValueSlice::ARRANGEMENT == WARP_STRIPED,
"input must have warp-striped arrangement");

ValueT padded_values[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
padded_values[i] = static_cast<ValueT>(0);
}
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
padded_values[i] = values[i];
}
}

Sort(p, make_slice<THREAD, WARP_STRIPED>(padded_keys),
make_slice<THREAD, WARP_STRIPED>(padded_values), scratch);

// copy valid values back
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
values[i] = padded_values[i];
}
}
} else {
Sort(p, make_slice<THREAD, WARP_STRIPED>(padded_keys), values, scratch);
}

// copy valid items back
// copy valid keys back
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (thread_offset + (i * WARP_THREADS) < num_items) {
items[i] = padded_items[i];
keys[i] = padded_keys[i];
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion velox/experimental/breeze/breeze/utils/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ enum DataArrangement {
WARP_STRIPED,
};

class EmptySlice {};
class NullType {};

class EmptySlice {
using data_type = NullType;
};

ATTR EmptySlice constexpr make_empty_slice() { return EmptySlice{}; }

Expand Down
Loading
Loading