Skip to content

Commit

Permalink
Change appraoch to use QueryPlanSerde instead of modifying the planner
Browse files Browse the repository at this point in the history
  • Loading branch information
lithium323 committed Jul 13, 2024
1 parent 3f6da50 commit 8c802c4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 44 deletions.
42 changes: 1 addition & 41 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1654,18 +1654,7 @@ impl PhysicalPlanner {
data_type,
));

let is_unary = expr.args.len() == 1;
let unary_arg = expr
.args
.first()
.and_then(|x| self.create_expr(x, input_schema.clone()).ok());

let log_functions = ["ln", "log2", "log10"];

match fun_name.as_str() {
log if log_functions.contains(&log) && is_unary => spark_log(scalar_expr, unary_arg),
_ => Ok(scalar_expr),
}
Ok(scalar_expr)
}
}

Expand Down Expand Up @@ -1825,35 +1814,6 @@ fn rewrite_physical_expr(
Ok(expr.rewrite(&mut rewriter).data()?)
}


/// Modifies the physical expression for `log` functions so that it is defined as null on numbers
/// less than or equal to 0. This matches Spark and Hive behavior, where values less than or
/// equal to 0 eval to null, instead of NaN or -Infinity
fn spark_log(
datafusion_expr: Arc<dyn PhysicalExpr>,
arg: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
match arg {
Some(arg) => {
let less_than_0 = BinaryExpr::new(
arg,
DataFusionOperator::LtEq,
Arc::new(DataFusionLiteral::new(ScalarValue::Float64(Some(0.0)))),
);

let lit_null = Arc::new(DataFusionLiteral::new(ScalarValue::Float64(None)));

// values less than or equal to 0 eval to null in Hive, instead of NaN or -Infinity
let if_expr = IfExpr::new(Arc::new(less_than_0), lit_null, datafusion_expr);
Ok(Arc::new(if_expr))
}

// If a first arg could not be resolved
None => Ok(datafusion_expr),
}
}


fn from_protobuf_eval_mode(value: i32) -> Result<EvalMode, prost::DecodeError> {
match spark_expression::EvalMode::try_from(value)? {
spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy),
Expand Down
14 changes: 11 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1703,18 +1703,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
optExprWithInfo(optExpr, expr, child)
}

/// The expression for `log` functions is defined as null on numbers less than or equal
/// to 0. This matches Spark and Hive behavior, where non positive values eval to null
/// instead of NaN or -Infinity
case Log(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("ln", childExpr)
optExprWithInfo(optExpr, expr, child)

case Log10(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("log10", childExpr)
optExprWithInfo(optExpr, expr, child)

case Log2(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("log2", childExpr)
optExprWithInfo(optExpr, expr, child)

Expand Down Expand Up @@ -2393,6 +2396,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expression
}

def nullIfNegative(expression: Expression): Expression = {
val zero = Literal.default(expression.dataType)
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
}

/**
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
Expand Down

0 comments on commit 8c802c4

Please sign in to comment.