Skip to content

Commit

Permalink
typelevel#787 typelevel#803 - rc4 usage and fix udf with expressionpr…
Browse files Browse the repository at this point in the history
…oxy - deeply nested also possible
  • Loading branch information
chris-twiner committed Mar 14, 2024
1 parent 7c1e603 commit b161067
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 90 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ val shimVersion = "0.0.1-RC4"
val Scala212 = "2.12.19"
val Scala213 = "2.13.13"

//resolvers in Global += Resolver.mavenLocal
resolvers in Global += Resolver.mavenLocal
resolvers in Global += MavenRepository(
"sonatype-s01-snapshots",
Resolver.SonatypeS01RepositoryRoot + "/snapshots"
Expand Down
225 changes: 136 additions & 89 deletions dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,90 +2,108 @@ package frameless
package functions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionProxy, LeafExpression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.{
Expression,
ExpressionProxy,
LeafExpression,
NonSQLExpression
}
import org.apache.spark.sql.catalyst.expressions.codegen._
import Block._
import org.apache.spark.sql.types.DataType
import shapeless.syntax.std.tuple._

/** Documentation marked "apache/spark" is thanks to apache/spark Contributors
* at https://github.com/apache/spark, licensed under Apache v2.0 available at
* http://www.apache.org/licenses/LICENSE-2.0
*/
/**
* Documentation marked "apache/spark" is thanks to apache/spark Contributors
* at https://github.com/apache/spark, licensed under Apache v2.0 available at
* http://www.apache.org/licenses/LICENSE-2.0
*/
trait Udf {

/** Defines a user-defined function of 1 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A, R: TypedEncoder](f: A => R):
TypedColumn[T, A] => TypedColumn[T, R] = {
/**
* Defines a user-defined function of 1 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A, R: TypedEncoder](f: A => R): TypedColumn[T, A] => TypedColumn[T, R] = {
u =>
val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
}

/** Defines a user-defined function of 2 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, R: TypedEncoder](f: (A1,A2) => R):
(TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = {
/**
* Defines a user-defined function of 2 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, R: TypedEncoder](f: (A1, A2) => R): (
TypedColumn[T, A1],
TypedColumn[T, A2]
) => TypedColumn[T, R] = {
case us =>
val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
}
}

/** Defines a user-defined function of 3 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1,A2,A3) => R):
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = {
/**
* Defines a user-defined function of 3 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1, A2, A3) => R): (
TypedColumn[T, A1],
TypedColumn[T, A2],
TypedColumn[T, A3]
) => TypedColumn[T, R] = {
case us =>
val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
}
}

/** Defines a user-defined function of 4 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1,A2,A3,A4) => R):
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = {
/**
* Defines a user-defined function of 4 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = {
case us =>
val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
}
}

/** Defines a user-defined function of 5 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1,A2,A3,A4,A5) => R):
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = {
/**
* Defines a user-defined function of 5 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
*
* apache/spark
*/
def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = {
case us =>
val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
val scalaUdf =
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
}
}
}

/**
* NB: Implementation detail, isn't intended to be directly used.
*
* Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]].
*/
* NB: Implementation detail, isn't intended to be directly used.
*
* Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]].
*/
case class FramelessUdf[T, R](
function: AnyRef,
encoders: Seq[TypedEncoder[_]],
children: Seq[Expression],
rencoder: TypedEncoder[R]
) extends Expression with NonSQLExpression {
function: AnyRef,
encoders: Seq[TypedEncoder[_]],
children: Seq[Expression],
rencoder: TypedEncoder[R])
extends Expression
with NonSQLExpression {

override def nullable: Boolean = rencoder.nullable
override def toString: String = s"FramelessUdf(${children.mkString(", ")})"
Expand Down Expand Up @@ -118,10 +136,12 @@ case class FramelessUdf[T, R](
"""

val code = CodeFormatter.stripOverlappingComments(
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
)

val (clazz, _) = CodeGenerator.compile(code)
val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]
val codegen =
clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]

codegen
}
Expand All @@ -136,7 +156,7 @@ case class FramelessUdf[T, R](
def nonProxy(child: Expression): Expression =
child transform {
case p: ExpressionProxy => p.child
case _ => child
case everythingElse => everythingElse
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -146,29 +166,45 @@ case class FramelessUdf[T, R](
val framelessUdfClassName = classOf[FramelessUdf[_, _]].getName
val funcClassName = s"scala.Function${children.size}"
val funcExpressionIdx = ctx.references.size - 1
val funcTerm = ctx.addMutableState(funcClassName, ctx.freshName("udf"),
v => s"$v = ($funcClassName)((($framelessUdfClassName)references" +
s"[$funcExpressionIdx]).function());")

val (argsCode, funcArguments) = encoders.zip(children).map {
case (encoder, child) =>
val eval = nonProxy(child).genCode(ctx)
val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr)
val argTerm = ctx.freshName("arg")
val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));"
val funcTerm = ctx.addMutableState(
funcClassName,
ctx.freshName("udf"),
v =>
s"$v = ($funcClassName)((($framelessUdfClassName)references" +
s"[$funcExpressionIdx]).function());"
)

(convert, argTerm)
}.unzip
val (argsCode, funcArguments) = encoders
.zip(children)
.map {
case (encoder, child) =>
val eval = nonProxy(child).genCode(ctx)
val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr)
val argTerm = ctx.freshName("arg")
val convert =
s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));"

(convert, argTerm)
}
.unzip

val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr)
val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal"))
val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull"))
val internalTerm =
ctx.addMutableState(internalTpe, ctx.freshName("internal"))
val internalNullTerm =
ctx.addMutableState("boolean", ctx.freshName("internalNull"))
// CTw - can't inject the term, may have to duplicate old code for parity
val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true)
val internalExpr = Spark2_4_LambdaVariable(
internalTerm,
internalNullTerm,
rencoder.jvmRepr,
true
)

val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx)

ev.copy(code = code"""
ev.copy(
code = code"""
${argsCode.mkString("\n")}

$internalTerm =
Expand All @@ -182,29 +218,39 @@ case class FramelessUdf[T, R](
)
}

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren)
protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]
): Expression = copy(children = newChildren)
}

case class Spark2_4_LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true)
extends LeafExpression
with NonSQLExpression {

private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
private val accessor: (InternalRow, Int) => Any =
InternalRow.getAccessor(dataType)

// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
assert(input.numFields == 1,
"The input row of interpreted LambdaVariable should have only 1 field.")
assert(
input.numFields == 1,
"The input row of interpreted LambdaVariable should have only 1 field."
)
if (nullable && input.isNullAt(0)) {
null
} else {
accessor(input, 0)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
override protected def doGenCode(
ctx: CodegenContext,
ev: ExprCode
): ExprCode = {
val isNullValue = if (nullable) {
JavaCode.isNullVariable(isNull)
} else {
Expand All @@ -215,12 +261,13 @@ case class Spark2_4_LambdaVariable(
}

object FramelessUdf {

// Spark needs case class with `children` field to mutate it
def apply[T, R](
function: AnyRef,
cols: Seq[UntypedExpression[T]],
rencoder: TypedEncoder[R]
): FramelessUdf[T, R] = FramelessUdf(
function: AnyRef,
cols: Seq[UntypedExpression[T]],
rencoder: TypedEncoder[R]
): FramelessUdf[T, R] = FramelessUdf(
function = function,
encoders = cols.map(_.uencoder).toList,
children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList,
Expand Down

0 comments on commit b161067

Please sign in to comment.