-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce CUB ForEach algorithms (#1302)
- Loading branch information
1 parent
0c9d032
commit b7d4228
Showing
25 changed files
with
2,618 additions
and
158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* * Redistributions of source code must retain the above copyright | ||
* notice, this list of conditions and the following disclaimer. | ||
* * Redistributions in binary form must reproduce the above copyright | ||
* notice, this list of conditions and the following disclaimer in the | ||
* documentation and/or other materials provided with the distribution. | ||
* * Neither the name of the NVIDIA CORPORATION nor the | ||
* names of its contributors may be used to endorse or promote products | ||
* derived from this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
* | ||
******************************************************************************/ | ||
|
||
#include <cub/device/device_for.cuh> | ||
|
||
#include <nvbench_helper.cuh> | ||
|
||
template <class T> | ||
struct op_t | ||
{ | ||
int* d_count{}; | ||
|
||
__device__ void operator()(T val) const | ||
{ | ||
if (val == T{}) | ||
{ | ||
atomicAdd(d_count, 1); | ||
} | ||
} | ||
}; | ||
|
||
template <class T, class OffsetT> | ||
void for_each(nvbench::state& state, nvbench::type_list<T, OffsetT>) | ||
{ | ||
using input_it_t = const T*; | ||
using output_it_t = int*; | ||
using offset_t = OffsetT; | ||
|
||
const auto elements = static_cast<offset_t>(state.get_int64("Elements{io}")); | ||
|
||
thrust::device_vector<T> in(elements, T{42}); | ||
|
||
input_it_t d_in = thrust::raw_pointer_cast(in.data()); | ||
// `d_out` exists for visibility | ||
// All inputs are equal to `42`, while the operator is searching for `0`. | ||
// If the operator finds `0` in the input sequence, it's an issue leading to a segfault. | ||
output_it_t d_out = nullptr; | ||
|
||
state.add_element_count(elements); | ||
state.add_global_memory_reads<T>(elements); | ||
|
||
op_t<T> op{d_out}; | ||
|
||
std::size_t temp_size{}; | ||
cub::DeviceFor::ForEachN(nullptr, temp_size, d_in, elements, op); | ||
|
||
thrust::device_vector<nvbench::uint8_t> temp(temp_size); | ||
auto* temp_storage = thrust::raw_pointer_cast(temp.data()); | ||
|
||
state.exec(nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { | ||
cub::DeviceFor::ForEachN(temp_storage, temp_size, d_in, elements, op, launch.get_stream()); | ||
}); | ||
} | ||
|
||
NVBENCH_BENCH_TYPES(for_each, NVBENCH_TYPE_AXES(fundamental_types, offset_types)) | ||
.set_name("base") | ||
.set_type_axes_names({"T{ct}", "OffsetT{ct}"}) | ||
.add_int64_power_of_two_axis("Elements{io}", nvbench::range(16, 28, 4)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* * Redistributions of source code must retain the above copyright | ||
* notice, this list of conditions and the following disclaimer. | ||
* * Redistributions in binary form must reproduce the above copyright | ||
* notice, this list of conditions and the following disclaimer in the | ||
* documentation and/or other materials provided with the distribution. | ||
* * Neither the name of the NVIDIA CORPORATION nor the | ||
* names of its contributors may be used to endorse or promote products | ||
* derived from this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
* | ||
******************************************************************************/ | ||
|
||
#include <cub/device/device_for.cuh> | ||
|
||
#include <nvbench_helper.cuh> | ||
|
||
template <class T> | ||
struct op_t | ||
{ | ||
int* d_count{}; | ||
|
||
__device__ void operator()(T val) const | ||
{ | ||
if (val == T{}) | ||
{ | ||
atomicAdd(d_count, 1); | ||
} | ||
} | ||
}; | ||
|
||
template <class T, class OffsetT> | ||
void for_each(nvbench::state& state, nvbench::type_list<T, OffsetT>) | ||
{ | ||
using input_it_t = const T*; | ||
using output_it_t = int*; | ||
using offset_t = OffsetT; | ||
|
||
const auto elements = static_cast<offset_t>(state.get_int64("Elements{io}")); | ||
|
||
thrust::device_vector<T> in(elements, T{42}); | ||
|
||
input_it_t d_in = thrust::raw_pointer_cast(in.data()); | ||
output_it_t d_out = nullptr; | ||
|
||
state.add_element_count(elements); | ||
state.add_global_memory_reads<T>(elements); | ||
|
||
op_t<T> op{d_out}; | ||
|
||
std::size_t temp_size{}; | ||
cub::DeviceFor::ForEachCopyN(nullptr, temp_size, d_in, elements, op); | ||
|
||
thrust::device_vector<nvbench::uint8_t> temp(temp_size); | ||
auto* temp_storage = thrust::raw_pointer_cast(temp.data()); | ||
|
||
state.exec(nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { | ||
cub::DeviceFor::ForEachCopyN(temp_storage, temp_size, d_in, elements, op, launch.get_stream()); | ||
}); | ||
} | ||
|
||
NVBENCH_BENCH_TYPES(for_each, NVBENCH_TYPE_AXES(fundamental_types, offset_types)) | ||
.set_name("base") | ||
.set_type_axes_names({"T{ct}", "OffsetT{ct}"}) | ||
.add_int64_power_of_two_axis("Elements{io}", nvbench::range(16, 28, 4)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* * Redistributions of source code must retain the above copyright | ||
* notice, this list of conditions and the following disclaimer. | ||
* * Redistributions in binary form must reproduce the above copyright | ||
* notice, this list of conditions and the following disclaimer in the | ||
* documentation and/or other materials provided with the distribution. | ||
* * Neither the name of the NVIDIA CORPORATION nor the | ||
* names of its contributors may be used to endorse or promote products | ||
* derived from this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
* | ||
******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
#include <cub/config.cuh> | ||
|
||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#include <cub/util_ptx.cuh> | ||
#include <cub/util_type.cuh> | ||
|
||
CUB_NAMESPACE_BEGIN | ||
|
||
namespace detail | ||
{ | ||
namespace for_each | ||
{ | ||
|
||
template <int BlockThreads, int ItemsPerThread> | ||
struct policy_t | ||
{ | ||
static constexpr int block_threads = BlockThreads; | ||
static constexpr int items_per_thread = ItemsPerThread; | ||
}; | ||
|
||
template <class PolicyT, class OffsetT, class OpT> | ||
struct agent_block_striped_t | ||
{ | ||
static constexpr int items_per_thread = PolicyT::items_per_thread; | ||
|
||
OffsetT tile_base; | ||
OpT op; | ||
|
||
template <bool IsFullTile> | ||
_CCCL_DEVICE _CCCL_FORCEINLINE void consume_tile(int items_in_tile, int block_threads) | ||
{ | ||
#pragma unroll | ||
for (int item = 0; item < items_per_thread; item++) | ||
{ | ||
const auto idx = static_cast<OffsetT>(block_threads * item + threadIdx.x); | ||
|
||
if (IsFullTile || idx < items_in_tile) | ||
{ | ||
(void)op(tile_base + idx); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace for_each | ||
} // namespace detail | ||
|
||
CUB_NAMESPACE_END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.