From c87029f521489c805bbf718c6e02a0ecc7d4266b Mon Sep 17 00:00:00 2001 From: Shuai li Date: Fri, 31 May 2024 09:45:48 +0800 Subject: [PATCH] [GLUTEN-5921][CH] Function trim of trim_character support value from column (#5922) [CH] Function trim of trim_character support value from column --- .../GlutenFunctionValidateSuite.scala | 18 ++++- .../Functions/SparkFunctionTrim.cpp | 76 +++++++++++++++---- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 5a1ca679986f..cfe8ea95abcd 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -708,6 +708,23 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } + test("GLUTEN-5821: trim_character support value from column.") { + withTable("trim") { + sql("create table trim(a String, b String) using parquet") + sql(""" + |insert into trim values ('aba', 'a'),('bba', 'b'),('abcdef', 'abcd') + |""".stripMargin) + + val sql_str = + s"""select + | trim(both b from a) + | from trim + """.stripMargin + + runQueryAndCompare(sql_str) { _ => } + } + } + test("GLUTEN-5897: fix regexp_extract with bracket") { withTable("regexp_extract_bracket") { sql("create table regexp_extract_bracket(a String) using parquet") @@ -727,5 +744,4 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS runQueryAndCompare(sql_str) { _ => } } } - } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp index d8f6be1bfc32..88ed3f635672 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp @@ -77,8 +77,6 @@ namespace bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { if (arguments.size() != 2) @@ -112,19 +110,34 @@ namespace if (!src_str_col) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be String", getName()); - const ColumnConst * trim_str_col = checkAndGetColumnConst(arguments[1].column.get()); - if (!trim_str_col) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be Const String", getName()); - - String trim_str = trim_str_col->getValue(); - if (trim_str.empty()) - return src_str_col->cloneResized(input_rows_count); - auto res_col = ColumnString::create(); - res_col->reserve(input_rows_count); + if (const auto * trim_const_str_col = checkAndGetColumnConst(arguments[1].column.get())) + { + String trim_str = trim_const_str_col->getValue(); + if (trim_str.empty()) + return src_str_col->cloneResized(input_rows_count); + + auto res_col = ColumnString::create(); + res_col->reserve(input_rows_count); + executeVector(src_str_col->getChars(), src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(), trim_str); + return std::move(res_col); + } + else if (const auto * trim_str_col = checkAndGetColumn(arguments[1].column.get())) + { + auto res_col = ColumnString::create(); + res_col->reserve(input_rows_count); + + executeVector( + src_str_col->getChars(), + src_str_col->getOffsets(), + res_col->getChars(), + res_col->getOffsets(), + trim_str_col->getChars(), + trim_str_col->getOffsets()); + return std::move(res_col); + } - executeVector(src_str_col->getChars(), src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(), trim_str); - return std::move(res_col); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be String or Const String", getName()); } private: @@ -159,6 +172,43 @@ namespace } } + void executeVector( + const ColumnString::Chars & data, + const ColumnString::Offsets & offsets, + ColumnString::Chars & res_data, + ColumnString::Offsets & res_offsets, + const ColumnString::Chars & trim_data, + const ColumnString::Offsets & trim_offsets) const + { + res_data.reserve_exact(data.size()); + + size_t rows = offsets.size(); + res_offsets.resize_exact(rows); + + size_t prev_offset = 0; + size_t prev_trim_str_offset = 0; + size_t res_offset = 0; + + const UInt8 * start; + size_t length; + + for (size_t i = 0; i < rows; ++i) + { + std::unordered_set trim_set( + &trim_data[prev_trim_str_offset], &trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1); + + trim(reinterpret_cast(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length, trim_set); + res_data.resize_exact(res_data.size() + length + 1); + memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length); + res_offset += length + 1; + res_data[res_offset - 1] = '\0'; + + res_offsets[i] = res_offset; + prev_offset = offsets[i]; + prev_trim_str_offset = trim_offsets[i]; + } + } + void trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t & res_size, const std::unordered_set & trim_set) const {