From f253d45d3db5d643a7194ed7df29051d9ecd4151 Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Thu, 14 Mar 2024 17:40:15 +0100 Subject: [PATCH] #787 #803 - rc4 usage and fix udf with expressionproxy --- build.sbt | 2 +- dataset/src/main/scala/frameless/functions/Udf.scala | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/build.sbt b/build.sbt index a0f8c590..d72fe754 100644 --- a/build.sbt +++ b/build.sbt @@ -12,7 +12,7 @@ val scalacheck = "1.17.0" val scalacheckEffect = "1.0.4" val refinedVersion = "0.11.1" val nakedFSVersion = "0.1.0" -val shimVersion = "0.0.1-RC3" +val shimVersion = "0.0.1-RC4" val Scala212 = "2.12.19" val Scala213 = "2.13.13" diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index 93ba7f11..7aa38937 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -2,7 +2,7 @@ package frameless package functions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionProxy, LeafExpression, NonSQLExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ import org.apache.spark.sql.types.DataType @@ -132,6 +132,13 @@ case class FramelessUdf[T, R]( def dataType: DataType = rencoder.catalystRepr + // #803 - SPARK-41991 fixes this for the most part, this is a belts and braces approach + def nonProxy(child: Expression): Expression = + child match { + case p: ExpressionProxy => nonProxy(p.child) + case _ => child + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ctx.references += this @@ -145,7 +152,7 @@ case class FramelessUdf[T, R]( val (argsCode, funcArguments) = encoders.zip(children).map { case (encoder, child) => - val eval = child.genCode(ctx) + val eval = nonProxy(child).genCode(ctx) val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) val argTerm = ctx.freshName("arg") val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));"