diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala index 88a5439be42eb..e40b293ea9d21 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala @@ -54,17 +54,17 @@ class GlutenClickhouseStringFunctionsSuite extends GlutenClickHouseWholeStageTra sql("create table trim(trim_col String, src_col String) using parquet") sql(""" |insert into trim values - | ('aba', 'a'),('bba', 'b'),('abcdef', 'abcd'), + | ('bAa', 'a'),('bba', 'b'),('abcdef', 'abcd'), | (null, '123'),('123', null), ('', 'aaa'), ('bbb', '') |""".stripMargin) val sql0 = "select rtrim('aba', 'a') from trim order by src_col" val sql1 = "select rtrim(trim_col, src_col) from trim order by src_col" - val sql2 = "select rtrim(trim_col, 'NSSS') from trim order by src_col" + val sql2 = "select rtrim(trim_col, 'cCBbAa') from trim order by src_col" val sql3 = "select rtrim(trim_col, '') from trim order by src_col" val sql4 = "select rtrim('', 'AAA') from trim order by src_col" val sql5 = "select rtrim('', src_col) from trim order by src_col" - val sql6 = "select rtrim('ttt', src_col) from trim order by src_col" + val sql6 = "select rtrim('ab', src_col) from trim order by src_col" runQueryAndCompare(sql0) { _ => } runQueryAndCompare(sql1) { _ => } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp index 9f7a7647fcb79..6b8227ccc24b6 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp @@ -104,7 +104,7 @@ namespace } ColumnPtr - executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const override { const ColumnString * src_col = checkAndGetColumn(arguments[0].column.get()); const ColumnConst * src_const_col = checkAndGetColumnConst(arguments[0].column.get()); @@ -118,39 +118,34 @@ namespace if (trim_const_col) trim_const_str = trim_const_col->getValue(); if (trim_const_col && trim_const_str.empty()) { - return arguments[0].column->cloneResized(input_rows_count); + return arguments[0].column; } - if (src_const_col && trim_const_col) - { - const char * dst; - size_t dst_size; - std::unordered_set trim_set(trim_const_str.begin(), trim_const_str.end()); - trim(src_const_str.c_str(), src_const_str.size(), dst, dst_size, trim_set); - return result_type->createColumnConst(input_rows_count, String(dst, dst_size)); - } + // If both arguments are constants, it will be simplified to a constant. Skipped here. auto res_col = ColumnString::create(); ColumnString::Chars & res_data = res_col->getChars(); ColumnString::Offsets & res_offsets = res_col->getOffsets(); res_offsets.resize_exact(input_rows_count); + // Source column is constant and trim column is not constant if (src_const_col) { res_data.reserve_exact(src_const_str.size() * input_rows_count); for (size_t row = 0; row < input_rows_count; ++row) { StringRef trim_str_ref = trim_col->getDataAt(row); - std::unordered_set trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size); + std::unique_ptr> trim_set = buildTrimSet(trim_str_ref.toString()); executeRow(src_const_str.c_str(), src_const_str.size(), res_data, res_offsets, row, trim_set); } return std::move(res_col); } + // Source column is not constant and trim column is constant if (trim_const_col) { res_data.reserve_exact(src_col->getChars().size()); - std::unordered_set trim_set(trim_const_str.begin(), trim_const_str.end()); + std::unique_ptr> trim_set = buildTrimSet(trim_const_str); for (size_t row = 0; row < input_rows_count; ++row) { StringRef src_str_ref = src_col->getDataAt(row); @@ -165,7 +160,7 @@ namespace { StringRef src_str_ref = src_col->getDataAt(row); StringRef trim_str_ref = trim_col->getDataAt(row); - std::unordered_set trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size); + std::unique_ptr> trim_set = buildTrimSet(trim_str_ref.toString()); executeRow(src_str_ref.data, src_str_ref.size, res_data, res_offsets, row, trim_set); } return std::move(res_col); @@ -178,7 +173,7 @@ namespace ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets, size_t & row, - const std::unordered_set & trim_set) const + const std::unique_ptr> & trim_set) const { const char * dst; size_t dst_size; @@ -191,16 +186,30 @@ namespace res_offsets[row] = res_offset; } - void trim(const char * src, size_t src_size, const char *& dst, size_t & dst_size, const std::unordered_set & trim_set) const + std::unique_ptr> buildTrimSet(const String& trim_str) const { - const char * src_end = src + src_size; + auto trim_set = std::make_unique>(); + for (unsigned char i : trim_str) + trim_set->set(i); + return trim_set; + } + void trim(const char * src, size_t src_size, const char *& dst, size_t & dst_size, const std::unique_ptr> & trim_set) const + { + if (!trim_set || trim_set->none()) + { + dst = src; + dst_size = src_size; + return; + } + + const char * src_end = src + src_size; if constexpr (TrimMode::trim_left) - while (src < src_end && trim_set.contains(*src)) + while (src < src_end && trim_set->test((unsigned char)*src)) ++src; if constexpr (TrimMode::trim_right) - while (src < src_end && trim_set.contains(*(src_end - 1))) + while (src < src_end && trim_set->test((unsigned char)*(src_end - 1))) --src_end; dst = const_cast(src);