diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index 291f94ec75a8a..25d0c0466aca4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -29,14 +29,30 @@ @Evolving public class Cast extends ExpressionWithToString { private Expression expression; + + /** + * Original data type of given expression + */ + private DataType expressionDataType; + + /** + * Target data type, i.e. data type in which expression will be cast + */ private DataType dataType; + @Deprecated public Cast(Expression expression, DataType dataType) { + this(expression, null, dataType); + } + + public Cast(Expression expression, DataType expressionDataType, DataType targetDataType) { this.expression = expression; - this.dataType = dataType; + this.expressionDataType = expressionDataType; + this.dataType = targetDataType; } public Expression expression() { return expression; } + public DataType expressionDataType() { return expressionDataType; } public DataType dataType() { return dataType; } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 11f4389245d9a..14e2112b7201a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -78,7 +78,7 @@ public String build(Expression expr) { } else if (expr instanceof NamedReference namedReference) { return visitNamedReference(namedReference); } else if (expr instanceof Cast cast) { - return visitCast(build(cast.expression()), cast.dataType()); + return visitCast(build(cast.expression()), cast.expressionDataType(), cast.dataType()); } else if (expr instanceof Extract extract) { return visitExtract(extract.field(), build(extract.source())); } else if (expr instanceof SortOrder sortOrder) { @@ -230,8 +230,8 @@ protected String visitBinaryArithmetic(String name, String l, String r) { return l + " " + name + " " + r; } - protected String visitCast(String l, DataType dataType) { - return "CAST(" + l + " AS " + dataType.typeName() + ")"; + protected String visitCast(String expr, DataType exprDataType, DataType targetDataType) { + return "CAST(" + expr + " AS " + targetDataType.typeName() + ")"; } protected String visitAnd(String name, String l, String r) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 398f21e01b806..ca04991b50fc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -95,7 +95,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, evalMode) if evalMode == EvalMode.ANSI || Cast.canUpCast(child.dataType, dataType) => - generateExpression(child).map(v => new V2Cast(v, dataType)) + generateExpression(child).map(v => new V2Cast(v, child.dataType, dataType)) case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 5f69d18cad756..4ebe73292f11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -391,10 +391,10 @@ abstract class JdbcDialect extends Serializable with Logging { quoteIdentifier(namedRef.fieldNames.head) } - override def visitCast(l: String, dataType: DataType): String = { + override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = { val databaseTypeDefinition = getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName) - s"CAST($l AS $databaseTypeDefinition)" + s"CAST($expr AS $databaseTypeDefinition)" } override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {