From 05eb6c8cc9c1ebbe292f07624a6769dc3c22e133 Mon Sep 17 00:00:00 2001 From: lgbo Date: Wed, 16 Oct 2024 14:08:07 +0800 Subject: [PATCH] rewrite get_json_object in singular_or_list (#7551) --- .../GlutenClickhouseFunctionSuite.scala | 26 +++++++++++++++++++ .../Rewriter/ExpressionRewriter.h | 8 ++++++ 2 files changed, 34 insertions(+) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala index ce8761469c05..3d7b922e7b1b 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala @@ -269,4 +269,30 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { } } + test("GLUTEN-7550 get_json_object in IN") { + withTable("test_7550") { + sql("create table test_7550(a string) using parquet") + val insert_sql = + """ + |insert into test_7550 values('{\'a\':\'1\'}'),('{\'a\':\'2\'}'),('{\'a\':\'3\'}') + |""".stripMargin + sql(insert_sql) + compareResultsAgainstVanillaSpark( + """ + |select a, get_json_object(a, '$.a') in ('1', '2') from test_7550 + |""".stripMargin, + true, + { _ => } + ) + compareResultsAgainstVanillaSpark( + """ + |select a in ('1', '2') from test_7550 + |where get_json_object(a, '$.a') in ('1', '2') + |""".stripMargin, + true, + { _ => } + ) + } + } + } diff --git a/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h b/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h index ab0967eb1248..8d4fcaca4420 100644 --- a/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h +++ b/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h @@ -104,6 +104,10 @@ class GetJsonObjectFunctionWriter : public RelRewriter prepareOnExpression(if_then.else_()); break; } + case substrait::Expression::RexTypeCase::kSingularOrList: { + prepareOnExpression(expr.singular_or_list().value()); + break; + } case substrait::Expression::RexTypeCase::kScalarFunction: { const auto & scalar_function_pb = expr.scalar_function(); auto function_signature_name_opt = parser_context->getFunctionNameInSignature(scalar_function_pb); @@ -160,6 +164,10 @@ class GetJsonObjectFunctionWriter : public RelRewriter rewriteExpression(*if_then->mutable_else_()); break; } + case substrait::Expression::RexTypeCase::kSingularOrList: { + rewriteExpression(*expr.mutable_singular_or_list()->mutable_value()); + break; + } case substrait::Expression::RexTypeCase::kScalarFunction: { auto & scalar_function_pb = *expr.mutable_scalar_function(); if (scalar_function_pb.arguments().empty())