Skip to content

Commit

Permalink
Push initial restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
costas-db committed Jan 23, 2025
1 parent af2768d commit 62fbb58
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 88 deletions.
89 changes: 89 additions & 0 deletions core/src/main/scala/org/apache/spark/util/MetricUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.spark.util

import java.text.NumberFormat
import java.util.{Arrays, Locale}

import scala.concurrent.duration._

import org.apache.spark.SparkException
import org.apache.spark.util.Utils

object MetricUtils {

val SUM_METRIC: String = "sum"
val SIZE_METRIC: String = "size"
val TIMING_METRIC: String = "timing"
val NS_TIMING_METRIC: String = "nsTiming"
val AVERAGE_METRIC: String = "average"
private val baseForAvgMetric: Int = 1000
private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"

private def toNumberFormat(value: Long): String = {
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(value.toDouble / baseForAvgMetric)
}

def metricNeedsMax(metricsType: String): Boolean = {
metricsType != SUM_METRIC
}

/**
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
} else {
s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
}
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val validValues = values.filter(_ > 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
toNumberFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(min, med, max) = {
Arrays.sort(validValues)
Seq(
toNumberFormat(validValues(0)),
toNumberFormat(validValues(validValues.length / 2)),
toNumberFormat(validValues(validValues.length - 1)))
}
s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
}
} else {
val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}

val validValues = values.filter(_ >= 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
strFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(sum, min, med, max) = {
Arrays.sort(validValues)
Seq(
strFormat(validValues.sum),
strFormat(validValues(0)),
strFormat(validValues(validValues.length / 2)),
strFormat(validValues(validValues.length - 1)))
}
s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.PythonSQLMetrics
import org.apache.spark.util.MetricUtils


class PythonCustomMetric(
Expand All @@ -28,7 +29,7 @@ class PythonCustomMetric(
def this() = this(null, null)

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
MetricUtils.stringValue("size", taskMetrics, Array.empty[Long])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@

package org.apache.spark.sql.execution.metric

import java.text.NumberFormat
import java.util.{Arrays, Locale}

import scala.concurrent.duration._
// import scala.concurrent.duration._

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
import org.apache.spark.util.AccumulatorContext.internOption

/**
Expand Down Expand Up @@ -72,7 +69,7 @@ class SQLMetric(

// This is used to filter out metrics. Metrics with value equal to initValue should
// be filtered out, since they are either invalid or safe to filter without changing
// the aggregation defined in [[SQLMetrics.stringValue]].
// the aggregation defined in [[MetricUtils.stringValue]].
// Note that we don't use 0 here since we may want to collect 0 metrics for
// calculating min, max, etc. See SPARK-11013.
override def isZero: Boolean = _value == initValue
Expand Down Expand Up @@ -106,7 +103,7 @@ class SQLMetric(
SQLMetrics.cachedSQLAccumIdentifier)
}

// We should provide the raw value which can be -1, so that `SQLMetrics.stringValue` can correctly
// We should provide the raw value which can be -1, so that `MetricUtils.stringValue` can correctly
// filter out the invalid -1 values.
override def toInfoUpdate: AccumulableInfo = {
AccumulableInfo(id, name, internOption(Some(_value)), None, true, true,
Expand Down Expand Up @@ -203,77 +200,6 @@ object SQLMetrics {
acc
}

private def toNumberFormat(value: Long): String = {
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(value.toDouble / baseForAvgMetric)
}

def metricNeedsMax(metricsType: String): Boolean = {
metricsType != SUM_METRIC
}

private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"

/**
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
} else {
s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
}
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val validValues = values.filter(_ > 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
toNumberFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(min, med, max) = {
Arrays.sort(validValues)
Seq(
toNumberFormat(validValues(0)),
toNumberFormat(validValues(validValues.length / 2)),
toNumberFormat(validValues(validValues.length - 1)))
}
s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
}
} else {
val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}

val validValues = values.filter(_ >= 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
strFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(sum, min, med, max) = {
Arrays.sort(validValues)
Seq(
strFormat(validValues.sum),
strFormat(validValues(0)),
strFormat(validValues(validValues.length / 2)),
strFormat(validValues(validValues.length - 1)))
}
s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
}
}
}

def postDriverMetricsUpdatedByValue(
sc: SparkContext,
executionId: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.Utils
import org.apache.spark.util.{MetricUtils, Utils}
import org.apache.spark.util.collection.OpenHashMap

class SQLAppStatusListener(
Expand Down Expand Up @@ -235,7 +235,7 @@ class SQLAppStatusListener(
}
}.getOrElse(
// Built-in SQLMetric
SQLMetrics.stringValue(m.metricType, _, _)
MetricUtils.stringValue(m.metricType, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap
Expand Down Expand Up @@ -554,7 +554,7 @@ private class LiveStageMetrics(
/**
* Task metrics values for the stage. Maps the metric ID to the metric values for each
* index. For each metric ID, there will be the same number of values as the number
* of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
* of indices. This relies on `MetricUtils.stringValue` treating 0 as a neutral value,
* independent of the actual metric type.
*/
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
Expand Down Expand Up @@ -601,7 +601,7 @@ private class LiveStageMetrics(
val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
metricValues(taskIdx) = value

if (SQLMetrics.metricNeedsMax(accumIdsToMetricType(acc.id))) {
if (MetricUtils.metricNeedsMax(accumIdsToMetricType(acc.id))) {
val maxMetricsTaskId = metricsIdToMaxTaskValue.computeIfAbsent(acc.id, _ => Array(value,
taskId))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.status.{AppStatusStore, ElementTrackingStore}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, SerializableConfiguration, Utils}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, MetricUtils, SerializableConfiguration, Utils}
import org.apache.spark.util.kvstore.InMemoryStore


Expand Down Expand Up @@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes
val metrics = statusStore.executionMetrics(execId)
val driverMetric = physicalPlan.metrics("dummy")
val driverMetric2 = physicalPlan.metrics("dummy2")
val expectedValue = SQLMetrics.stringValue(driverMetric.metricType,
val expectedValue = MetricUtils.stringValue(driverMetric.metricType,
Array(expectedAccumValue), Array.empty[Long])
val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType,
val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType,
Array(expectedAccumValue2), Array.empty[Long])

assert(metrics.contains(driverMetric.id))
Expand Down

0 comments on commit 62fbb58

Please sign in to comment.