From 4be0828e6e6afa6d9ab67958f5ef5fbe6814252d Mon Sep 17 00:00:00 2001 From: Uros Stankovic Date: Thu, 16 May 2024 19:51:51 +0800 Subject: [PATCH] [SPARK-48288] Add source data type for connector cast expression Currently, V2ExpressionBuilder will build connector.Cast expression from catalyst.Cast expression. Catalyst cast have expression data type, but connector cast does not have it. Since some casts are not allowed on external engine, we need to know source and target data type, since we want finer granularity to block some unsupported casts. ### What changes were proposed in this pull request? Add source data type to connector `Cast` expression ### Why are the changes needed? We need finer granularity to allow implementors of `SQLBuilder` to disable some unsupported casts. ### Does this PR introduce _any_ user-facing change? Yes, visitCast function is changed, and it needs to be overriden again. ### How was this patch tested? No tests made. Simple code change. ### Was this patch authored or co-authored using generative AI tooling? No Closes #46596 from urosstan-db/SPARK-48288-Add-source-data-type-to-connector-cast-expression. Authored-by: Uros Stankovic Signed-off-by: Wenchen Fan --- .../spark/sql/connector/expressions/Cast.java | 18 +++++++++++++++++- .../connector/util/V2ExpressionSQLBuilder.java | 6 +++--- .../catalyst/util/V2ExpressionBuilder.scala | 2 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 4 ++-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index 291f94ec75a8a..25d0c0466aca4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -29,14 +29,30 @@ @Evolving public class Cast extends ExpressionWithToString { private Expression expression; + + /** + * Original data type of given expression + */ + private DataType expressionDataType; + + /** + * Target data type, i.e. data type in which expression will be cast + */ private DataType dataType; + @Deprecated public Cast(Expression expression, DataType dataType) { + this(expression, null, dataType); + } + + public Cast(Expression expression, DataType expressionDataType, DataType targetDataType) { this.expression = expression; - this.dataType = dataType; + this.expressionDataType = expressionDataType; + this.dataType = targetDataType; } public Expression expression() { return expression; } + public DataType expressionDataType() { return expressionDataType; } public DataType dataType() { return dataType; } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 11f4389245d9a..14e2112b7201a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -78,7 +78,7 @@ public String build(Expression expr) { } else if (expr instanceof NamedReference namedReference) { return visitNamedReference(namedReference); } else if (expr instanceof Cast cast) { - return visitCast(build(cast.expression()), cast.dataType()); + return visitCast(build(cast.expression()), cast.expressionDataType(), cast.dataType()); } else if (expr instanceof Extract extract) { return visitExtract(extract.field(), build(extract.source())); } else if (expr instanceof SortOrder sortOrder) { @@ -230,8 +230,8 @@ protected String visitBinaryArithmetic(String name, String l, String r) { return l + " " + name + " " + r; } - protected String visitCast(String l, DataType dataType) { - return "CAST(" + l + " AS " + dataType.typeName() + ")"; + protected String visitCast(String expr, DataType exprDataType, DataType targetDataType) { + return "CAST(" + expr + " AS " + targetDataType.typeName() + ")"; } protected String visitAnd(String name, String l, String r) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 398f21e01b806..ca04991b50fc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -95,7 +95,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, evalMode) if evalMode == EvalMode.ANSI || Cast.canUpCast(child.dataType, dataType) => - generateExpression(child).map(v => new V2Cast(v, dataType)) + generateExpression(child).map(v => new V2Cast(v, child.dataType, dataType)) case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 5f69d18cad756..4ebe73292f11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -391,10 +391,10 @@ abstract class JdbcDialect extends Serializable with Logging { quoteIdentifier(namedRef.fieldNames.head) } - override def visitCast(l: String, dataType: DataType): String = { + override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = { val databaseTypeDefinition = getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName) - s"CAST($l AS $databaseTypeDefinition)" + s"CAST($expr AS $databaseTypeDefinition)" } override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {