Skip to content

Commit

Permalink
[SPARK-50581][SQL] fix support for UDAF in Dataset.observe()
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Mark inputProjection field as transient in org.apache.spark.sql.execution.aggregate.ScalaAggregator.

### Why are the changes needed?

To support UDAF in Dataset.observe() which currently fails due to serialization exception.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added test.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#49190 from tomsisso/support-udaf-in-observe.

Authored-by: tomsisso <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
toms-definity authored and HyukjinKwon committed Dec 16, 2024
1 parent 0faf9d5 commit 44ab349
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -501,25 +501,25 @@ case class ScalaAggregator[IN, BUF, OUT](
with Logging {

// input and buffer encoders are resolved by ResolveEncodersInScalaAgg
private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer()
private[this] lazy val outputEncoder = encoderFor(agg.outputEncoder)
private[this] lazy val outputSerializer = outputEncoder.createSerializer()
@transient private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
@transient private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
@transient private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer()
@transient private[this] lazy val outputEncoder = encoderFor(agg.outputEncoder)
@transient private[this] lazy val outputSerializer = outputEncoder.createSerializer()

def dataType: DataType = outputEncoder.objSerializer.dataType

def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)

override lazy val deterministic: Boolean = isDeterministic
@transient override lazy val deterministic: Boolean = isDeterministic

def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
copy(inputAggBufferOffset = newInputAggBufferOffset)

private[this] lazy val inputProjection = UnsafeProjection.create(children)
@transient private[this] lazy val inputProjection = UnsafeProjection.create(children)

def createAggregationBuffer(): BUF = agg.zero

Expand All @@ -533,7 +533,7 @@ case class ScalaAggregator[IN, BUF, OUT](
if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
}

private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length)
@transient private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length)

def serialize(agg: BUF): Array[Byte] =
bufferSerializer(agg).asInstanceOf[UnsafeRow].getBytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import java.lang.{Long => JLong}
import scala.collection.mutable.ArrayBuffer

import org.apache.spark._
import org.apache.spark.sql.{functions, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.{functions, Dataset, Encoder, Encoders, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -339,6 +340,51 @@ class DataFrameCallbackSuite extends QueryTest
}
}

test("SPARK-50581: support observe with udaf") {
withUserDefinedFunction(("someUdaf", true)) {
spark.udf.register("someUdaf", functions.udaf(new Aggregator[JLong, JLong, JLong] {
def zero: JLong = 0L
def reduce(b: JLong, a: JLong): JLong = a + b
def merge(b1: JLong, b2: JLong): JLong = b1 + b2
def finish(r: JLong): JLong = r
def bufferEncoder: Encoder[JLong] = Encoders.LONG
def outputEncoder: Encoder[JLong] = Encoders.LONG
}))

val df = spark.range(100)

val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
if (qe.observedMetrics.nonEmpty) {
metricMaps += qe.observedMetrics
}
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
// No-op
}
}
try {
spark.listenerManager.register(listener)

// udaf usage in observe is not working (serialization exception)
df.observe(
name = "my_metrics",
expr("someUdaf(id)").as("agg")
)
.collect()

sparkContext.listenerBus.waitUntilEmpty()
assert(metricMaps.size === 1)
assert(metricMaps.head("my_metrics") === Row(4950L))

} finally {
spark.listenerManager.unregister(listener)
}
}
}

private def validateObservedMetrics(df: Dataset[JLong]): Unit = {
val metricMaps = ArrayBuffer.empty[Map[String, Row]]
val listener = new QueryExecutionListener {
Expand Down

0 comments on commit 44ab349

Please sign in to comment.