diff --git a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala new file mode 100644 index 0000000000000..cff3fd9f56912 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala @@ -0,0 +1,89 @@ +package org.apache.spark.util + +import java.text.NumberFormat +import java.util.{Arrays, Locale} + +import scala.concurrent.duration._ + +import org.apache.spark.SparkException +import org.apache.spark.util.Utils + +object MetricUtils { + + val SUM_METRIC: String = "sum" + val SIZE_METRIC: String = "size" + val TIMING_METRIC: String = "timing" + val NS_TIMING_METRIC: String = "nsTiming" + val AVERAGE_METRIC: String = "average" + private val baseForAvgMetric: Int = 1000 + private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))" + + private def toNumberFormat(value: Long): String = { + val numberFormat = NumberFormat.getNumberInstance(Locale.US) + numberFormat.format(value.toDouble / baseForAvgMetric) + } + + def metricNeedsMax(metricsType: String): Boolean = { + metricsType != SUM_METRIC + } + +/** + * A function that defines how we aggregate the final accumulator results among all tasks, + * and represent it in string for a SQL physical operator. + */ + def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = { + // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId) + val taskInfo = if (maxMetrics.isEmpty) { + "(driver)" + } else { + s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})" + } + if (metricsType == SUM_METRIC) { + val numberFormat = NumberFormat.getIntegerInstance(Locale.US) + numberFormat.format(values.sum) + } else if (metricsType == AVERAGE_METRIC) { + val validValues = values.filter(_ > 0) + // When there are only 1 metrics value (or None), no need to display max/min/median. This is + // common for driver-side SQL metrics. + if (validValues.length <= 1) { + toNumberFormat(validValues.headOption.getOrElse(0)) + } else { + val Seq(min, med, max) = { + Arrays.sort(validValues) + Seq( + toNumberFormat(validValues(0)), + toNumberFormat(validValues(validValues.length / 2)), + toNumberFormat(validValues(validValues.length - 1))) + } + s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)" + } + } else { + val strFormat: Long => String = if (metricsType == SIZE_METRIC) { + Utils.bytesToString + } else if (metricsType == TIMING_METRIC) { + Utils.msDurationToString + } else if (metricsType == NS_TIMING_METRIC) { + duration => Utils.msDurationToString(duration.nanos.toMillis) + } else { + throw SparkException.internalError(s"unexpected metrics type: $metricsType") + } + + val validValues = values.filter(_ >= 0) + // When there are only 1 metrics value (or None), no need to display max/min/median. This is + // common for driver-side SQL metrics. + if (validValues.length <= 1) { + strFormat(validValues.headOption.getOrElse(0)) + } else { + val Seq(sum, min, med, max) = { + Arrays.sort(validValues) + Seq( + strFormat(validValues.sum), + strFormat(validValues(0)), + strFormat(validValues(validValues.length / 2)), + strFormat(validValues(validValues.length - 1))) + } + s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)" + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala index 7551cd04f20f6..2db2ff74374ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources.v2.python import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.PythonSQLMetrics +import org.apache.spark.util.MetricUtils class PythonCustomMetric( @@ -28,7 +29,7 @@ class PythonCustomMetric( def this() = this(null, null) override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { - SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long]) + MetricUtils.stringValue("size", taskMetrics, Array.empty[Long]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index f0c1c0900c7f5..3dc5040c3d3ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,19 +17,16 @@ package org.apache.spark.sql.execution.metric -import java.text.NumberFormat -import java.util.{Arrays, Locale} - -import scala.concurrent.duration._ +// import scala.concurrent.duration._ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates -import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} import org.apache.spark.util.AccumulatorContext.internOption /** @@ -72,7 +69,7 @@ class SQLMetric( // This is used to filter out metrics. Metrics with value equal to initValue should // be filtered out, since they are either invalid or safe to filter without changing - // the aggregation defined in [[SQLMetrics.stringValue]]. + // the aggregation defined in [[MetricUtils.stringValue]]. // Note that we don't use 0 here since we may want to collect 0 metrics for // calculating min, max, etc. See SPARK-11013. override def isZero: Boolean = _value == initValue @@ -106,7 +103,7 @@ class SQLMetric( SQLMetrics.cachedSQLAccumIdentifier) } - // We should provide the raw value which can be -1, so that `SQLMetrics.stringValue` can correctly + // We should provide the raw value which can be -1, so that `MetricUtils.stringValue` can correctly // filter out the invalid -1 values. override def toInfoUpdate: AccumulableInfo = { AccumulableInfo(id, name, internOption(Some(_value)), None, true, true, @@ -203,77 +200,6 @@ object SQLMetrics { acc } - private def toNumberFormat(value: Long): String = { - val numberFormat = NumberFormat.getNumberInstance(Locale.US) - numberFormat.format(value.toDouble / baseForAvgMetric) - } - - def metricNeedsMax(metricsType: String): Boolean = { - metricsType != SUM_METRIC - } - - private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))" - - /** - * A function that defines how we aggregate the final accumulator results among all tasks, - * and represent it in string for a SQL physical operator. - */ - def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = { - // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId) - val taskInfo = if (maxMetrics.isEmpty) { - "(driver)" - } else { - s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})" - } - if (metricsType == SUM_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.US) - numberFormat.format(values.sum) - } else if (metricsType == AVERAGE_METRIC) { - val validValues = values.filter(_ > 0) - // When there are only 1 metrics value (or None), no need to display max/min/median. This is - // common for driver-side SQL metrics. - if (validValues.length <= 1) { - toNumberFormat(validValues.headOption.getOrElse(0)) - } else { - val Seq(min, med, max) = { - Arrays.sort(validValues) - Seq( - toNumberFormat(validValues(0)), - toNumberFormat(validValues(validValues.length / 2)), - toNumberFormat(validValues(validValues.length - 1))) - } - s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)" - } - } else { - val strFormat: Long => String = if (metricsType == SIZE_METRIC) { - Utils.bytesToString - } else if (metricsType == TIMING_METRIC) { - Utils.msDurationToString - } else if (metricsType == NS_TIMING_METRIC) { - duration => Utils.msDurationToString(duration.nanos.toMillis) - } else { - throw SparkException.internalError(s"unexpected metrics type: $metricsType") - } - - val validValues = values.filter(_ >= 0) - // When there are only 1 metrics value (or None), no need to display max/min/median. This is - // common for driver-side SQL metrics. - if (validValues.length <= 1) { - strFormat(validValues.headOption.getOrElse(0)) - } else { - val Seq(sum, min, med, max) = { - Arrays.sort(validValues) - Seq( - strFormat(validValues.sum), - strFormat(validValues(0)), - strFormat(validValues(validValues.length / 2)), - strFormat(validValues(validValues.length - 1))) - } - s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)" - } - } - } - def postDriverMetricsUpdatedByValue( sc: SparkContext, executionId: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 3c8c7edfeb069..f680860231f01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} -import org.apache.spark.util.Utils +import org.apache.spark.util.{MetricUtils, Utils} import org.apache.spark.util.collection.OpenHashMap class SQLAppStatusListener( @@ -235,7 +235,7 @@ class SQLAppStatusListener( } }.getOrElse( // Built-in SQLMetric - SQLMetrics.stringValue(m.metricType, _, _) + MetricUtils.stringValue(m.metricType, _, _) ) (m.accumulatorId, metricAggMethod) }.toMap @@ -554,7 +554,7 @@ private class LiveStageMetrics( /** * Task metrics values for the stage. Maps the metric ID to the metric values for each * index. For each metric ID, there will be the same number of values as the number - * of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value, + * of indices. This relies on `MetricUtils.stringValue` treating 0 as a neutral value, * independent of the actual metric type. */ private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]() @@ -601,7 +601,7 @@ private class LiveStageMetrics( val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks)) metricValues(taskIdx) = value - if (SQLMetrics.metricNeedsMax(accumIdsToMetricType(acc.id))) { + if (MetricUtils.metricNeedsMax(accumIdsToMetricType(acc.id))) { val maxMetricsTaskId = metricsIdToMaxTaskValue.computeIfAbsent(acc.id, _ => Array(value, taskId)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index e63ff019a2b6c..256d9d156c183 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -57,7 +57,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.status.{AppStatusStore, ElementTrackingStore} -import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, SerializableConfiguration, Utils} +import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, MetricUtils, SerializableConfiguration, Utils} import org.apache.spark.util.kvstore.InMemoryStore @@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") val driverMetric2 = physicalPlan.metrics("dummy2") - val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, + val expectedValue = MetricUtils.stringValue(driverMetric.metricType, Array(expectedAccumValue), Array.empty[Long]) - val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, + val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType, Array(expectedAccumValue2), Array.empty[Long]) assert(metrics.contains(driverMetric.id))