Skip to content

Commit

Permalink
Avoid make_zip_iterator(make_tuple(...)) (#2796)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Nov 13, 2024
1 parent 8a44baa commit f4a0619
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 79 deletions.
12 changes: 6 additions & 6 deletions thrust/testing/zip_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ struct TestZipFunctionTransform
device_vector<T> d_result_zip(n);

// Tuple base case
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
transform(make_zip_iterator(h_data0.begin(), h_data1.begin(), h_data2.begin()),
make_zip_iterator(h_data0.end(), h_data1.end(), h_data2.end()),
h_result_tuple.begin(),
SumThreeTuple{});
// Zip Function
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
transform(make_zip_iterator(h_data0.begin(), h_data1.begin(), h_data2.begin()),
make_zip_iterator(h_data0.end(), h_data1.end(), h_data2.end()),
h_result_zip.begin(),
make_zip_function(SumThree{}));
transform(make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin(), d_data2.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
transform(make_zip_iterator(d_data0.begin(), d_data1.begin(), d_data2.begin()),
make_zip_iterator(d_data0.end(), d_data1.end(), d_data2.end()),
d_result_zip.begin(),
make_zip_function(SumThree{}));

Expand Down
34 changes: 17 additions & 17 deletions thrust/testing/zip_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ struct TestZipIteratorManipulation

// test equality
ZipIterator iter1 = iter0;
ZipIterator iter2 = make_zip_iterator(make_tuple(v0.begin(), v2.begin()));
ZipIterator iter3 = make_zip_iterator(make_tuple(v1.begin(), v2.begin()));
ZipIterator iter2 = make_zip_iterator(v0.begin(), v2.begin());
ZipIterator iter3 = make_zip_iterator(v1.begin(), v2.begin());
ASSERT_EQUAL(true, iter0 == iter1);
ASSERT_EQUAL(true, iter0 == iter2);
ASSERT_EQUAL(false, iter0 == iter3);
Expand Down Expand Up @@ -284,9 +284,9 @@ void TestZipIteratorCopy()
sequence(input0.begin(), input0.end(), T{0});
sequence(input1.begin(), input1.end(), T{13});

thrust::copy(make_zip_iterator(make_tuple(input0.begin(), input1.begin())),
make_zip_iterator(make_tuple(input0.end(), input1.end())),
make_zip_iterator(make_tuple(output0.begin(), output1.begin())));
thrust::copy(make_zip_iterator(input0.begin(), input1.begin()),
make_zip_iterator(input0.end(), input1.end()),
make_zip_iterator(output0.begin(), output1.begin()));

ASSERT_EQUAL(input0, output0);
ASSERT_EQUAL(input1, output1);
Expand Down Expand Up @@ -332,23 +332,23 @@ struct TestZipIteratorTransform
device_vector<T> d_result(n);

// Tuples with 2 elements
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end())),
transform(make_zip_iterator(h_data0.begin(), h_data1.begin()),
make_zip_iterator(h_data0.end(), h_data1.end()),
h_result.begin(),
SumTwoTuple());
transform(make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end())),
transform(make_zip_iterator(d_data0.begin(), d_data1.begin()),
make_zip_iterator(d_data0.end(), d_data1.end()),
d_result.begin(),
SumTwoTuple());
ASSERT_EQUAL(h_result, d_result);

// Tuples with 3 elements
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
transform(make_zip_iterator(h_data0.begin(), h_data1.begin(), h_data2.begin()),
make_zip_iterator(h_data0.end(), h_data1.end(), h_data2.end()),
h_result.begin(),
SumThreeTuple());
transform(make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin(), d_data2.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
transform(make_zip_iterator(d_data0.begin(), d_data1.begin(), d_data2.begin()),
make_zip_iterator(d_data0.end(), d_data1.end(), d_data2.end()),
d_result.begin(),
SumThreeTuple());
ASSERT_EQUAL(h_result, d_result);
Expand All @@ -375,14 +375,14 @@ void TestZipIteratorCopyAoSToSoA()

// host to host
host_vector<int> h_field0(n), h_field1(n);
host_structure_of_arrays h_soa = make_zip_iterator(make_tuple(h_field0.begin(), h_field1.begin()));
host_structure_of_arrays h_soa = make_zip_iterator(h_field0.begin(), h_field1.begin());

thrust::copy(h_aos.begin(), h_aos.end(), h_soa);
ASSERT_EQUAL_QUIET(make_tuple(7, 13), h_soa[0]);

// host to device
device_vector<int> d_field0(n), d_field1(n);
device_structure_of_arrays d_soa = make_zip_iterator(make_tuple(d_field0.begin(), d_field1.begin()));
device_structure_of_arrays d_soa = make_zip_iterator(d_field0.begin(), d_field1.begin());

thrust::copy(h_aos.begin(), h_aos.end(), d_soa);
ASSERT_EQUAL_QUIET(make_tuple(7, 13), d_soa[0]);
Expand Down Expand Up @@ -420,8 +420,8 @@ void TestZipIteratorCopySoAToAoS()
host_vector<int> h_field0(n, 7), h_field1(n, 13);
device_vector<int> d_field0(n, 7), d_field1(n, 13);

host_structure_of_arrays h_soa = make_zip_iterator(make_tuple(h_field0.begin(), h_field1.begin()));
device_structure_of_arrays d_soa = make_zip_iterator(make_tuple(d_field0.begin(), d_field1.begin()));
host_structure_of_arrays h_soa = make_zip_iterator(h_field0.begin(), h_field1.begin());
device_structure_of_arrays d_soa = make_zip_iterator(d_field0.begin(), d_field1.begin());

host_array_of_structures h_aos(n);
device_array_of_structures d_aos(n);
Expand Down
8 changes: 4 additions & 4 deletions thrust/testing/zip_iterator_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ struct TestZipIteratorReduce

// run on host
Tuple h_result = thrust::reduce(
make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end())),
make_zip_iterator(h_data0.begin(), h_data1.begin()),
make_zip_iterator(h_data0.end(), h_data1.end()),
make_tuple<T, T>(0, 0),
TuplePlus<Tuple>());

// run on device
Tuple d_result = thrust::reduce(
make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end())),
make_zip_iterator(d_data0.begin(), d_data1.begin()),
make_zip_iterator(d_data0.end(), d_data1.end()),
make_tuple<T, T>(0, 0),
TuplePlus<Tuple>());

Expand Down
28 changes: 14 additions & 14 deletions thrust/testing/zip_iterator_reduce_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ struct TestZipIteratorReduceByKey
reduce_by_key(
h_data0.begin(),
h_data0.end(),
make_zip_iterator(make_tuple(h_data1.begin(), h_data2.begin())),
make_zip_iterator(h_data1.begin(), h_data2.begin()),
h_data3.begin(),
make_zip_iterator(make_tuple(h_data4.begin(), h_data5.begin())),
make_zip_iterator(h_data4.begin(), h_data5.begin()),
equal_to<T>(),
TuplePlus<Tuple>());

// run on device
reduce_by_key(
d_data0.begin(),
d_data0.end(),
make_zip_iterator(make_tuple(d_data1.begin(), d_data2.begin())),
make_zip_iterator(d_data1.begin(), d_data2.begin()),
d_data3.begin(),
make_zip_iterator(make_tuple(d_data4.begin(), d_data5.begin())),
make_zip_iterator(d_data4.begin(), d_data5.begin()),
equal_to<T>(),
TuplePlus<Tuple>());

Expand Down Expand Up @@ -95,21 +95,21 @@ struct TestZipIteratorReduceByKey

// run on host
reduce_by_key(
make_zip_iterator(make_tuple(h_data0.begin(), h_data0.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data0.end())),
make_zip_iterator(make_tuple(h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data3.begin(), h_data4.begin())),
make_zip_iterator(make_tuple(h_data5.begin(), h_data6.begin())),
make_zip_iterator(h_data0.begin(), h_data0.begin()),
make_zip_iterator(h_data0.end(), h_data0.end()),
make_zip_iterator(h_data1.begin(), h_data2.begin()),
make_zip_iterator(h_data3.begin(), h_data4.begin()),
make_zip_iterator(h_data5.begin(), h_data6.begin()),
equal_to<Tuple>(),
TuplePlus<Tuple>());

// run on device
reduce_by_key(
make_zip_iterator(make_tuple(d_data0.begin(), d_data0.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data0.end())),
make_zip_iterator(make_tuple(d_data1.begin(), d_data2.begin())),
make_zip_iterator(make_tuple(d_data3.begin(), d_data4.begin())),
make_zip_iterator(make_tuple(d_data5.begin(), d_data6.begin())),
make_zip_iterator(d_data0.begin(), d_data0.begin()),
make_zip_iterator(d_data0.end(), d_data0.end()),
make_zip_iterator(d_data1.begin(), d_data2.begin()),
make_zip_iterator(d_data3.begin(), d_data4.begin()),
make_zip_iterator(d_data5.begin(), d_data6.begin()),
equal_to<Tuple>(),
TuplePlus<Tuple>());

Expand Down
40 changes: 20 additions & 20 deletions thrust/testing/zip_iterator_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,27 @@ struct TestZipIteratorScan

// inclusive_scan (tuple output)
thrust::inclusive_scan(
make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end())),
make_zip_iterator(h_data0.begin(), h_data1.begin()),
make_zip_iterator(h_data0.end(), h_data1.end()),
h_result.begin(),
TuplePlus<Tuple>());
thrust::inclusive_scan(
make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end())),
make_zip_iterator(d_data0.begin(), d_data1.begin()),
make_zip_iterator(d_data0.end(), d_data1.end()),
d_result.begin(),
TuplePlus<Tuple>());
ASSERT_EQUAL_QUIET(h_result, d_result);

// exclusive_scan (tuple output)
thrust::exclusive_scan(
make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end())),
make_zip_iterator(h_data0.begin(), h_data1.begin()),
make_zip_iterator(h_data0.end(), h_data1.end()),
h_result.begin(),
make_tuple<T, T>(0, 0),
TuplePlus<Tuple>());
thrust::exclusive_scan(
make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end())),
make_zip_iterator(d_data0.begin(), d_data1.begin()),
make_zip_iterator(d_data0.end(), d_data1.end()),
d_result.begin(),
make_tuple<T, T>(0, 0),
TuplePlus<Tuple>());
Expand All @@ -72,29 +72,29 @@ struct TestZipIteratorScan

// inclusive_scan (zip_iterator output)
thrust::inclusive_scan(
make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end())),
make_zip_iterator(make_tuple(h_result0.begin(), h_result1.begin())),
make_zip_iterator(h_data0.begin(), h_data1.begin()),
make_zip_iterator(h_data0.end(), h_data1.end()),
make_zip_iterator(h_result0.begin(), h_result1.begin()),
TuplePlus<Tuple>());
thrust::inclusive_scan(
make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end())),
make_zip_iterator(make_tuple(d_result0.begin(), d_result1.begin())),
make_zip_iterator(d_data0.begin(), d_data1.begin()),
make_zip_iterator(d_data0.end(), d_data1.end()),
make_zip_iterator(d_result0.begin(), d_result1.begin()),
TuplePlus<Tuple>());
ASSERT_EQUAL_QUIET(h_result0, d_result0);
ASSERT_EQUAL_QUIET(h_result1, d_result1);

// exclusive_scan (zip_iterator output)
thrust::exclusive_scan(
make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end())),
make_zip_iterator(make_tuple(h_result0.begin(), h_result1.begin())),
make_zip_iterator(h_data0.begin(), h_data1.begin()),
make_zip_iterator(h_data0.end(), h_data1.end()),
make_zip_iterator(h_result0.begin(), h_result1.begin()),
make_tuple<T, T>(0, 0),
TuplePlus<Tuple>());
thrust::exclusive_scan(
make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end())),
make_zip_iterator(make_tuple(d_result0.begin(), d_result1.begin())),
make_zip_iterator(d_data0.begin(), d_data1.begin()),
make_zip_iterator(d_data0.end(), d_data1.end()),
make_zip_iterator(d_result0.begin(), d_result1.begin()),
make_tuple<T, T>(0, 0),
TuplePlus<Tuple>());
ASSERT_EQUAL_QUIET(h_result0, d_result0);
Expand Down
6 changes: 2 additions & 4 deletions thrust/testing/zip_iterator_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ struct TestZipIteratorStableSort
device_vector<T> d2 = h2;

// sort on host
stable_sort(make_zip_iterator(make_tuple(h1.begin(), h2.begin())),
make_zip_iterator(make_tuple(h1.end(), h2.end())));
stable_sort(make_zip_iterator(h1.begin(), h2.begin()), make_zip_iterator(h1.end(), h2.end()));

// sort on device
stable_sort(make_zip_iterator(make_tuple(d1.begin(), d2.begin())),
make_zip_iterator(make_tuple(d1.end(), d2.end())));
stable_sort(make_zip_iterator(d1.begin(), d2.begin()), make_zip_iterator(d1.end(), d2.end()));

ASSERT_EQUAL_QUIET(h1, d1);
ASSERT_EQUAL_QUIET(h2, d2);
Expand Down
24 changes: 10 additions & 14 deletions thrust/testing/zip_iterator_sort_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,25 @@ struct TestZipIteratorStableSortByKey
device_vector<T> d4 = h4;

// sort with (tuple, scalar)
stable_sort_by_key(make_zip_iterator(make_tuple(h1.begin(), h2.begin())),
make_zip_iterator(make_tuple(h1.end(), h2.end())),
h3.begin());
stable_sort_by_key(make_zip_iterator(make_tuple(d1.begin(), d2.begin())),
make_zip_iterator(make_tuple(d1.end(), d2.end())),
d3.begin());
stable_sort_by_key(make_zip_iterator(h1.begin(), h2.begin()), make_zip_iterator(h1.end(), h2.end()), h3.begin());
stable_sort_by_key(make_zip_iterator(d1.begin(), d2.begin()), make_zip_iterator(d1.end(), d2.end()), d3.begin());

ASSERT_EQUAL_QUIET(h1, d1);
ASSERT_EQUAL_QUIET(h2, d2);
ASSERT_EQUAL_QUIET(h3, d3);
ASSERT_EQUAL_QUIET(h4, d4);

// sort with (scalar, tuple)
stable_sort_by_key(h1.begin(), h1.end(), make_zip_iterator(make_tuple(h3.begin(), h4.begin())));
stable_sort_by_key(d1.begin(), d1.end(), make_zip_iterator(make_tuple(d3.begin(), d4.begin())));
stable_sort_by_key(h1.begin(), h1.end(), make_zip_iterator(h3.begin(), h4.begin()));
stable_sort_by_key(d1.begin(), d1.end(), make_zip_iterator(d3.begin(), d4.begin()));

// sort with (tuple, tuple)
stable_sort_by_key(make_zip_iterator(make_tuple(h1.begin(), h2.begin())),
make_zip_iterator(make_tuple(h1.end(), h2.end())),
make_zip_iterator(make_tuple(h3.begin(), h4.begin())));
stable_sort_by_key(make_zip_iterator(make_tuple(d1.begin(), d2.begin())),
make_zip_iterator(make_tuple(d1.end(), d2.end())),
make_zip_iterator(make_tuple(d3.begin(), d4.begin())));
stable_sort_by_key(make_zip_iterator(h1.begin(), h2.begin()),
make_zip_iterator(h1.end(), h2.end()),
make_zip_iterator(h3.begin(), h4.begin()));
stable_sort_by_key(make_zip_iterator(d1.begin(), d2.begin()),
make_zip_iterator(d1.end(), d2.end()),
make_zip_iterator(d3.begin(), d4.begin()));

ASSERT_EQUAL_QUIET(h1, d1);
ASSERT_EQUAL_QUIET(h2, d2);
Expand Down

0 comments on commit f4a0619

Please sign in to comment.