Skip to content

Commit

Permalink
[GLUTEN-6989][CH] Support RTrim with const source column
Browse files Browse the repository at this point in the history
  • Loading branch information
lwz9103 committed Aug 23, 2024
1 parent 860c9c3 commit b1bc696
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,44 @@ class GlutenClickhouseStringFunctionsSuite extends GlutenClickHouseWholeStageTra
}
}

test("GLUTEN-6989: rtrim support source column const") {
withTable("trim") {
sql("create table trim(trim_col String, src_col String) using parquet")
sql("""
|insert into trim values
| ('aba', 'a'),('bba', 'b'),('abcdef', 'abcd'),
| (null, '123'),('123', null), ('', 'aaa'), ('bbb', '')
|""".stripMargin)

val sql0 = "select rtrim('aba', 'a') from trim order by src_col"
val sql1 = "select rtrim(trim_col, src_col) from trim order by src_col"
val sql2 = "select rtrim(trim_col, 'NSSS') from trim order by src_col"
val sql3 = "select rtrim(trim_col, '') from trim order by src_col"
val sql4 = "select rtrim('', 'AAA') from trim order by src_col"
val sql5 = "select rtrim('', src_col) from trim order by src_col"
val sql6 = "select rtrim('ttt', src_col) from trim order by src_col"

runQueryAndCompare(sql0) { _ => }
runQueryAndCompare(sql1) { _ => }
runQueryAndCompare(sql2) { _ => }
runQueryAndCompare(sql3) { _ => }
runQueryAndCompare(sql4) { _ => }
runQueryAndCompare(sql5) { _ => }
runQueryAndCompare(sql6) { _ => }

// test other trim functions
val sql7 = "SELECT trim(LEADING trim_col FROM src_col) from trim"
val sql8 = "SELECT trim(LEADING trim_col FROM 'NSB') from trim"
val sql9 = "SELECT trim(TRAILING trim_col FROM src_col) from trim"
val sql10 = "SELECT trim(TRAILING trim_col FROM '') from trim"
runQueryAndCompare(sql7) { _ => }
runQueryAndCompare(sql8) { _ => }
runQueryAndCompare(sql9) { _ => }
runQueryAndCompare(sql10) { _ => }

}
}

test("GLUTEN-5897: fix regexp_extract with bracket") {
withTable("regexp_extract_bracket") {
sql("create table regexp_extract_bracket(a String) using parquet")
Expand Down
172 changes: 76 additions & 96 deletions cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,127 +104,107 @@ namespace
}

ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const override
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
const ColumnString * src_str_col = checkAndGetColumn<ColumnString>(arguments[0].column.get());
if (!src_str_col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be String", getName());

const ColumnString * src_col = checkAndGetColumn<ColumnString>(arguments[0].column.get());
const ColumnConst * src_const_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
const ColumnString * trim_col = checkAndGetColumn<ColumnString>(arguments[1].column.get());
const ColumnConst * trim_const_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get());

String src_const_str;
String trim_const_str;
if (src_const_col)
src_const_str = src_const_col->getValue<String>();
if (trim_const_col)
trim_const_str = trim_const_col->getValue<String>();
if (trim_const_col && trim_const_str.empty()) {
return arguments[0].column->cloneResized(input_rows_count);
}

if (const auto * trim_const_str_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
if (src_const_col && trim_const_col)
{
String trim_str = trim_const_str_col->getValue<String>();
if (trim_str.empty())
return src_str_col->cloneResized(input_rows_count);
const char * dst;
size_t dst_size;
std::unordered_set<char> trim_set(trim_const_str.begin(), trim_const_str.end());
trim(src_const_str.c_str(), src_const_str.size(), dst, dst_size, trim_set);
return result_type->createColumnConst(input_rows_count, String(dst, dst_size));
}

auto res_col = ColumnString::create();
res_col->reserve(input_rows_count);
executeVector(src_str_col->getChars(), src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(), trim_str);
auto res_col = ColumnString::create();
ColumnString::Chars & res_data = res_col->getChars();
ColumnString::Offsets & res_offsets = res_col->getOffsets();
res_offsets.resize_exact(input_rows_count);

if (src_const_col)
{
res_data.reserve_exact(src_const_str.size() * input_rows_count);
for (size_t row = 0; row < input_rows_count; ++row)
{
StringRef trim_str_ref = trim_col->getDataAt(row);
std::unordered_set<char> trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size);
executeRow(src_const_str.c_str(), src_const_str.size(), res_data, res_offsets, row, trim_set);
}
return std::move(res_col);
}
else if (const auto * trim_str_col = checkAndGetColumn<ColumnString>(arguments[1].column.get()))

if (trim_const_col)
{
auto res_col = ColumnString::create();
res_col->reserve(input_rows_count);

executeVector(
src_str_col->getChars(),
src_str_col->getOffsets(),
res_col->getChars(),
res_col->getOffsets(),
trim_str_col->getChars(),
trim_str_col->getOffsets());
res_data.reserve_exact(src_col->getChars().size());
std::unordered_set<char> trim_set(trim_const_str.begin(), trim_const_str.end());
for (size_t row = 0; row < input_rows_count; ++row)
{
StringRef src_str_ref = src_col->getDataAt(row);
executeRow(src_str_ref.data, src_str_ref.size, res_data, res_offsets, row, trim_set);
}
return std::move(res_col);
}

throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be String or Const String", getName());
}

private:
void executeVector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets,
const String & trim_str) const
{
res_data.reserve_exact(data.size());

size_t rows = offsets.size();
res_offsets.resize_exact(rows);

size_t prev_offset = 0;
size_t res_offset = 0;

const UInt8 * start;
size_t length;
std::unordered_set<char> trim_set(trim_str.begin(), trim_str.end());
for (size_t i = 0; i < rows; ++i)
// Both columns are not constant
res_data.reserve(src_col->getChars().size());
for (size_t row = 0; row < input_rows_count; ++row)
{
trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length, trim_set);
res_data.resize_exact(res_data.size() + length + 1);
memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length);
res_offset += length + 1;
res_data[res_offset - 1] = '\0';

res_offsets[i] = res_offset;
prev_offset = offsets[i];
StringRef src_str_ref = src_col->getDataAt(row);
StringRef trim_str_ref = trim_col->getDataAt(row);
std::unordered_set<char> trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size);
executeRow(src_str_ref.data, src_str_ref.size, res_data, res_offsets, row, trim_set);
}
return std::move(res_col);
}

void executeVector(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
private:
void executeRow(
const char * src,
size_t src_size,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets,
const ColumnString::Chars & trim_data,
const ColumnString::Offsets & trim_offsets) const
size_t & row,
const std::unordered_set<char> & trim_set) const
{
res_data.reserve_exact(data.size());

size_t rows = offsets.size();
res_offsets.resize_exact(rows);

size_t prev_offset = 0;
size_t prev_trim_str_offset = 0;
size_t res_offset = 0;

const UInt8 * start;
size_t length;

for (size_t i = 0; i < rows; ++i)
{
std::unordered_set<char> trim_set(
&trim_data[prev_trim_str_offset], &trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1);

trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length, trim_set);
res_data.resize_exact(res_data.size() + length + 1);
memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length);
res_offset += length + 1;
res_data[res_offset - 1] = '\0';

res_offsets[i] = res_offset;
prev_offset = offsets[i];
prev_trim_str_offset = trim_offsets[i];
}
const char * dst;
size_t dst_size;
trim(src, src_size, dst, dst_size, trim_set);
size_t res_offset = row > 0 ? res_offsets[row - 1] : 0;
res_data.resize_exact(res_data.size() + dst_size + 1);
memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], dst, dst_size);
res_offset += dst_size + 1;
res_data[res_offset - 1] = '\0';
res_offsets[row] = res_offset;
}

void
trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t & res_size, const std::unordered_set<char> & trim_set) const
void trim(const char * src, size_t src_size, const char *& dst, size_t & dst_size, const std::unordered_set<char> & trim_set) const
{
const char * char_data = reinterpret_cast<const char *>(data);
const char * char_end = char_data + size;
const char * src_end = src + src_size;

if constexpr (TrimMode::trim_left)
while (char_data < char_end && trim_set.contains(*char_data))
++char_data;
while (src < src_end && trim_set.contains(*src))
++src;

if constexpr (TrimMode::trim_right)
while (char_data < char_end && trim_set.contains(*(char_end - 1)))
--char_end;
while (src < src_end && trim_set.contains(*(src_end - 1)))
--src_end;

res_data = reinterpret_cast<const UInt8 *>(char_data);
res_size = char_end - char_data;
dst = const_cast<char *>(src);
dst_size = src_end - src;
}
};

Expand Down

0 comments on commit b1bc696

Please sign in to comment.