From 676769cde4a15c624e74c1eba0f59bb1f7561faa Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 11 Nov 2024 13:49:09 -0700 Subject: [PATCH] chore: Refactor binary and math expression serde code (#1069) * refactor binary expr * refactor math expr --- .../apache/comet/serde/QueryPlanSerde.scala | 328 +++++++++--------- 1 file changed, 164 insertions(+), 164 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c47c2b6a5..2a86c5c36 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -932,26 +932,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim handleCast(child, inputs, dt, timeZoneId, evalMode(c)) case add @ Add(left, right, _) if supportedDataType(left.dataType) => - createMathExpression(left, right, inputs, add.dataType, getFailOnError(add)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setAdd(expr) - .build() - } + createMathExpression( + left, + right, + inputs, + add.dataType, + getFailOnError(add), + (builder, mathExpr) => builder.setAdd(mathExpr)) case add @ Add(left, _, _) if !supportedDataType(left.dataType) => withInfo(add, s"Unsupported datatype ${left.dataType}") None case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - createMathExpression(left, right, inputs, sub.dataType, getFailOnError(sub)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setSubtract(expr) - .build() - } + createMathExpression( + left, + right, + inputs, + sub.dataType, + getFailOnError(sub), + (builder, mathExpr) => builder.setSubtract(mathExpr)) case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) => withInfo(sub, s"Unsupported datatype ${left.dataType}") @@ -959,13 +959,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - createMathExpression(left, right, inputs, mul.dataType, getFailOnError(mul)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setMultiply(expr) - .build() - } + createMathExpression( + left, + right, + inputs, + mul.dataType, + getFailOnError(mul), + (builder, mathExpr) => builder.setMultiply(mathExpr)) case mul @ Multiply(left, _, _) => if (!supportedDataType(left.dataType)) { @@ -983,13 +983,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // For now, use NullIf to swap zeros with nulls. val rightExpr = nullIfWhenPrimitive(right) - createMathExpression(left, rightExpr, inputs, div.dataType, getFailOnError(div)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setDivide(expr) - .build() - } + createMathExpression( + left, + rightExpr, + inputs, + div.dataType, + getFailOnError(div), + (builder, mathExpr) => builder.setDivide(mathExpr)) case div @ Divide(left, _, _) => if (!supportedDataType(left.dataType)) { @@ -1004,13 +1004,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => val rightExpr = nullIfWhenPrimitive(right) - createMathExpression(left, rightExpr, inputs, rem.dataType, getFailOnError(rem)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setRemainder(expr) - .build() - } + createMathExpression( + left, + rightExpr, + inputs, + rem.dataType, + getFailOnError(rem), + (builder, mathExpr) => builder.setRemainder(mathExpr)) case rem @ Remainder(left, _, _) => if (!supportedDataType(left.dataType)) { @@ -1022,68 +1022,60 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None case EqualTo(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setEq(binaryExpr)) case Not(EqualTo(left, right)) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setNeq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setNeq(binaryExpr)) case EqualNullSafe(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setEqNullSafe(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr)) case Not(EqualNullSafe(left, right)) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setNeqNullSafe(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) case GreaterThan(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setGt(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setGt(binaryExpr)) case GreaterThanOrEqual(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setGtEq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setGtEq(binaryExpr)) case LessThan(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setLt(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setLt(binaryExpr)) case LessThanOrEqual(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setLtEq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setLtEq(binaryExpr)) case Literal(value, dataType) if supportedDataType(dataType, allowStruct = value == null) => @@ -1220,12 +1212,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case Like(left, right, escapeChar) => if (escapeChar == '\\') { - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setLike(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setLike(binaryExpr)) } else { // TODO custom escape char withInfo(expr, s"custom escape character $escapeChar not supported in LIKE") @@ -1250,36 +1241,32 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim return None } - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setRlike(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setRlike(binaryExpr)) case StartsWith(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setStartsWith(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setStartsWith(binaryExpr)) case EndsWith(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setEndsWith(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setEndsWith(binaryExpr)) case Contains(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setContains(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setContains(binaryExpr)) case StringSpace(child) => createUnaryExpr( @@ -1460,20 +1447,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case And(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setAnd(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setAnd(binaryExpr)) case Or(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setOr(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setOr(binaryExpr)) case UnaryExpression(child) if expr.prettyName == "promote_precision" => // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 @@ -1910,31 +1895,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case BitwiseAnd(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseAnd(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) case BitwiseNot(child) => createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setBitwiseNot(unaryExpr)) case BitwiseOr(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseOr(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) case BitwiseXor(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseXor(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) case ShiftRight(left, right) => // DataFusion bitwise shift right expression requires @@ -1945,12 +1927,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right } - createBinaryExpr(left, rightExpression, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftRight(builder) - .build() - } + createBinaryExpr( + left, + rightExpression, + inputs, + (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr)) case ShiftLeft(left, right) => // DataFusion bitwise shift right expression requires @@ -1961,13 +1942,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right } - createBinaryExpr(left, rightExpression, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftLeft(builder) - .build() - } - + createBinaryExpr( + left, + rightExpression, + inputs, + (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr)) case In(value, list) => in(expr, value, list, inputs, false) @@ -2307,16 +2286,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def createBinaryExpr( left: Expression, right: Expression, - inputs: Seq[Attribute]): Option[ExprOuterClass.BinaryExpr] = { + inputs: Seq[Attribute], + f: ( + ExprOuterClass.Expr.Builder, + ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic BinaryExpr message + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(leftExpr.get) + .setRight(rightExpr.get) + .build() + // call the user-supplied function to wrap BinaryExpr in a top-level Expr + // such as Expr.And or Expr.Or Some( - ExprOuterClass.BinaryExpr - .newBuilder() - .setLeft(leftExpr.get) - .setRight(rightExpr.get) - .build()) + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) } else { withInfo(expr, left, right) None @@ -2328,11 +2318,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right: Expression, inputs: Seq[Attribute], dataType: DataType, - failOnError: Boolean): Option[ExprOuterClass.MathExpr] = { + failOnError: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic MathExpr message val builder = ExprOuterClass.MathExpr.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) @@ -2340,7 +2333,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim serializeDataType(dataType).foreach { t => builder.setReturnType(t) } - Some(builder.build()) + val inner = builder.build() + // call the user-supplied function to wrap MathExpr in a top-level Expr + // such as Expr.Add or Expr.Divide + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) } else { withInfo(expr, left, right) None