Skip to content

Commit

Permalink
[CH] A simple job scheduler for merge tree cache sync load (apache#6842)
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
When the cache is loaded synchronously, the time consumed may be greater than the timeout of the spark rpc. A new asynchronous task mechanism is introduced to implement cache synchronous loading through polling, and a unified exception handling is added.

How was this patch tested?
unit tests

(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
  • Loading branch information
liuneng1994 authored and shamirchen committed Oct 14, 2024
1 parent 9d805f3 commit 9de431a
Show file tree
Hide file tree
Showing 13 changed files with 581 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
import java.util.Set;

public class CHNativeCacheManager {
public static void cacheParts(String table, Set<String> columns, boolean async) {
nativeCacheParts(table, String.join(",", columns), async);
public static String cacheParts(String table, Set<String> columns) {
return nativeCacheParts(table, String.join(",", columns));
}

private static native void nativeCacheParts(String table, String columns, boolean async);
private static native String nativeCacheParts(String table, String columns);

public static CacheResult getCacheStatus(String jobId) {
return nativeGetCacheStatus(jobId);
}

private static native CacheResult nativeGetCacheStatus(String jobId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution;

public class CacheResult {
public enum Status {
RUNNING(0),
SUCCESS(1),
ERROR(2);

private final int value;

Status(int value) {
this.value = value;
}

public int getValue() {
return value;
}

public static Status fromInt(int value) {
for (Status myEnum : Status.values()) {
if (myEnum.getValue() == value) {
return myEnum;
}
}
throw new IllegalArgumentException("No enum constant for value: " + value);
}
}

private final Status status;
private final String message;

public CacheResult(int status, String message) {
this.status = Status.fromInt(status);
this.message = message;
}

public Status getStatus() {
return status;
}

public String getMessage() {
return message;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
hashIds.forEach(
resource_id => CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id))
}
case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
CHNativeCacheManager.cacheParts(mergeTreeTable, columns, true)

case e =>
logError(s"Received unexpected message. $e")
Expand All @@ -74,12 +72,16 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
try {
CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false)
context.reply(CacheLoadResult(true))
val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns)
context.reply(CacheJobInfo(status = true, jobId))
} catch {
case _: Exception =>
context.reply(CacheLoadResult(false, s"executor: $executorId cache data failed."))
context.reply(
CacheJobInfo(status = false, "", s"executor: $executorId cache data failed."))
}
case GlutenMergeTreeCacheLoadStatus(jobId) =>
val status = CHNativeCacheManager.getCacheStatus(jobId)
context.reply(status)
case e =>
logError(s"Received unexpected message. $e")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ object GlutenRpcMessages {
case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String])
extends GlutenRpcMessage

// for mergetree cache
case class GlutenMergeTreeCacheLoad(mergeTreeTable: String, columns: util.Set[String])
extends GlutenRpcMessage

case class CacheLoadResult(success: Boolean, reason: String = "") extends GlutenRpcMessage
case class GlutenMergeTreeCacheLoadStatus(jobId: String)

case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
extends GlutenRpcMessage
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
package org.apache.spark.sql.execution.commands

import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.CacheResult
import org.apache.gluten.execution.CacheResult.Status
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.substrait.rel.ExtensionTableBuilder

import org.apache.spark.affinity.CHAffinity
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.rpc.GlutenRpcMessages.{CacheLoadResult, GlutenMergeTreeCacheLoad}
import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, GlutenMergeTreeCacheLoad, GlutenMergeTreeCacheLoadStatus}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal}
import org.apache.spark.sql.delta._
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.toExecutorId
import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{checkExecutorId, collectJobTriggerResult, toExecutorId, waitAllJobFinish, waitRpcResults}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types.{BooleanType, StringType}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -106,7 +108,8 @@ case class GlutenCHCacheDataCommand(
}

val selectedAddFiles = if (tsfilter.isDefined) {
val allParts = DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false)
val allParts =
DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false)
allParts.files.filter(_.modificationTime >= tsfilter.get.toLong).toSeq
} else if (partitionColumn.isDefined && partitionValue.isDefined) {
val partitionColumns = snapshot.metadata.partitionSchema.fieldNames
Expand All @@ -126,10 +129,12 @@ case class GlutenCHCacheDataCommand(
snapshot,
Seq(partitionColumnAttr),
Seq(isNotNullExpr, greaterThanOrEqual),
false)
keepNumRecords = false)
.files
} else {
DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false).files
DeltaAdapter
.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false)
.files
}

val executorIdsToAddFiles =
Expand All @@ -151,17 +156,15 @@ case class GlutenCHCacheDataCommand(

if (locations.isEmpty) {
// non soft affinity
executorIdsToAddFiles
.get(GlutenCHCacheDataCommand.ALL_EXECUTORS)
.get
executorIdsToAddFiles(GlutenCHCacheDataCommand.ALL_EXECUTORS)
.append(mergeTreePart)
} else {
locations.foreach(
executor => {
if (!executorIdsToAddFiles.contains(executor)) {
executorIdsToAddFiles.put(executor, new ArrayBuffer[AddMergeTreeParts]())
}
executorIdsToAddFiles.get(executor).get.append(mergeTreePart)
executorIdsToAddFiles(executor).append(mergeTreePart)
})
}
})
Expand Down Expand Up @@ -201,87 +204,112 @@ case class GlutenCHCacheDataCommand(
executorIdsToParts.put(executorId, extensionTableNode.getExtensionTableStr)
}
})

// send rpc call
val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]()
if (executorIdsToParts.contains(GlutenCHCacheDataCommand.ALL_EXECUTORS)) {
// send all parts to all executors
val tableMessage = executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get
if (asynExecute) {
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
executor.executorEndpointRef.send(
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava))
})
Seq(Row(true, ""))
} else {
val futureList = ArrayBuffer[Future[CacheLoadResult]]()
val resultList = ArrayBuffer[CacheLoadResult]()
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
futureList.append(
executor.executorEndpointRef.ask[CacheLoadResult](
val tableMessage = executorIdsToParts(GlutenCHCacheDataCommand.ALL_EXECUTORS)
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
futureList.append(
(
executorId,
executor.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)
))
})
futureList.foreach(
f => {
resultList.append(ThreadUtils.awaitResult(f, Duration.Inf))
})
if (resultList.exists(!_.success)) {
Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";")))
} else {
Seq(Row(true, ""))
}
}
)))
})
} else {
if (asynExecute) {
executorIdsToParts.foreach(
value => {
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
if (executorData != null) {
executorData.executorEndpointRef.send(
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava))
} else {
throw new GlutenException(
s"executor ${value._1} not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
})
Seq(Row(true, ""))
} else {
val futureList = ArrayBuffer[Future[CacheLoadResult]]()
val resultList = ArrayBuffer[CacheLoadResult]()
executorIdsToParts.foreach(
value => {
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
if (executorData != null) {
futureList.append(
executorData.executorEndpointRef.ask[CacheLoadResult](
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)
))
} else {
throw new GlutenException(
s"executor ${value._1} not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
})
futureList.foreach(
f => {
resultList.append(ThreadUtils.awaitResult(f, Duration.Inf))
})
if (resultList.exists(!_.success)) {
Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";")))
} else {
Seq(Row(true, ""))
}
}
executorIdsToParts.foreach(
value => {
checkExecutorId(value._1)
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
futureList.append(
(
value._1,
executorData.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)
)))
})
}
val resultList = waitRpcResults(futureList)
if (asynExecute) {
val res = collectJobTriggerResult(resultList)
Seq(Row(res._1, res._2.mkString(";")))
} else {
val res = waitAllJobFinish(resultList)
Seq(Row(res._1, res._2))
}
}

}

object GlutenCHCacheDataCommand {
val ALL_EXECUTORS = "allExecutors"
private val ALL_EXECUTORS = "allExecutors"

private def toExecutorId(executorId: String): String =
executorId.split("_").last

def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, String) = {
val res = collectJobTriggerResult(jobs)
var status = res._1
val messages = res._2
jobs.foreach(
job => {
if (status) {
var complete = false
while (!complete) {
Thread.sleep(5000)
val future_result = GlutenDriverEndpoint.executorDataMap
.get(toExecutorId(job._1))
.executorEndpointRef
.ask[CacheResult](GlutenMergeTreeCacheLoadStatus(job._2.jobId))
val result = ThreadUtils.awaitResult(future_result, Duration.Inf)
result.getStatus match {
case Status.ERROR =>
status = false
messages.append(
s"executor : {}, failed with message: {};",
job._1,
result.getMessage)
complete = true
case Status.SUCCESS =>
complete = true
case _ =>
// still running
}
}
}
})
(status, messages.mkString(";"))
}

private def collectJobTriggerResult(jobs: ArrayBuffer[(String, CacheJobInfo)]) = {
var status = true
val messages = ArrayBuffer[String]()
jobs.foreach(
job => {
if (!job._2.status) {
messages.append(job._2.reason)
status = false
}
})
(status, messages)
}

private def waitRpcResults = (futureList: ArrayBuffer[(String, Future[CacheJobInfo])]) => {
val resultList = ArrayBuffer[(String, CacheJobInfo)]()
futureList.foreach(
f => {
resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf)))
})
resultList
}

private def checkExecutorId(executorId: String): Unit = {
if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) {
throw new GlutenException(
s"executor $executorId not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
}

}
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ void BackendInitializerUtil::init(const std::string_view plan)
// Init the table metadata cache map
StorageMergeTreeFactory::init_cache_map();

JobScheduler::initialize(SerializedPlanParser::global_context);
CacheManager::initialize(SerializedPlanParser::global_context);

std::call_once(
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Common/ConcurrentMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once

#include <mutex>
#include <shared_mutex>
#include <unordered_map>

namespace local_engine
Expand Down
Loading

0 comments on commit 9de431a

Please sign in to comment.