From ba54bb63c0dcc27d02b91e598d24c927e2d7ccdb Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Wed, 27 Nov 2024 14:16:46 +0000 Subject: [PATCH] fix in string scalar --- cpp/src/text/bpe/load_merge_pairs.cu | 3 ++- cpp/tests/streams/text/bpe_test.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/text/bpe/load_merge_pairs.cu b/cpp/src/text/bpe/load_merge_pairs.cu index cd68566bdec..a13a435a271 100644 --- a/cpp/src/text/bpe/load_merge_pairs.cu +++ b/cpp/src/text/bpe/load_merge_pairs.cu @@ -103,7 +103,8 @@ std::unique_ptr create_bpe_merge_pairs_im rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto pairs = cudf::strings::split_record(input, cudf::string_scalar(" "), 1, stream, mr); + auto pairs = + cudf::strings::split_record(input, cudf::string_scalar(" ", true, stream, mr), 1, stream, mr); auto content = pairs->release(); return create_bpe_merge_pairs_impl(std::move(content.children.back()), stream); } diff --git a/cpp/tests/streams/text/bpe_test.cpp b/cpp/tests/streams/text/bpe_test.cpp index 7a433086f9f..0510edc122a 100644 --- a/cpp/tests/streams/text/bpe_test.cpp +++ b/cpp/tests/streams/text/bpe_test.cpp @@ -27,6 +27,7 @@ struct TextBytePairEncoding : public cudf::test::BaseFixture {}; TEST_F(TextBytePairEncoding, BytePairEncoding) { + auto stream = cudf::test::get_default_stream(); // partial table based on values from https://huggingface.co/gpt2/raw/main/merges.txt auto mpt = cudf::test::strings_column_wrapper({ "e n", // 14 @@ -45,8 +46,7 @@ TEST_F(TextBytePairEncoding, BytePairEncoding) "s ent" // 33832 }); - auto merge_pairs = - nvtext::load_merge_pairs(cudf::strings_column_view(mpt), cudf::test::get_default_stream()); + auto merge_pairs = nvtext::load_merge_pairs(cudf::strings_column_view(mpt), stream); auto validity = cudf::test::iterators::null_at(4); cudf::test::strings_column_wrapper input( @@ -54,6 +54,6 @@ TEST_F(TextBytePairEncoding, BytePairEncoding) validity); auto sv = cudf::strings_column_view(input); - auto results = nvtext::byte_pair_encoding( - sv, *merge_pairs, cudf::string_scalar(" "), cudf::test::get_default_stream()); + auto results = + nvtext::byte_pair_encoding(sv, *merge_pairs, cudf::string_scalar(" ", true, stream), stream); }