Skip to content

Commit

Permalink
[SPARK-48288] Add source data type for connector cast expression
Browse files Browse the repository at this point in the history
Currently,
V2ExpressionBuilder will build connector.Cast expression from catalyst.Cast expression.
Catalyst cast have expression data type, but connector cast does not have it.
Since some casts are not allowed on external engine, we need to know source and target data type, since we want finer granularity to block some unsupported casts.

### What changes were proposed in this pull request?
Add source data type to connector `Cast` expression

### Why are the changes needed?
We need finer granularity to allow implementors of `SQLBuilder` to disable some unsupported casts.

### Does this PR introduce _any_ user-facing change?
Yes, visitCast function is changed, and it needs to be overriden again.

### How was this patch tested?
No tests made. Simple code change.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46596 from urosstan-db/SPARK-48288-Add-source-data-type-to-connector-cast-expression.

Authored-by: Uros Stankovic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
urosstan-db authored and cloud-fan committed May 16, 2024
1 parent fa83d0f commit 4be0828
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,30 @@
@Evolving
public class Cast extends ExpressionWithToString {
private Expression expression;

/**
* Original data type of given expression
*/
private DataType expressionDataType;

/**
* Target data type, i.e. data type in which expression will be cast
*/
private DataType dataType;

@Deprecated
public Cast(Expression expression, DataType dataType) {
this(expression, null, dataType);
}

public Cast(Expression expression, DataType expressionDataType, DataType targetDataType) {
this.expression = expression;
this.dataType = dataType;
this.expressionDataType = expressionDataType;
this.dataType = targetDataType;
}

public Expression expression() { return expression; }
public DataType expressionDataType() { return expressionDataType; }
public DataType dataType() { return dataType; }

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public String build(Expression expr) {
} else if (expr instanceof NamedReference namedReference) {
return visitNamedReference(namedReference);
} else if (expr instanceof Cast cast) {
return visitCast(build(cast.expression()), cast.dataType());
return visitCast(build(cast.expression()), cast.expressionDataType(), cast.dataType());
} else if (expr instanceof Extract extract) {
return visitExtract(extract.field(), build(extract.source()));
} else if (expr instanceof SortOrder sortOrder) {
Expand Down Expand Up @@ -230,8 +230,8 @@ protected String visitBinaryArithmetic(String name, String l, String r) {
return l + " " + name + " " + r;
}

protected String visitCast(String l, DataType dataType) {
return "CAST(" + l + " AS " + dataType.typeName() + ")";
protected String visitCast(String expr, DataType exprDataType, DataType targetDataType) {
return "CAST(" + expr + " AS " + targetDataType.typeName() + ")";
}

protected String visitAnd(String name, String l, String r) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
}
case Cast(child, dataType, _, evalMode)
if evalMode == EvalMode.ANSI || Cast.canUpCast(child.dataType, dataType) =>
generateExpression(child).map(v => new V2Cast(v, dataType))
generateExpression(child).map(v => new V2Cast(v, child.dataType, dataType))
case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) =>
generateAggregateFunc(aggregateFunction, isDistinct)
case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ abstract class JdbcDialect extends Serializable with Logging {
quoteIdentifier(namedRef.fieldNames.head)
}

override def visitCast(l: String, dataType: DataType): String = {
override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = {
val databaseTypeDefinition =
getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName)
s"CAST($l AS $databaseTypeDefinition)"
s"CAST($expr AS $databaseTypeDefinition)"
}

override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
Expand Down

0 comments on commit 4be0828

Please sign in to comment.