From b1610674df660c986aeb4f923b721ce376a5925a Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Thu, 14 Mar 2024 20:28:00 +0100 Subject: [PATCH] #787 #803 - rc4 usage and fix udf with expressionproxy - deeply nested also possible --- build.sbt | 2 +- .../main/scala/frameless/functions/Udf.scala | 225 +++++++++++------- 2 files changed, 137 insertions(+), 90 deletions(-) diff --git a/build.sbt b/build.sbt index d72fe754..a3dafd97 100644 --- a/build.sbt +++ b/build.sbt @@ -17,7 +17,7 @@ val shimVersion = "0.0.1-RC4" val Scala212 = "2.12.19" val Scala213 = "2.13.13" -//resolvers in Global += Resolver.mavenLocal +resolvers in Global += Resolver.mavenLocal resolvers in Global += MavenRepository( "sonatype-s01-snapshots", Resolver.SonatypeS01RepositoryRoot + "/snapshots" diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index 117f22b2..aa58cfc1 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -2,90 +2,108 @@ package frameless package functions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionProxy, LeafExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + ExpressionProxy, + LeafExpression, + NonSQLExpression +} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ import org.apache.spark.sql.types.DataType import shapeless.syntax.std.tuple._ -/** Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ +/** + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ trait Udf { - /** Defines a user-defined function of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A, R: TypedEncoder](f: A => R): - TypedColumn[T, A] => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A, R: TypedEncoder](f: A => R): TypedColumn[T, A] => TypedColumn[T, R] = { u => val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) } - /** Defines a user-defined function of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, R: TypedEncoder](f: (A1,A2) => R): - (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, R: TypedEncoder](f: (A1, A2) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1,A2,A3) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1, A2, A3) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1,A2,A3,A4) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1,A2,A3,A4,A5) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } } /** - * NB: Implementation detail, isn't intended to be directly used. - * - * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. - */ + * NB: Implementation detail, isn't intended to be directly used. + * + * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. + */ case class FramelessUdf[T, R]( - function: AnyRef, - encoders: Seq[TypedEncoder[_]], - children: Seq[Expression], - rencoder: TypedEncoder[R] -) extends Expression with NonSQLExpression { + function: AnyRef, + encoders: Seq[TypedEncoder[_]], + children: Seq[Expression], + rencoder: TypedEncoder[R]) + extends Expression + with NonSQLExpression { override def nullable: Boolean = rencoder.nullable override def toString: String = s"FramelessUdf(${children.mkString(", ")})" @@ -118,10 +136,12 @@ case class FramelessUdf[T, R]( """ val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) + ) val (clazz, _) = CodeGenerator.compile(code) - val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] + val codegen = + clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] codegen } @@ -136,7 +156,7 @@ case class FramelessUdf[T, R]( def nonProxy(child: Expression): Expression = child transform { case p: ExpressionProxy => p.child - case _ => child + case everythingElse => everythingElse } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -146,29 +166,45 @@ case class FramelessUdf[T, R]( val framelessUdfClassName = classOf[FramelessUdf[_, _]].getName val funcClassName = s"scala.Function${children.size}" val funcExpressionIdx = ctx.references.size - 1 - val funcTerm = ctx.addMutableState(funcClassName, ctx.freshName("udf"), - v => s"$v = ($funcClassName)((($framelessUdfClassName)references" + - s"[$funcExpressionIdx]).function());") - - val (argsCode, funcArguments) = encoders.zip(children).map { - case (encoder, child) => - val eval = nonProxy(child).genCode(ctx) - val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) - val argTerm = ctx.freshName("arg") - val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + val funcTerm = ctx.addMutableState( + funcClassName, + ctx.freshName("udf"), + v => + s"$v = ($funcClassName)((($framelessUdfClassName)references" + + s"[$funcExpressionIdx]).function());" + ) - (convert, argTerm) - }.unzip + val (argsCode, funcArguments) = encoders + .zip(children) + .map { + case (encoder, child) => + val eval = nonProxy(child).genCode(ctx) + val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) + val argTerm = ctx.freshName("arg") + val convert = + s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + + (convert, argTerm) + } + .unzip val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr) - val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal")) - val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull")) + val internalTerm = + ctx.addMutableState(internalTpe, ctx.freshName("internal")) + val internalNullTerm = + ctx.addMutableState("boolean", ctx.freshName("internalNull")) // CTw - can't inject the term, may have to duplicate old code for parity - val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true) + val internalExpr = Spark2_4_LambdaVariable( + internalTerm, + internalNullTerm, + rencoder.jvmRepr, + true + ) val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx) - ev.copy(code = code""" + ev.copy( + code = code""" ${argsCode.mkString("\n")} $internalTerm = @@ -182,21 +218,28 @@ case class FramelessUdf[T, R]( ) } - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(children = newChildren) } case class Spark2_4_LambdaVariable( - value: String, - isNull: String, - dataType: DataType, - nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) + extends LeafExpression + with NonSQLExpression { - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = + InternalRow.getAccessor(dataType) // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { - assert(input.numFields == 1, - "The input row of interpreted LambdaVariable should have only 1 field.") + assert( + input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field." + ) if (nullable && input.isNullAt(0)) { null } else { @@ -204,7 +247,10 @@ case class Spark2_4_LambdaVariable( } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = { val isNullValue = if (nullable) { JavaCode.isNullVariable(isNull) } else { @@ -215,12 +261,13 @@ case class Spark2_4_LambdaVariable( } object FramelessUdf { + // Spark needs case class with `children` field to mutate it def apply[T, R]( - function: AnyRef, - cols: Seq[UntypedExpression[T]], - rencoder: TypedEncoder[R] - ): FramelessUdf[T, R] = FramelessUdf( + function: AnyRef, + cols: Seq[UntypedExpression[T]], + rencoder: TypedEncoder[R] + ): FramelessUdf[T, R] = FramelessUdf( function = function, encoders = cols.map(_.uencoder).toList, children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList,