Skip to content

Commit

Permalink
Add device inclusive scan with init_value (#1845)
Browse files Browse the repository at this point in the history
* Add block scan raking and warp scans implementation support for initial value

* Add array based API tests for block inclusive scan

* Resolve CI issues on block_scan_api test

* Remove more value based APIs

* Remove even more value based APIs, complement docs with details

* Fix device inclusive scan and adapt unit test

* Add device inclusive scan docs and API test

* Sync with upstream main

* Default IsInclusive in AgentScan

* Resolve CI issues

* Disambiguate device scan unit tests with and without init value

* Revert changes that caused unreachable code

* Revert "Revert changes that caused unreachable code"

This reverts commit fc209a1.

* Fix unreachable code MSVC error

* Change InclusiveScan dvice overload name with init value

* Finalize inclusive_scan with init value API

* Resolve reviews

* Fix unit test

Co-authored-by: Georgii Evtushenko <[email protected]>

* Minor docs issue

Co-authored-by: Georgii Evtushenko <[email protected]>

* Rename device InclusiveScan with init value to InclusiveScanInit

* Revert original arguments order

* Amend tests after reverting arguments order

* Address documentation reviews

---------

Co-authored-by: Georgii Evtushenko <[email protected]>
  • Loading branch information
gonidelis and gevtushenko authored Jul 17, 2024
1 parent cb7e845 commit be411ff
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 9 deletions.
34 changes: 30 additions & 4 deletions cub/cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ template <typename AgentScanPolicyT,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT>
typename AccumT,
bool ForceInclusive = false>
struct AgentScan
{
//---------------------------------------------------------------------
Expand All @@ -169,7 +170,10 @@ struct AgentScan
enum
{
// Inclusive scan if no init_value type is provided
IS_INCLUSIVE = std::is_same<InitValueT, NullType>::value,
HAS_INIT = !std::is_same<InitValueT, NullType>::value,
IS_INCLUSIVE = ForceInclusive || !HAS_INIT, // We are relying on either initial value not beeing `NullType`
// or the ForceInclusive tag to be true for inclusive scan
// to get picked up.
BLOCK_THREADS = AgentScanPolicyT::BLOCK_THREADS,
ITEMS_PER_THREAD = AgentScanPolicyT::ITEMS_PER_THREAD,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
Expand Down Expand Up @@ -251,17 +255,39 @@ struct AgentScan
block_aggregate = scan_op(init_value, block_aggregate);
}

_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTileInclusive(
AccumT (&items)[ITEMS_PER_THREAD],
AccumT init_value,
ScanOpT scan_op,
AccumT& block_aggregate,
Int2Type<true> /*has_init*/)
{
BlockScanT(temp_storage.scan_storage.scan).InclusiveScan(items, items, init_value, scan_op, block_aggregate);
block_aggregate = scan_op(init_value, block_aggregate);
}

_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTileInclusive(
AccumT (&items)[ITEMS_PER_THREAD],
InitValueT /*init_value*/,
ScanOpT scan_op,
AccumT& block_aggregate,
Int2Type<false> /*has_init*/)

{
BlockScanT(temp_storage.scan_storage.scan).InclusiveScan(items, items, scan_op, block_aggregate);
}

/**
* Inclusive scan specialization (first tile)
*/
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
AccumT (&items)[ITEMS_PER_THREAD],
InitValueT /*init_value*/,
InitValueT init_value,
ScanOpT scan_op,
AccumT& block_aggregate,
Int2Type<true> /*is_inclusive*/)
{
BlockScanT(temp_storage.scan_storage.scan).InclusiveScan(items, items, scan_op, block_aggregate);
ScanTileInclusive(items, init_value, scan_op, block_aggregate, Int2Type<HAS_INIT>());
}

/**
Expand Down
102 changes: 102 additions & 0 deletions cub/cub/device/device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,108 @@ struct DeviceScan
d_temp_storage, temp_storage_bytes, d_in, d_out, scan_op, NullType(), num_items, stream);
}

//! @rst
//! Computes a device-wide inclusive prefix scan using the specified binary ``scan_op`` functor.
//! The result of applying the ``scan_op`` binary operator to ``init_value`` value and ``*d_in``
//! is assigned to ``*d_out``.
//!
//! - Supports non-commutative scan operators.
//! - Results are not deterministic for pseudo-associative operators (e.g.,
//! addition of floating-point types). Results for pseudo-associative
//! operators may vary from run to run. Additional details can be found in
//! the @lookback description.
//! - When ``d_in`` and ``d_out`` are equal, the scan is performed in-place. The
//! range ``[d_in, d_in + num_items)`` and ``[d_out, d_out + num_items)``
//! shall not overlap in any other way.
//! - @devicestorage
//!
//! Snippet
//! +++++++++++++++++++++++++++++++++++++++++++++
//!
//! The code snippet below illustrates the inclusive max-scan of an ``int`` device vector.
//!
//! .. literalinclude:: ../../../cub/test/catch2_test_device_scan_api.cu
//! :language: c++
//! :dedent:
//! :start-after: example-begin device-inclusive-scan
//! :end-before: example-end device-inclusive-scan
//!
//! @endrst
//!
//! @tparam InputIteratorT
//! **[inferred]** Random-access input iterator type for reading scan inputs @iterator
//!
//! @tparam OutputIteratorT
//! **[inferred]** Random-access output iterator type for writing scan outputs @iterator
//!
//! @tparam ScanOpT
//! **[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)`
//!
//! @tparam InitValueT
//! **[inferred]** Type of the `init_value`
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage.
//! When `nullptr`, the required allocation size is written to
//! `temp_storage_bytes` and no work is done.
//!
//! @param[in,out] temp_storage_bytes
//! Reference to the size in bytes of the `d_temp_storage` allocation
//!
//! @param[in] d_in
//! Random-access iterator to the input sequence of data items
//!
//! @param[out] d_out
//! Random-access iterator to the output sequence of data items
//!
//! @param[in] scan_op
//! Binary scan functor
//!
//! @param[in] init_value
//! Initial value to seed the inclusive scan (`scan_op(init_value, d_in[0])`
//! is assigned to `*d_out`)
//!
//! @param[in] num_items
//! Total number of input items (i.e., the length of `d_in`)
//!
//! @param[in] stream
//! CUDA stream to launch kernels within.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT>
CUB_RUNTIME_FUNCTION static cudaError_t InclusiveScanInit(
void* d_temp_storage,
size_t& temp_storage_bytes,
InputIteratorT d_in,
OutputIteratorT d_out,
ScanOpT scan_op,
InitValueT init_value,
int num_items,
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::InclusiveScanInit");

// Signed integer type for global offsets
using OffsetT = int;
using AccumT = cub::detail::accumulator_t<ScanOpT, InitValueT, cub::detail::value_t<InputIteratorT>>;
constexpr bool ForceInclusive = true;

return DispatchScan<
InputIteratorT,
OutputIteratorT,
ScanOpT,
detail::InputValue<InitValueT>,
OffsetT,
AccumT,
DeviceScanPolicy<AccumT, ScanOpT>,
ForceInclusive>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
scan_op,
detail::InputValue<InitValueT>(init_value),
num_items,
stream);
}

template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t InclusiveScan(
void* d_temp_storage,
Expand Down
22 changes: 18 additions & 4 deletions cub/cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ template <typename ChainedPolicyT,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT>
typename AccumT,
bool ForceInclusive>
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanKernel(
InputIteratorT d_in,
Expand All @@ -188,7 +189,8 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
using ScanPolicyT = typename ChainedPolicyT::ActivePolicy::ScanPolicyT;

// Thread block type for scanning input tiles
using AgentScanT = AgentScan<ScanPolicyT, InputIteratorT, OutputIteratorT, ScanOpT, RealInitValueT, OffsetT, AccumT>;
using AgentScanT =
AgentScan<ScanPolicyT, InputIteratorT, OutputIteratorT, ScanOpT, RealInitValueT, OffsetT, AccumT, ForceInclusive>;

// Shared memory for AgentScan
__shared__ typename AgentScanT::TempStorage temp_storage;
Expand Down Expand Up @@ -223,6 +225,9 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
* @tparam OffsetT
* Signed integer type for global offsets
*
* @tparam ForceInclusive
* Boolean flag to force InclusiveScan invocation when true.
*
*/
template <typename InputIteratorT,
typename OutputIteratorT,
Expand All @@ -234,7 +239,8 @@ template <typename InputIteratorT,
cub::detail::value_t<InputIteratorT>,
typename InitValueT::value_type>,
cub::detail::value_t<InputIteratorT>>,
typename SelectedPolicy = DeviceScanPolicy<AccumT, ScanOpT>>
typename SelectedPolicy = DeviceScanPolicy<AccumT, ScanOpT>,
bool ForceInclusive = false>
struct DispatchScan : SelectedPolicy
{
//---------------------------------------------------------------------
Expand Down Expand Up @@ -505,7 +511,15 @@ struct DispatchScan : SelectedPolicy
// Ensure kernels are instantiated.
return Invoke<ActivePolicyT>(
DeviceScanInitKernel<ScanTileStateT>,
DeviceScanKernel<MaxPolicyT, InputIteratorT, OutputIteratorT, ScanTileStateT, ScanOpT, InitValueT, OffsetT, AccumT>);
DeviceScanKernel<MaxPolicyT,
InputIteratorT,
OutputIteratorT,
ScanTileStateT,
ScanOpT,
InitValueT,
OffsetT,
AccumT,
ForceInclusive>);
}

/**
Expand Down
37 changes: 36 additions & 1 deletion cub/test/catch2_test_device_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
#include <cstdint>

#include "c2h/custom_type.cuh"
#include "c2h/extended_types.cuh"
#include "catch2_test_device_reduce.cuh"
#include "catch2_test_device_scan.cuh"
#include "catch2_test_helper.h"
#include "catch2_test_launch_helper.h"

DECLARE_LAUNCH_WRAPPER(cub::DeviceScan::InclusiveScanInit, device_inclusive_scan_with_init);
DECLARE_LAUNCH_WRAPPER(cub::DeviceScan::ExclusiveSum, device_exclusive_sum);
DECLARE_LAUNCH_WRAPPER(cub::DeviceScan::ExclusiveScan, device_exclusive_scan);
DECLARE_LAUNCH_WRAPPER(cub::DeviceScan::InclusiveSum, device_inclusive_sum);
Expand Down Expand Up @@ -210,6 +210,41 @@ CUB_TEST("Device scan works with all device interfaces", "[scan][device]", full_
}
}

SECTION("inclusive scan with init value")
{
using op_t = cub::Sum;
using accum_t = cub::detail::accumulator_t<op_t, input_t, input_t>;

// Scan operator
auto scan_op = unwrap_op(reference_extended_fp(d_in_it), op_t{});

// Prepare verification data
c2h::host_vector<input_t> host_items(in_items);
c2h::host_vector<output_t> expected_result(num_items);

// Run test
c2h::device_vector<output_t> out_result(num_items);
auto d_out_it = thrust::raw_pointer_cast(out_result.data());
accum_t init_value{};
init_default_constant(init_value);
compute_inclusive_scan_reference(
host_items.cbegin(), host_items.cend(), expected_result.begin(), scan_op, init_value);

device_inclusive_scan_with_init(unwrap_it(d_in_it), unwrap_it(d_out_it), scan_op, init_value, num_items);

// Verify result
REQUIRE(expected_result == out_result);

// Run test in-place
_CCCL_IF_CONSTEXPR (std::is_same<input_t, output_t>::value)
{
device_inclusive_scan_with_init(unwrap_it(d_in_it), unwrap_it(d_in_it), scan_op, init_value, num_items);

// Verify result
REQUIRE(expected_result == in_items);
}
}

SECTION("exclusive scan")
{
using op_t = cub::Sum;
Expand Down
64 changes: 64 additions & 0 deletions cub/test/catch2_test_device_scan_api.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/******************************************************************************
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/

#include <cub/device/device_scan.cuh>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include "catch2_test_helper.h"

CUB_TEST("Device inclusive scan works", "[scan][device]")
{
// example-begin device-inclusive-scan
thrust::device_vector<int> input{0, -1, 2, -3, 4, -5, 6};
thrust::device_vector<int> out(input.size());

int init = 1;
size_t temp_storage_bytes{};

cub::DeviceScan::InclusiveScanInit(
nullptr, temp_storage_bytes, input.begin(), out.begin(), cub::Max{}, init, static_cast<int>(input.size()));

// Allocate temporary storage for inclusive scan
thrust::device_vector<std::uint8_t> temp_storage(temp_storage_bytes);

// Run inclusive prefix sum
cub::DeviceScan::InclusiveScanInit(
thrust::raw_pointer_cast(temp_storage.data()),
temp_storage_bytes,
input.begin(),
out.begin(),
cub::Max{},
init,
static_cast<int>(input.size()));

thrust::host_vector<int> expected{1, 1, 2, 2, 4, 4, 6};
// example-end device-inclusive-scan

REQUIRE(expected == out);
}

0 comments on commit be411ff

Please sign in to comment.