Skip to content

Commit

Permalink
Add Catch2 tests for cub::DeviceSegmentedRadixSort (#1214)
Browse files Browse the repository at this point in the history
* Exit gracefully when num_segments == 0.

* Add catch2 tests for cub::DeviceSegmentedRadixSort.
  • Loading branch information
alliepiper authored Dec 15, 2023
1 parent 6ba3291 commit a2efaba
Show file tree
Hide file tree
Showing 4 changed files with 952 additions and 8 deletions.
3 changes: 1 addition & 2 deletions cub/cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2837,7 +2837,7 @@ struct DispatchSegmentedRadixSort : SelectedPolicy
typedef typename DispatchSegmentedRadixSort::MaxPolicy MaxPolicyT;

// Return if empty problem, or if no bits to sort and double-buffering is used
if (num_items == 0 || (begin_bit == end_bit && is_overwrite_okay))
if (num_items == 0 || num_segments == 0 || (begin_bit == end_bit && is_overwrite_okay))
{
if (d_temp_storage == nullptr)
{
Expand Down Expand Up @@ -2991,4 +2991,3 @@ CUB_NAMESPACE_END
#if defined(__clang__)
# pragma clang diagnostic pop
#endif

223 changes: 217 additions & 6 deletions cub/test/catch2_radix_sort_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@
#include <thrust/host_vector.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/memory.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include <cub/detail/cpp_compatibility.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_segmented_radix_sort.cuh>
#include <cub/util_macro.cuh>
#include <cub/util_math.cuh>
#include <cub/util_type.cuh>

#include <array>
#include <climits>
#include <cstdint>

#include "c2h/generators.cuh"
#include "c2h/utility.cuh"
#include "catch2_test_helper.h"

Expand Down Expand Up @@ -104,6 +110,61 @@ public:
}
};

struct double_buffer_segmented_sort_t
{
private:
bool m_is_descending;
int* m_selector;

public:
explicit double_buffer_segmented_sort_t(bool is_descending)
: m_is_descending(is_descending),
m_selector(nullptr)
{
}

void initialize()
{
REQUIRE(cudaSuccess == cudaMallocHost(&m_selector, sizeof(int)));
}

void finalize()
{
REQUIRE(cudaSuccess == cudaFreeHost(m_selector));
m_selector = nullptr;
}

int selector() const { return *m_selector;}

template <class KeyT, class... As>
CUB_RUNTIME_FUNCTION cudaError_t
operator()(std::uint8_t* d_temp_storage, std::size_t& temp_storage_bytes, cub::DoubleBuffer<KeyT> keys, As... as)
{
const cudaError_t status =
m_is_descending ? cub::DeviceSegmentedRadixSort::SortKeysDescending(d_temp_storage, temp_storage_bytes, keys, as...)
: cub::DeviceSegmentedRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, keys, as...);

*m_selector = keys.selector;
return status;
}

template <class KeyT, class ValueT, class... As>
CUB_RUNTIME_FUNCTION cudaError_t operator()(
std::uint8_t* d_temp_storage,
std::size_t& temp_storage_bytes,
cub::DoubleBuffer<KeyT> keys,
cub::DoubleBuffer<ValueT> values,
As... as)
{
const cudaError_t status =
m_is_descending ? cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, keys, values, as...)
: cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, keys, values, as...);

*m_selector = keys.selector;
return status;
}
};

// Helpers to assist with specifying default args to DeviceRadixSort API:
template <typename T>
constexpr int begin_bit()
Expand Down Expand Up @@ -184,10 +245,13 @@ struct indirect_binary_comparator_t
}
};

template <class KeyT>
template <class KeyT, class SegBeginIterT, class SegEndIterT>
thrust::host_vector<std::size_t>
get_permutation(const thrust::host_vector<KeyT> &h_keys,
bool is_descending,
std::size_t num_segments,
SegBeginIterT h_seg_begin_it,
SegEndIterT h_seg_end_it,
int begin_bit,
int end_bit)
{
Expand All @@ -203,9 +267,12 @@ get_permutation(const thrust::host_vector<KeyT> &h_keys,
auto bit_ordered_striped_keys =
reinterpret_cast<const bit_ordered_t*>(thrust::raw_pointer_cast(h_striped_keys.data()));

std::stable_sort(h_permutation.begin(),
h_permutation.end(),
indirect_binary_comparator_t<bit_ordered_t>{bit_ordered_striped_keys, is_descending});
indirect_binary_comparator_t<bit_ordered_t> comp{bit_ordered_striped_keys, is_descending};

for (std::size_t segment = 0; segment < num_segments; ++segment)
{
std::stable_sort(h_permutation.begin() + h_seg_begin_it[segment], h_permutation.begin() + h_seg_end_it[segment], comp);
}

return h_permutation;
}
Expand All @@ -218,8 +285,9 @@ radix_sort_reference(const thrust::device_vector<KeyT> &d_keys,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
thrust::host_vector<KeyT> h_keys(d_keys);
std::array<std::size_t, 2> segments{0, d_keys.size()};
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, begin_bit, end_bit);
get_permutation(h_keys, is_descending, 1, segments.cbegin(), segments.cbegin() + 1, begin_bit, end_bit);
thrust::host_vector<KeyT> result(d_keys.size());
thrust::gather(h_permutation.cbegin(), h_permutation.cend(), h_keys.cbegin(), result.begin());

Expand All @@ -238,9 +306,58 @@ radix_sort_reference(const thrust::device_vector<KeyT> &d_keys,
result.first.resize(d_keys.size());
result.second.resize(d_keys.size());

std::array<std::size_t, 2> segments{0, d_keys.size()};

thrust::host_vector<KeyT> h_keys(d_keys);
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, 1, segments.cbegin(), segments.cbegin() + 1, begin_bit, end_bit);

thrust::host_vector<ValueT> h_values(d_values);
thrust::gather(h_permutation.cbegin(),
h_permutation.cend(),
thrust::make_zip_iterator(h_keys.cbegin(), h_values.cbegin()),
thrust::make_zip_iterator(result.first.begin(), result.second.begin()));

return result;
}

template <class KeyT, class SegBeginIterT, class SegEndIterT>
thrust::host_vector<KeyT> segmented_radix_sort_reference(
const thrust::device_vector<KeyT>& d_keys,
bool is_descending,
std::size_t num_segments,
SegBeginIterT h_seg_begin_it,
SegEndIterT h_seg_end_it,
int begin_bit = 0,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
thrust::host_vector<KeyT> h_keys(d_keys);
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, num_segments, h_seg_begin_it, h_seg_end_it, begin_bit, end_bit);
thrust::host_vector<KeyT> result(d_keys.size());
thrust::gather(h_permutation.cbegin(), h_permutation.cend(), h_keys.cbegin(), result.begin());

return result;
}

template <class KeyT, class ValueT, class SegBeginIterT, class SegEndIterT>
std::pair<thrust::host_vector<KeyT>, thrust::host_vector<ValueT>> segmented_radix_sort_reference(
const thrust::device_vector<KeyT>& d_keys,
const thrust::device_vector<ValueT>& d_values,
bool is_descending,
std::size_t num_segments,
SegBeginIterT h_seg_begin_it,
SegEndIterT h_seg_end_it,
int begin_bit = 0,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
std::pair<thrust::host_vector<KeyT>, thrust::host_vector<ValueT>> result;
result.first.resize(d_keys.size());
result.second.resize(d_keys.size());

thrust::host_vector<KeyT> h_keys(d_keys);
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, begin_bit, end_bit);
get_permutation(h_keys, is_descending, num_segments, h_seg_begin_it, h_seg_end_it, begin_bit, end_bit);

thrust::host_vector<ValueT> h_values(d_values);
thrust::gather(h_permutation.cbegin(),
Expand All @@ -250,3 +367,97 @@ radix_sort_reference(const thrust::device_vector<KeyT> &d_keys,

return result;
}

template <class KeyT, class OffsetT>
thrust::host_vector<KeyT> segmented_radix_sort_reference(
const thrust::device_vector<KeyT>& d_keys,
bool is_descending,
const thrust::device_vector<OffsetT>& d_offsets,
int begin_bit = 0,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
const thrust::host_vector<OffsetT> h_offsets(d_offsets);
const std::size_t num_segments = h_offsets.size() - 1;
auto h_seg_begin_it = h_offsets.cbegin();
auto h_seg_end_it = h_offsets.cbegin() + 1;
return segmented_radix_sort_reference(
d_keys, is_descending, num_segments, h_seg_begin_it, h_seg_end_it, begin_bit, end_bit);
}

template <class KeyT, class OffsetT>
thrust::host_vector<KeyT> segmented_radix_sort_reference(
const thrust::device_vector<KeyT>& d_keys,
bool is_descending,
const thrust::device_vector<OffsetT>& d_offsets_begin,
const thrust::device_vector<OffsetT>& d_offsets_end,
int begin_bit = 0,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
const thrust::host_vector<OffsetT> h_offsets_begin(d_offsets_begin);
const thrust::host_vector<OffsetT> h_offsets_end(d_offsets_end);
const std::size_t num_segments = h_offsets_begin.size();
auto h_seg_begin_it = h_offsets_begin.cbegin();
auto h_seg_end_it = h_offsets_end.cbegin();
return segmented_radix_sort_reference(
d_keys, is_descending, num_segments, h_seg_begin_it, h_seg_end_it, begin_bit, end_bit);
}

template <class KeyT, class ValueT, class OffsetT>
std::pair<thrust::host_vector<KeyT>, thrust::host_vector<ValueT>> segmented_radix_sort_reference(
const thrust::device_vector<KeyT>& d_keys,
const thrust::device_vector<ValueT>& d_values,
bool is_descending,
const thrust::device_vector<OffsetT>& d_offsets,
int begin_bit = 0,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
const thrust::host_vector<OffsetT> h_offsets(d_offsets);
const std::size_t num_segments = h_offsets.size() - 1;
auto h_seg_begin_it = h_offsets.cbegin();
auto h_seg_end_it = h_offsets.cbegin() + 1;
return segmented_radix_sort_reference(
d_keys, d_values, is_descending, num_segments, h_seg_begin_it, h_seg_end_it, begin_bit, end_bit);
}

template <class KeyT, class ValueT, class OffsetT>
std::pair<thrust::host_vector<KeyT>, thrust::host_vector<ValueT>> segmented_radix_sort_reference(
const thrust::device_vector<KeyT>& d_keys,
const thrust::device_vector<ValueT>& d_values,
bool is_descending,
const thrust::device_vector<OffsetT>& d_offsets_begin,
const thrust::device_vector<OffsetT>& d_offsets_end,
int begin_bit = 0,
int end_bit = static_cast<int>(sizeof(KeyT) * CHAR_BIT))
{
const thrust::host_vector<OffsetT> h_offsets_begin(d_offsets_begin);
const thrust::host_vector<OffsetT> h_offsets_end(d_offsets_end);
const std::size_t num_segments = h_offsets_begin.size();
auto h_seg_begin_it = h_offsets_begin.cbegin();
auto h_seg_end_it = h_offsets_end.cbegin();
return segmented_radix_sort_reference(
d_keys, d_values, is_descending, num_segments, h_seg_begin_it, h_seg_end_it, begin_bit, end_bit);
}

template <typename OffsetT>
struct offset_scan_op_t
{
OffsetT num_items;

__host__ __device__
OffsetT operator()(OffsetT a, OffsetT b) const
{
const OffsetT sum = a + b;
return CUB_MIN(sum, num_items);
}
};

template <class OffsetT>
void generate_segment_offsets(c2h::seed_t seed, thrust::device_vector<OffsetT>& offsets, std::size_t num_items)
{
const std::size_t num_segments = offsets.size() - 1;
const OffsetT expected_segment_length = static_cast<OffsetT>(cub::DivideAndRoundUp(num_items, num_segments));
const OffsetT max_segment_length = (expected_segment_length * 2) + 1;
c2h::gen(seed, offsets, OffsetT{0}, max_segment_length);
thrust::exclusive_scan(
offsets.begin(), offsets.end(), offsets.begin(), OffsetT{0}, offset_scan_op_t<OffsetT>{static_cast<OffsetT>(num_items)});
}
Loading

0 comments on commit a2efaba

Please sign in to comment.