Skip to content

Commit

Permalink
typelevel#787 typelevel#803 - rc4 usage and fix udf with expressionproxy
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 14, 2024
1 parent dd10cee commit f253d45
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ val scalacheck = "1.17.0"
val scalacheckEffect = "1.0.4"
val refinedVersion = "0.11.1"
val nakedFSVersion = "0.1.0"
val shimVersion = "0.0.1-RC3"
val shimVersion = "0.0.1-RC4"

val Scala212 = "2.12.19"
val Scala213 = "2.13.13"
Expand Down
11 changes: 9 additions & 2 deletions dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package frameless
package functions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionProxy, LeafExpression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.codegen._
import Block._
import org.apache.spark.sql.types.DataType
Expand Down Expand Up @@ -132,6 +132,13 @@ case class FramelessUdf[T, R](

def dataType: DataType = rencoder.catalystRepr

// #803 - SPARK-41991 fixes this for the most part, this is a belts and braces approach
def nonProxy(child: Expression): Expression =
child match {
case p: ExpressionProxy => nonProxy(p.child)
case _ => child
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ctx.references += this

Expand All @@ -145,7 +152,7 @@ case class FramelessUdf[T, R](

val (argsCode, funcArguments) = encoders.zip(children).map {
case (encoder, child) =>
val eval = child.genCode(ctx)
val eval = nonProxy(child).genCode(ctx)
val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr)
val argTerm = ctx.freshName("arg")
val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));"
Expand Down

0 comments on commit f253d45

Please sign in to comment.