-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port
thrust::merge[_by_key]
to CUB (#1817)
* Refactor thrust/CUB merge * Port thurst::merge[_by_key] to cub::DeviceMerge Fixes #1763 Co-authored-by: Georgii Evtushenko <[email protected]>
- Loading branch information
1 parent
1b16af7
commit 8635429
Showing
15 changed files
with
1,697 additions
and
971 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#pragma once | ||
|
||
#include <cub/config.cuh> | ||
|
||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#include <cub/agent/agent_merge_sort.cuh> | ||
#include <cub/block/block_load.cuh> | ||
#include <cub/block/block_merge_sort.cuh> | ||
#include <cub/block/block_store.cuh> | ||
#include <cub/util_namespace.cuh> | ||
#include <cub/util_type.cuh> | ||
|
||
#include <thrust/system/cuda/detail/core/util.h> | ||
|
||
#include <cuda/std/__cccl/dialect.h> | ||
|
||
CUB_NAMESPACE_BEGIN | ||
namespace detail | ||
{ | ||
namespace merge | ||
{ | ||
template <int ThreadsPerBlock, | ||
int ItemsPerThread, | ||
BlockLoadAlgorithm LoadAlgorithm, | ||
CacheLoadModifier LoadCacheModifier, | ||
BlockStoreAlgorithm StoreAlgorithm> | ||
struct agent_policy_t | ||
{ | ||
// do not change data member names, policy_wrapper_t depends on it | ||
static constexpr int BLOCK_THREADS = ThreadsPerBlock; | ||
static constexpr int ITEMS_PER_THREAD = ItemsPerThread; | ||
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD; | ||
static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = LoadAlgorithm; | ||
static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier; | ||
static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm; | ||
}; | ||
|
||
// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh? | ||
template <typename Policy, | ||
typename KeysIt1, | ||
typename ItemsIt1, | ||
typename KeysIt2, | ||
typename ItemsIt2, | ||
typename KeysOutputIt, | ||
typename ItemsOutputIt, | ||
typename Offset, | ||
typename CompareOp> | ||
struct agent_t | ||
{ | ||
using policy = Policy; | ||
|
||
using key_type = typename ::cuda::std::iterator_traits<KeysIt1>::value_type; | ||
using item_type = typename ::cuda::std::iterator_traits<ItemsIt1>::value_type; | ||
|
||
using keys_load_it1 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeysIt1>::type; | ||
using keys_load_it2 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeysIt2>::type; | ||
using items_load_it1 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ItemsIt1>::type; | ||
using items_load_it2 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ItemsIt2>::type; | ||
|
||
using block_load_keys1 = typename BlockLoadType<Policy, keys_load_it1>::type; | ||
using block_load_keys2 = typename BlockLoadType<Policy, keys_load_it2>::type; | ||
using block_load_items1 = typename BlockLoadType<Policy, items_load_it1>::type; | ||
using block_load_items2 = typename BlockLoadType<Policy, items_load_it2>::type; | ||
|
||
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type; | ||
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type; | ||
|
||
union temp_storages | ||
{ | ||
typename block_load_keys1::TempStorage load_keys1; | ||
typename block_load_keys2::TempStorage load_keys2; | ||
typename block_load_items1::TempStorage load_items1; | ||
typename block_load_items2::TempStorage load_items2; | ||
typename block_store_keys::TempStorage store_keys; | ||
typename block_store_items::TempStorage store_items; | ||
|
||
key_type keys_shared[Policy::ITEMS_PER_TILE + 1]; | ||
item_type items_shared[Policy::ITEMS_PER_TILE + 1]; | ||
}; | ||
|
||
struct TempStorage : Uninitialized<temp_storages> | ||
{}; | ||
|
||
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD; | ||
static constexpr int threads_per_block = Policy::BLOCK_THREADS; | ||
static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE; | ||
|
||
// Per thread data | ||
temp_storages& storage; | ||
keys_load_it1 keys1_in; | ||
items_load_it1 items1_in; | ||
Offset keys1_count; | ||
keys_load_it2 keys2_in; | ||
items_load_it2 items2_in; | ||
Offset keys2_count; | ||
KeysOutputIt keys_out; | ||
ItemsOutputIt items_out; | ||
CompareOp compare_op; | ||
Offset* merge_partitions; | ||
|
||
template <bool IsFullTile> | ||
_CCCL_DEVICE _CCCL_FORCEINLINE void consume_tile(Offset tile_idx, Offset tile_base, int num_remaining) | ||
{ | ||
const Offset partition_beg = merge_partitions[tile_idx + 0]; | ||
const Offset partition_end = merge_partitions[tile_idx + 1]; | ||
|
||
const Offset diag0 = items_per_tile * tile_idx; | ||
const Offset diag1 = (cub::min)(keys1_count + keys2_count, diag0 + items_per_tile); | ||
|
||
// compute bounding box for keys1 & keys2 | ||
const Offset keys1_beg = partition_beg; | ||
const Offset keys1_end = partition_end; | ||
const Offset keys2_beg = diag0 - keys1_beg; | ||
const Offset keys2_end = diag1 - keys1_end; | ||
|
||
// number of keys per tile | ||
const int num_keys1 = static_cast<int>(keys1_end - keys1_beg); | ||
const int num_keys2 = static_cast<int>(keys2_end - keys2_beg); | ||
|
||
key_type keys_loc[items_per_thread]; | ||
gmem_to_reg<threads_per_block, IsFullTile>( | ||
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2); | ||
reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc); | ||
CTA_SYNC(); | ||
|
||
// use binary search in shared memory to find merge path for each of thread. | ||
// we can use int type here, because the number of items in shared memory is limited | ||
const int diag0_loc = min<int>(num_keys1 + num_keys2, items_per_thread * threadIdx.x); | ||
|
||
const int keys1_beg_loc = | ||
MergePath(&storage.keys_shared[0], &storage.keys_shared[num_keys1], num_keys1, num_keys2, diag0_loc, compare_op); | ||
const int keys1_end_loc = num_keys1; | ||
const int keys2_beg_loc = diag0_loc - keys1_beg_loc; | ||
const int keys2_end_loc = num_keys2; | ||
|
||
const int num_keys1_loc = keys1_end_loc - keys1_beg_loc; | ||
const int num_keys2_loc = keys2_end_loc - keys2_beg_loc; | ||
|
||
// perform serial merge | ||
int indices[items_per_thread]; | ||
cub::SerialMerge( | ||
&storage.keys_shared[0], | ||
keys1_beg_loc, | ||
keys2_beg_loc + num_keys1, | ||
num_keys1_loc, | ||
num_keys2_loc, | ||
keys_loc, | ||
indices, | ||
compare_op); | ||
CTA_SYNC(); | ||
|
||
// write keys | ||
if (IsFullTile) | ||
{ | ||
block_store_keys{storage.store_keys}.Store(keys_out + tile_base, keys_loc); | ||
} | ||
else | ||
{ | ||
block_store_keys{storage.store_keys}.Store(keys_out + tile_base, keys_loc, num_remaining); | ||
} | ||
|
||
// if items are provided, merge them | ||
static constexpr bool have_items = !std::is_same<item_type, NullType>::value; | ||
#ifdef _CCCL_CUDACC_BELOW_11_8 | ||
if (have_items) // nvcc 11.1 cannot handle #pragma unroll inside if constexpr but 11.8 can. | ||
// nvcc versions between may work | ||
#else | ||
_CCCL_IF_CONSTEXPR (have_items) | ||
#endif | ||
{ | ||
item_type items_loc[items_per_thread]; | ||
gmem_to_reg<threads_per_block, IsFullTile>( | ||
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2); | ||
CTA_SYNC(); // block_store_keys above uses shared memory, so make sure all threads are done before we write to it | ||
reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc); | ||
CTA_SYNC(); | ||
|
||
// gather items from shared mem | ||
#pragma unroll | ||
for (int i = 0; i < items_per_thread; ++i) | ||
{ | ||
items_loc[i] = storage.items_shared[indices[i]]; | ||
} | ||
CTA_SYNC(); | ||
|
||
// write from reg to gmem | ||
if (IsFullTile) | ||
{ | ||
block_store_items{storage.store_items}.Store(items_out + tile_base, items_loc); | ||
} | ||
else | ||
{ | ||
block_store_items{storage.store_items}.Store(items_out + tile_base, items_loc, num_remaining); | ||
} | ||
} | ||
} | ||
|
||
_CCCL_DEVICE _CCCL_FORCEINLINE void operator()() | ||
{ | ||
// XXX with 8.5 chaging type to Offset (or long long) results in error! | ||
// TODO(bgruber): is the above still true? | ||
const int tile_idx = static_cast<int>(blockIdx.x); | ||
const Offset tile_base = tile_idx * items_per_tile; | ||
// TODO(bgruber): random mixing of int and Offset | ||
const int items_in_tile = | ||
static_cast<int>(cub::min(static_cast<Offset>(items_per_tile), keys1_count + keys2_count - tile_base)); | ||
if (items_in_tile == items_per_tile) | ||
{ | ||
consume_tile<true>(tile_idx, tile_base, items_per_tile); // full tile | ||
} | ||
else | ||
{ | ||
consume_tile<false>(tile_idx, tile_base, items_in_tile); // partial tile | ||
} | ||
} | ||
}; | ||
} // namespace merge | ||
} // namespace detail | ||
CUB_NAMESPACE_END |
Oops, something went wrong.