Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strings::contains() for multiple scalar search targets #16641

Closed
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f4924a9
strings::contains() for multiple search targets
mythrocks Apr 15, 2024
1022c83
string contains optimization
Aug 22, 2024
45170e9
Add benchmark test
Aug 22, 2024
7e2aa43
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Aug 23, 2024
32e1329
Fix comments
Aug 27, 2024
be6985b
Use new approach to improve perf: index the first chars in the targets
Aug 29, 2024
be7a1e2
Fix comments; Restore a test change
Aug 29, 2024
6b635f6
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Aug 29, 2024
479788c
Improve
Aug 29, 2024
543a1f6
Fix compile error
Aug 30, 2024
f1da8b0
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Aug 30, 2024
06ba14c
Update test cases; update benchmark tests
Aug 30, 2024
14418d7
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Aug 30, 2024
814e002
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 2, 2024
587ce34
Format code
Sep 2, 2024
470355f
Fix bug
Sep 2, 2024
4b41ead
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 4, 2024
e56a122
Fix comments
Sep 4, 2024
31f4822
Optimize warp parallel
Sep 5, 2024
88d351d
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 5, 2024
7836c33
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 6, 2024
6ae2c00
Split targets to small groups to save shared memory when num of targe…
Sep 6, 2024
542e1ff
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 9, 2024
ab5ef90
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 10, 2024
da1d92b
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 11, 2024
3324671
Fix bug when strings are long: returns all falses.
Sep 11, 2024
849c093
Format code
Sep 11, 2024
85e8b17
Refactor: refine code comments
Sep 11, 2024
ce4450d
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 14, 2024
9fc9398
Fix bug: illegal memory access
Sep 14, 2024
b33d692
Fix bug in split logic
Sep 14, 2024
6741bef
Optimize the perf for indexing first chars
Sep 14, 2024
330e828
Fix comments from code review
Sep 14, 2024
d216993
Fix compile error
Sep 14, 2024
eb6744f
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 18, 2024
a32c54d
Fix bugs; update tests
Sep 18, 2024
8391239
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 18, 2024
5caf782
Update
Sep 18, 2024
41fb9ae
Merge branch 'branch-24.10' into multi-string-contains-review
res-life Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions cpp/benchmarks/string/find.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,25 @@ static void bench_find_string(nvbench::state& state)
cudf::strings::find_multiple(input, cudf::strings_column_view(targets));
});
} else if (api == "contains") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, target); });
constexpr int iters = 20;
std::vector<std::string> match_targets({"123", "abc", "4567890", "DEFGHI", "5W43"});
auto multi_targets = std::vector<std::string>{};
for (int i = 0; i < iters; i++) {
multi_targets.emplace_back(match_targets[i % match_targets.size()]);
}
cudf::test::strings_column_wrapper multi_targets_column(multi_targets.begin(), multi_targets.end());

constexpr bool combine = false;
if constexpr (not combine) {
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
for (size_t i = 0; i < multi_targets.size(); i++) {
cudf::strings::contains(input, cudf::string_scalar(multi_targets[i]));
}
});
} else { // combine
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::multi_contains(input, cudf::strings_column_view(multi_targets_column)); });
}
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
} else if (api == "starts_with") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); });
Expand Down
30 changes: 30 additions & 0 deletions cpp/include/cudf/strings/find.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,36 @@ std::unique_ptr<column> contains(
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns a table of columns of boolean values for each string where true indicates
* the target string was found within that string in the provided column.
*
* Each column in the result table corresponds to the result for the target string at the same
* ordinal. i.e. 0th column is the boolean-column result for the 0th target string, 1th for 1th,
* etc.
*
* If the target is not found for a string, false is returned for that entry in the output column.
* If the target is an empty string, true is returned for all non-null entries in the output column.
*
* Any null string entries return corresponding null entries in the output columns.
* e.g.:
* input: "a", "b", "c"
* targets: "a", "c"
* output is a table with two boolean columns:
* column_0: true, false, false
* column_1: false, false, true
res-life marked this conversation as resolved.
Show resolved Hide resolved
* @param input Strings instance for this operation
* @param targets UTF-8 encoded strings to search for in each string in `input`
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New BOOL8 column
*/
std::unique_ptr<table> multi_contains(
strings_column_view const& input,
strings_column_view const& targets,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns a column of boolean values for each string where true indicates
* the corresponding target string was found within that string in the provided column.
Expand Down
173 changes: 168 additions & 5 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/find.hpp>
Expand Down Expand Up @@ -414,6 +414,134 @@ std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
return results;
}

CUDF_KERNEL void multi_contains_fn(column_device_view const d_strings,
column_device_view const d_targets,
cudf::device_span<bool*> d_results)
res-life marked this conversation as resolved.
Show resolved Hide resolved
{
auto const str_idx = static_cast<size_type>(cudf::detail::grid_1d::global_thread_id());
auto const num_targets = d_targets.size();
auto const num_rows = d_strings.size();
if (str_idx >= num_rows) { return; }
if (d_strings.is_null(str_idx)) { return; } // bitmask will set result to null.
auto const d_str = d_strings.element<string_view>(str_idx);

// check empty target
for (auto target_idx = 0; target_idx < num_targets; ++target_idx) {
auto const d_target = d_targets.element<string_view>(target_idx);
d_results[target_idx][str_idx] = d_target.size_bytes() == 0;
}

for (auto str_byte_idx = 0; str_byte_idx < d_str.size_bytes();
++str_byte_idx) { // iterate the start index in the string
for (auto target_idx = 0; target_idx < num_targets; ++target_idx) { // iterate targets
if (!d_results[target_idx][str_idx]) { // not found before
auto const d_target = d_targets.element<string_view>(target_idx);
if (d_str.size_bytes() - str_byte_idx >= d_target.size_bytes() &&
(d_target.compare(d_str.data() + str_byte_idx, d_target.size_bytes()) == 0)) {
// found
d_results[target_idx][str_idx] = true;
}
}
}
}
}

CUDF_KERNEL void multi_contains_warp_parallel_multi_scalars_fn(column_device_view const d_strings,
column_device_view const d_targets,
cudf::device_span<bool*> d_results)
{
auto const num_targets = d_targets.size();
auto const num_rows = d_strings.size();

auto const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);
using warp_reduce = cub::WarpReduce<bool>;
__shared__ typename warp_reduce::TempStorage temp_storage;

if (idx >= (num_rows * cudf::detail::warp_size)) { return; }

auto const lane_idx = idx % cudf::detail::warp_size;
auto const str_idx = idx / cudf::detail::warp_size;
if (d_strings.is_null(str_idx)) { return; } // bitmask will set result to null.

// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);

for (size_t target_idx = 0; target_idx < num_targets; target_idx++) {
// Identify the target.
auto const d_target = d_targets.element<string_view>(target_idx);

// each thread of the warp will check just part of the string
auto found = false;
if (d_target.empty()) {
found = true;
} else {
for (auto i = static_cast<size_type>(lane_idx);
!found && ((i + d_target.size_bytes()) <= d_str.size_bytes());
i += cudf::detail::warp_size) {
// check the target matches this part of the d_str data
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
}
}
__syncwarp();
auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max());
if (lane_idx == 0) { d_results[target_idx][str_idx] = result; }
}
res-life marked this conversation as resolved.
Show resolved Hide resolved
}

std::vector<std::unique_ptr<column>> multi_contains(strings_column_view const& input,
strings_column_view const& targets,
bool warp_parallel,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const num_targets = static_cast<size_type>(targets.size());
CUDF_EXPECTS(not targets.is_empty(), "Must specify at least one target string.");
davidwendt marked this conversation as resolved.
Show resolved Hide resolved

// Create output columns.
auto const results_iter =
thrust::make_transform_iterator(thrust::counting_iterator<cudf::size_type>(0), [&](int i) {
return make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
});
auto results_list =
std::vector<std::unique_ptr<column>>(results_iter, results_iter + targets.size());
auto device_results_list = [&] {
auto host_results_pointer_iter =
thrust::make_transform_iterator(results_list.begin(), [](auto const& results_column) {
return results_column->mutable_view().template data<bool>();
});
auto host_results_pointers = std::vector<bool*>(
host_results_pointer_iter, host_results_pointer_iter + results_list.size());
return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr);
}();

// Populate all output vectors,

constexpr int block_size = 256;
// launch warp per string
auto const d_strings = column_device_view::create(input.parent(), stream);
auto const d_targets = column_device_view::create(targets.parent(), stream);

if (warp_parallel) {
// one warp handles multi-targets for a string
cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size};
multi_contains_warp_parallel_multi_scalars_fn<<<grid.num_blocks,
grid.num_threads_per_block,
0,
stream.value()>>>(
*d_strings, *d_targets, device_results_list);
} else {
cudf::detail::grid_1d grid{input.size(), block_size};
multi_contains_fn<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, *d_targets, device_results_list);
}
return results_list;
}

/**
* @brief Utility to return a bool column indicating the presence of
* a given target string in a strings column.
Expand Down Expand Up @@ -534,6 +662,16 @@ std::unique_ptr<column> contains_fn(strings_column_view const& strings,
return results;
}

std::unique_ptr<column> contains_small_strings_impl(strings_column_view const& input,
string_scalar const& target,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
};
return contains_fn(input, target, pfn, stream, mr);
}
} // namespace

std::unique_ptr<column> contains(strings_column_view const& input,
Expand All @@ -548,10 +686,26 @@ std::unique_ptr<column> contains(strings_column_view const& input,
}

// benchmark measurements showed this to be faster for smaller strings
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
};
return contains_fn(input, target, pfn, stream, mr);
return contains_small_strings_impl(input, target, stream, mr);
}

std::unique_ptr<table> multi_contains(strings_column_view const& input,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
res-life marked this conversation as resolved.
Show resolved Hide resolved
{
auto result_columns = [&] {
if ((input.null_count() < input.size()) &&
((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
// Large strings.
// use warp parallel when the average string width is greater than the threshold
return multi_contains(input, targets, /*warp_parallel=*/true, stream, mr);
} else {
// Small strings. Searching for multiple targets in one thread seems to work fastest.
return multi_contains(input, targets, /*warp_parallel=*/false, stream, mr);
}
}();
return std::make_unique<table>(std::move(result_columns));
}

std::unique_ptr<column> contains(strings_column_view const& strings,
Expand Down Expand Up @@ -632,6 +786,15 @@ std::unique_ptr<column> contains(strings_column_view const& strings,
return detail::contains(strings, target, stream, mr);
}

std::unique_ptr<table> multi_contains(strings_column_view const& strings,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::multi_contains(strings, targets, stream, mr);
}

std::unique_ptr<column> contains(strings_column_view const& strings,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
Expand Down
35 changes: 32 additions & 3 deletions cpp/tests/strings/find_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/iterator_utilities.hpp>

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/strings/attributes.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/strings_column_view.hpp>

#include <thrust/iterator/transform_iterator.h>

#include <vector>

struct StringsFindTest : public cudf::test::BaseFixture {};
Expand Down Expand Up @@ -198,6 +196,37 @@ TEST_F(StringsFindTest, ContainsLongStrings)
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected);
}

TEST_F(StringsFindTest, MultiContains)
{
using cudf::test::iterators::null_at;
auto const strings = cudf::test::strings_column_wrapper{
{"Héllo, there world and goodbye",
"quick brown fox jumped over the lazy brown dog; the fat cats jump in place without moving",
"the following code snippet demonstrates how to use search for values in an ordered range",
"it returns the last position where value could be inserted without violating the ordering",
"algorithms execution is parallelized as determined by an execution policy. t",
"he this is a continuation of previous row to make sure string boundaries are honored",
"abcdefghijklmnopqrstuvwxyz 0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ !@#$%^&*()~",
"",
""},
null_at(8)};
auto strings_view = cudf::strings_column_view(strings);
std::vector<std::string> match_targets({" the ", "a", ""});
res-life marked this conversation as resolved.
Show resolved Hide resolved
cudf::test::strings_column_wrapper multi_targets_column(match_targets.begin(),
match_targets.end());
auto results =
cudf::strings::multi_contains(strings_view, cudf::strings_column_view(multi_targets_column));
auto expected_0 =
cudf::test::fixed_width_column_wrapper<bool>({0, 1, 0, 1, 0, 0, 0, 0, 0}, null_at(8));
auto expected_1 =
cudf::test::fixed_width_column_wrapper<bool>({1, 1, 1, 1, 1, 1, 1, 0, 0}, null_at(8));
auto expected_2 =
cudf::test::fixed_width_column_wrapper<bool>({1, 1, 1, 1, 1, 1, 1, 1, 0}, null_at(8));
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(0), expected_0);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(1), expected_1);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->get_column(2), expected_2);
}

TEST_F(StringsFindTest, StartsWith)
{
cudf::test::strings_column_wrapper strings({"Héllo", "thesé", "", "lease", "tést strings", ""},
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.*;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static ai.rapids.cudf.HostColumnVector.OFFSET_SIZE;

Expand Down Expand Up @@ -3351,6 +3352,22 @@ public final ColumnVector stringContains(Scalar compString) {
return new ColumnVector(stringContains(getNativeView(), compString.getScalarHandle()));
}

private static long[] toPrimitive(Long[] longs) {
long[] ret = new long[longs.length];
for (int i = 0; i < longs.length; ++i) {
ret[i] = longs[i];
}
return ret;
}

public final ColumnVector[] stringContains(ColumnView targets) {
assert type.equals(DType.STRING) : "column type must be a String";
assert targets.getType().equals(DType.STRING) : "targets type must be a string";
assert targets.getNullCount() > 0 : "targets must not be null";
long[] resultPointers = stringContainsMulti(getNativeView(), targets.getNativeView());
return Arrays.stream(resultPointers).mapToObj(ColumnVector::new).toArray(ColumnVector[]::new);
}

/**
* Replaces values less than `lo` in `input` with `lo`,
* and values greater than `hi` with `hi`.
Expand Down Expand Up @@ -4456,6 +4473,11 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
private static native long stringContains(long cudfViewHandle, long compString) throws CudfException;

/**
* Check multiple target strings against the same input column.
*/
private static native long[] stringContainsMulti(long cudfViewHandle, long targets) throws CudfException;

/**
* Native method for extracting results from a regex program pattern. Returns a table handle.
*
Expand Down
Loading
Loading