From 8dae5f4952b1dc84a5a71fac595d9401ed2f641f Mon Sep 17 00:00:00 2001 From: lwz9103 Date: Fri, 23 Aug 2024 14:29:56 +0800 Subject: [PATCH] [GLUTEN-6989][CH] Support RTrim with const source column --- ...GlutenClickhouseStringFunctionsSuite.scala | 38 ++++ .../Functions/SparkFunctionTrim.cpp | 172 ++++++++---------- 2 files changed, 114 insertions(+), 96 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala index 98c0c2b35f202..88a5439be42eb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala @@ -49,6 +49,44 @@ class GlutenClickhouseStringFunctionsSuite extends GlutenClickHouseWholeStageTra } } + test("GLUTEN-6989: rtrim support source column const") { + withTable("trim") { + sql("create table trim(trim_col String, src_col String) using parquet") + sql(""" + |insert into trim values + | ('aba', 'a'),('bba', 'b'),('abcdef', 'abcd'), + | (null, '123'),('123', null), ('', 'aaa'), ('bbb', '') + |""".stripMargin) + + val sql0 = "select rtrim('aba', 'a') from trim order by src_col" + val sql1 = "select rtrim(trim_col, src_col) from trim order by src_col" + val sql2 = "select rtrim(trim_col, 'NSSS') from trim order by src_col" + val sql3 = "select rtrim(trim_col, '') from trim order by src_col" + val sql4 = "select rtrim('', 'AAA') from trim order by src_col" + val sql5 = "select rtrim('', src_col) from trim order by src_col" + val sql6 = "select rtrim('ttt', src_col) from trim order by src_col" + + runQueryAndCompare(sql0) { _ => } + runQueryAndCompare(sql1) { _ => } + runQueryAndCompare(sql2) { _ => } + runQueryAndCompare(sql3) { _ => } + runQueryAndCompare(sql4) { _ => } + runQueryAndCompare(sql5) { _ => } + runQueryAndCompare(sql6) { _ => } + + // test other trim functions + val sql7 = "SELECT trim(LEADING trim_col FROM src_col) from trim" + val sql8 = "SELECT trim(LEADING trim_col FROM 'NSB') from trim" + val sql9 = "SELECT trim(TRAILING trim_col FROM src_col) from trim" + val sql10 = "SELECT trim(TRAILING trim_col FROM '') from trim" + runQueryAndCompare(sql7) { _ => } + runQueryAndCompare(sql8) { _ => } + runQueryAndCompare(sql9) { _ => } + runQueryAndCompare(sql10) { _ => } + + } + } + test("GLUTEN-5897: fix regexp_extract with bracket") { withTable("regexp_extract_bracket") { sql("create table regexp_extract_bracket(a String) using parquet") diff --git a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp index 88ed3f635672d..9f7a7647fcb79 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp @@ -104,127 +104,107 @@ namespace } ColumnPtr - executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const override + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override { - const ColumnString * src_str_col = checkAndGetColumn(arguments[0].column.get()); - if (!src_str_col) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be String", getName()); - + const ColumnString * src_col = checkAndGetColumn(arguments[0].column.get()); + const ColumnConst * src_const_col = checkAndGetColumnConst(arguments[0].column.get()); + const ColumnString * trim_col = checkAndGetColumn(arguments[1].column.get()); + const ColumnConst * trim_const_col = checkAndGetColumnConst(arguments[1].column.get()); + + String src_const_str; + String trim_const_str; + if (src_const_col) + src_const_str = src_const_col->getValue(); + if (trim_const_col) + trim_const_str = trim_const_col->getValue(); + if (trim_const_col && trim_const_str.empty()) { + return arguments[0].column->cloneResized(input_rows_count); + } - if (const auto * trim_const_str_col = checkAndGetColumnConst(arguments[1].column.get())) + if (src_const_col && trim_const_col) { - String trim_str = trim_const_str_col->getValue(); - if (trim_str.empty()) - return src_str_col->cloneResized(input_rows_count); + const char * dst; + size_t dst_size; + std::unordered_set trim_set(trim_const_str.begin(), trim_const_str.end()); + trim(src_const_str.c_str(), src_const_str.size(), dst, dst_size, trim_set); + return result_type->createColumnConst(input_rows_count, String(dst, dst_size)); + } - auto res_col = ColumnString::create(); - res_col->reserve(input_rows_count); - executeVector(src_str_col->getChars(), src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(), trim_str); + auto res_col = ColumnString::create(); + ColumnString::Chars & res_data = res_col->getChars(); + ColumnString::Offsets & res_offsets = res_col->getOffsets(); + res_offsets.resize_exact(input_rows_count); + + if (src_const_col) + { + res_data.reserve_exact(src_const_str.size() * input_rows_count); + for (size_t row = 0; row < input_rows_count; ++row) + { + StringRef trim_str_ref = trim_col->getDataAt(row); + std::unordered_set trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size); + executeRow(src_const_str.c_str(), src_const_str.size(), res_data, res_offsets, row, trim_set); + } return std::move(res_col); } - else if (const auto * trim_str_col = checkAndGetColumn(arguments[1].column.get())) + + if (trim_const_col) { - auto res_col = ColumnString::create(); - res_col->reserve(input_rows_count); - - executeVector( - src_str_col->getChars(), - src_str_col->getOffsets(), - res_col->getChars(), - res_col->getOffsets(), - trim_str_col->getChars(), - trim_str_col->getOffsets()); + res_data.reserve_exact(src_col->getChars().size()); + std::unordered_set trim_set(trim_const_str.begin(), trim_const_str.end()); + for (size_t row = 0; row < input_rows_count; ++row) + { + StringRef src_str_ref = src_col->getDataAt(row); + executeRow(src_str_ref.data, src_str_ref.size, res_data, res_offsets, row, trim_set); + } return std::move(res_col); } - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be String or Const String", getName()); - } - - private: - void executeVector( - const ColumnString::Chars & data, - const ColumnString::Offsets & offsets, - ColumnString::Chars & res_data, - ColumnString::Offsets & res_offsets, - const String & trim_str) const - { - res_data.reserve_exact(data.size()); - - size_t rows = offsets.size(); - res_offsets.resize_exact(rows); - - size_t prev_offset = 0; - size_t res_offset = 0; - - const UInt8 * start; - size_t length; - std::unordered_set trim_set(trim_str.begin(), trim_str.end()); - for (size_t i = 0; i < rows; ++i) + // Both columns are not constant + res_data.reserve(src_col->getChars().size()); + for (size_t row = 0; row < input_rows_count; ++row) { - trim(reinterpret_cast(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length, trim_set); - res_data.resize_exact(res_data.size() + length + 1); - memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length); - res_offset += length + 1; - res_data[res_offset - 1] = '\0'; - - res_offsets[i] = res_offset; - prev_offset = offsets[i]; + StringRef src_str_ref = src_col->getDataAt(row); + StringRef trim_str_ref = trim_col->getDataAt(row); + std::unordered_set trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size); + executeRow(src_str_ref.data, src_str_ref.size, res_data, res_offsets, row, trim_set); } + return std::move(res_col); } - void executeVector( - const ColumnString::Chars & data, - const ColumnString::Offsets & offsets, + private: + void executeRow( + const char * src, + size_t src_size, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets, - const ColumnString::Chars & trim_data, - const ColumnString::Offsets & trim_offsets) const + size_t & row, + const std::unordered_set & trim_set) const { - res_data.reserve_exact(data.size()); - - size_t rows = offsets.size(); - res_offsets.resize_exact(rows); - - size_t prev_offset = 0; - size_t prev_trim_str_offset = 0; - size_t res_offset = 0; - - const UInt8 * start; - size_t length; - - for (size_t i = 0; i < rows; ++i) - { - std::unordered_set trim_set( - &trim_data[prev_trim_str_offset], &trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1); - - trim(reinterpret_cast(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length, trim_set); - res_data.resize_exact(res_data.size() + length + 1); - memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length); - res_offset += length + 1; - res_data[res_offset - 1] = '\0'; - - res_offsets[i] = res_offset; - prev_offset = offsets[i]; - prev_trim_str_offset = trim_offsets[i]; - } + const char * dst; + size_t dst_size; + trim(src, src_size, dst, dst_size, trim_set); + size_t res_offset = row > 0 ? res_offsets[row - 1] : 0; + res_data.resize_exact(res_data.size() + dst_size + 1); + memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], dst, dst_size); + res_offset += dst_size + 1; + res_data[res_offset - 1] = '\0'; + res_offsets[row] = res_offset; } - void - trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t & res_size, const std::unordered_set & trim_set) const + void trim(const char * src, size_t src_size, const char *& dst, size_t & dst_size, const std::unordered_set & trim_set) const { - const char * char_data = reinterpret_cast(data); - const char * char_end = char_data + size; + const char * src_end = src + src_size; if constexpr (TrimMode::trim_left) - while (char_data < char_end && trim_set.contains(*char_data)) - ++char_data; + while (src < src_end && trim_set.contains(*src)) + ++src; if constexpr (TrimMode::trim_right) - while (char_data < char_end && trim_set.contains(*(char_end - 1))) - --char_end; + while (src < src_end && trim_set.contains(*(src_end - 1))) + --src_end; - res_data = reinterpret_cast(char_data); - res_size = char_end - char_data; + dst = const_cast(src); + dst_size = src_end - src; } };