Skip to content

Commit

Permalink
Ensure that we can run reduce_by_key with const inputs (#1528)
Browse files Browse the repository at this point in the history
* Ensure that we can run `reduce_by_key` with const inputs

It seems that explicitly passing in `thrust::device` was key here, otherwise the bug did not manifest

Fixes [BUG]: `reduce_by_key` fails with zip_iterator to const pointers #1527

* Also address nvbug4550097

The explicit usage of `int` can give a conversion warning, so just use the right difference type

* Update thrust/testing/zip_iterator_reduce_by_key.cu
  • Loading branch information
miscco authored Mar 12, 2024
1 parent dbf9749 commit c1be3a5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
47 changes: 46 additions & 1 deletion thrust/testing/zip_iterator_reduce_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct TestZipIteratorReduceByKey
ASSERT_EQUAL(h_data4, d_data4);
ASSERT_EQUAL(h_data5, d_data5);
}

// The tests below get miscompiled on Tesla hw for 8b types

#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
Expand Down Expand Up @@ -118,6 +118,51 @@ struct TestZipIteratorReduceByKey
ASSERT_EQUAL(h_data5, d_data5);
ASSERT_EQUAL(h_data6, d_data6);
}

// const inputs, see #1527
{
host_vector<float> h_data3(n, 0.0f);
host_vector<T> h_data4(n, 0);
host_vector<T> h_data5(n, 0);
host_vector<float> h_data6(n, 0.0f);
device_vector<float> d_data3(n, 0.0f);
device_vector<T> d_data4(n, 0);
device_vector<T> d_data5(n, 0);
device_vector<float> d_data6(n, 0.0f);

// run on host
const T* h_begin1 = thrust::raw_pointer_cast(h_data1.data());
const T* h_begin2 = thrust::raw_pointer_cast(h_data2.data());
const float* h_begin3 = thrust::raw_pointer_cast(h_data3.data());
T* h_begin4 = thrust::raw_pointer_cast(h_data4.data());
T* h_begin5 = thrust::raw_pointer_cast(h_data5.data());
float* h_begin6 = thrust::raw_pointer_cast(h_data6.data());
thrust::reduce_by_key(thrust::host,
thrust::make_zip_iterator(thrust::make_tuple(h_begin1, h_begin2)),
thrust::make_zip_iterator(thrust::make_tuple(h_begin1, h_begin2)) + n,
h_begin3,
thrust::make_zip_iterator(thrust::make_tuple(h_begin4, h_begin5)),
h_begin6);

// run on device
const T* d_begin1 = thrust::raw_pointer_cast(d_data1.data());
const T* d_begin2 = thrust::raw_pointer_cast(d_data2.data());
const float* d_begin3 = thrust::raw_pointer_cast(d_data3.data());
T* d_begin4 = thrust::raw_pointer_cast(d_data4.data());
T* d_begin5 = thrust::raw_pointer_cast(d_data5.data());
float* d_begin6 = thrust::raw_pointer_cast(d_data6.data());
thrust::reduce_by_key(thrust::device,
thrust::make_zip_iterator(thrust::make_tuple(d_begin1, d_begin2)),
thrust::make_zip_iterator(thrust::make_tuple(d_begin1, d_begin2)) + n,
d_begin3,
thrust::make_zip_iterator(thrust::make_tuple(d_begin4, d_begin5)),
d_begin6);

ASSERT_EQUAL(h_data3, d_data3);
ASSERT_EQUAL(h_data4, d_data4);
ASSERT_EQUAL(h_data5, d_data5);
ASSERT_EQUAL(h_data6, d_data6);
}
}
};
VariableUnitTest<TestZipIteratorReduceByKey, UnsignedIntegralTypes> TestZipIteratorReduceByKeyInstance;
Expand Down
4 changes: 2 additions & 2 deletions thrust/thrust/system/cuda/detail/reduce_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ namespace __reduce_by_key {
}

key_type tile_pred_key = (threadIdx.x == 0)
? keys_load_it[tile_offset - 1]
? key_type(keys_load_it[tile_offset - 1])
: key_type();

sync_threadblock();
Expand Down Expand Up @@ -1057,7 +1057,7 @@ namespace __reduce_by_key {
status = cuda_cub::synchronize(policy);
cuda_cub::throw_on_error(status, "reduce_by_key: failed to synchronize");

int num_runs_out = cuda_cub::get_value(policy, d_num_runs_out);
const auto num_runs_out = cuda_cub::get_value(policy, d_num_runs_out);

return thrust::make_pair(
keys_output + num_runs_out,
Expand Down

0 comments on commit c1be3a5

Please sign in to comment.