Skip to content

Commit

Permalink
remove async rpc call
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Aug 15, 2024
1 parent d57a35d commit e4c6177
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import java.util.Set;

public class CHNativeCacheManager {
public static String cacheParts(String table, Set<String> columns, boolean async) {
return nativeCacheParts(table, String.join(",", columns), async);
public static String cacheParts(String table, Set<String> columns) {
return nativeCacheParts(table, String.join(",", columns));
}

private static native String nativeCacheParts(String table, String columns, boolean async);
private static native String nativeCacheParts(String table, String columns);

public static CacheResult getCacheStatus(String jobId) {
return nativeGetCacheStatus(jobId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
hashIds.forEach(
resource_id => CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id))
}
case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false)

case e =>
logError(s"Received unexpected message. $e")
Expand All @@ -74,7 +72,7 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
try {
val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false)
val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns)
context.reply(CacheJobInfo(status = true, jobId))
} catch {
case _: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal}
import org.apache.spark.sql.delta._
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{toExecutorId, waitAllJobFinish}
import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{checkExecutorId, collectJobTriggerResult, toExecutorId, waitAllJobFinish, waitRpcResults}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types.{BooleanType, StringType}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -102,14 +102,14 @@ case class GlutenCHCacheDataCommand(
selectedColuman.get
.filter(allColumns.contains(_))
.map(ConverterUtils.normalizeColName)
.toSeq
} else {
allColumns.map(ConverterUtils.normalizeColName)
}

val selectedAddFiles = if (tsfilter.isDefined) {
val allParts = DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false)
allParts.files.filter(_.modificationTime >= tsfilter.get.toLong).toSeq
val allParts =
DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false)
allParts.files.filter(_.modificationTime >= tsfilter.get.toLong)
} else if (partitionColumn.isDefined && partitionValue.isDefined) {
val partitionColumns = snapshot.metadata.partitionSchema.fieldNames
require(
Expand All @@ -128,10 +128,12 @@ case class GlutenCHCacheDataCommand(
snapshot,
Seq(partitionColumnAttr),
Seq(isNotNullExpr, greaterThanOrEqual),
false)
keepNumRecords = false)
.files
} else {
DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false).files
DeltaAdapter
.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false)
.files
}

val executorIdsToAddFiles =
Expand All @@ -153,17 +155,15 @@ case class GlutenCHCacheDataCommand(

if (locations.isEmpty) {
// non soft affinity
executorIdsToAddFiles
.get(GlutenCHCacheDataCommand.ALL_EXECUTORS)
.get
executorIdsToAddFiles(GlutenCHCacheDataCommand.ALL_EXECUTORS)
.append(mergeTreePart)
} else {
locations.foreach(
executor => {
if (!executorIdsToAddFiles.contains(executor)) {
executorIdsToAddFiles.put(executor, new ArrayBuffer[AddMergeTreeParts]())
}
executorIdsToAddFiles.get(executor).get.append(mergeTreePart)
executorIdsToAddFiles(executor).append(mergeTreePart)
})
}
})
Expand All @@ -174,7 +174,7 @@ case class GlutenCHCacheDataCommand(
val executorId = value._1
if (parts.nonEmpty) {
val onePart = parts(0)
val partNameList = parts.map(_.name).toSeq
val partNameList = parts.map(_.name)
// starts and lengths is useless for write
val partRanges = Seq.range(0L, partNameList.length).map(_ => long2Long(0L)).asJava

Expand Down Expand Up @@ -203,97 +203,54 @@ case class GlutenCHCacheDataCommand(
executorIdsToParts.put(executorId, extensionTableNode.getExtensionTableStr)
}
})

// send rpc call
val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]()
if (executorIdsToParts.contains(GlutenCHCacheDataCommand.ALL_EXECUTORS)) {
// send all parts to all executors
val tableMessage = executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get
if (asynExecute) {
GlutenDriverEndpoint.executorDataMap.forEach(
(_, executor) => {
executor.executorEndpointRef.send(
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava))
})
Seq(Row(true, ""))
} else {
val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]()
val resultList = ArrayBuffer[(String, CacheJobInfo)]()
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
futureList.append(
(
executorId,
executor.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)
)))
})
futureList.foreach(
f => {
resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf)))
})

val res = waitAllJobFinish(resultList)
Seq(Row(res._1, res._2))
}
val tableMessage = executorIdsToParts(GlutenCHCacheDataCommand.ALL_EXECUTORS)
GlutenDriverEndpoint.executorDataMap.forEach(
(executorId, executor) => {
futureList.append(
(
executorId,
executor.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)
)))
})
} else {
def checkExecutorId(executorId: String): Unit = {
if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) {
throw new GlutenException(
s"executor $executorId not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
}
if (asynExecute) {
executorIdsToParts.foreach(
value => {
checkExecutorId(value._1)
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
executorData.executorEndpointRef.send(
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava))
})
Seq(Row(true, ""))
} else {
val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]()
val resultList = ArrayBuffer[(String, CacheJobInfo)]()
executorIdsToParts.foreach(
value => {
checkExecutorId(value._1)
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
futureList.append(
(
value._1,
executorData.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)
)))
})
futureList.foreach(
f => {
resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf)))
})
val res = waitAllJobFinish(resultList)
Seq(Row(res._1, res._2))
}
executorIdsToParts.foreach(
value => {
checkExecutorId(value._1)
val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
futureList.append(
(
value._1,
executorData.executorEndpointRef.ask[CacheJobInfo](
GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)
)))
})
}
val resultList = waitRpcResults(futureList)
if (asynExecute) {
val res = collectJobTriggerResult(resultList)
Seq(Row(res._1, res._2.mkString(";")))
} else {
val res = waitAllJobFinish(resultList)
Seq(Row(res._1, res._2))
}
}

}

object GlutenCHCacheDataCommand {
val ALL_EXECUTORS = "allExecutors"
private val ALL_EXECUTORS = "allExecutors"

private def toExecutorId(executorId: String): String =
executorId.split("_").last

def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, String) = {
var status = true
val messages = ArrayBuffer[String]()
jobs.foreach(
job => {
if (!job._2.status) {
messages.append(job._2.reason)
status = false
}
})

val res = collectJobTriggerResult(jobs)
var status = res._1
val messages = res._2
jobs.foreach(
job => {
if (status) {
Expand All @@ -309,7 +266,7 @@ object GlutenCHCacheDataCommand {
case Status.ERROR =>
status = false
messages.append(
s"executor : {}, failed with message: {}",
s"executor : {}, failed with message: {};",
job._1,
result.getMessage)
complete = true
Expand All @@ -324,4 +281,34 @@ object GlutenCHCacheDataCommand {
(status, messages.mkString(";"))
}

private def collectJobTriggerResult(jobs: ArrayBuffer[(String, CacheJobInfo)]) = {
var status = true
val messages = ArrayBuffer[String]()
jobs.foreach(
job => {
if (!job._2.status) {
messages.append(job._2.reason)
status = false
}
})
(status, messages)
}

private def waitRpcResults = (futureList: ArrayBuffer[(String, Future[CacheJobInfo])]) => {
val resultList = ArrayBuffer[(String, CacheJobInfo)]()
futureList.foreach(
f => {
resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf)))
})
resultList
}

private def checkExecutorId(executorId: String): Unit = {
if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) {
throw new GlutenException(
s"executor $executorId not found," +
s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}")
}
}

}
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Storages/Cache/CacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Task CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p
return std::move(task);
}

JobId CacheManager::cacheParts(const String& table_def, const std::unordered_set<String>& columns, const bool async)
JobId CacheManager::cacheParts(const String& table_def, const std::unordered_set<String>& columns)
{
auto table = parseMergeTreeTableString(table_def);
JobId id = toString(UUIDHelpers::generateV4());
Expand All @@ -141,7 +141,7 @@ JobId CacheManager::cacheParts(const String& table_def, const std::unordered_set
job.addTask(cachePart(table, part, columns));
}
auto& scheduler = JobScheduler::instance();
scheduler.scheduleJob(std::move(job), async);
scheduler.scheduleJob(std::move(job));
return id;
}

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Storages/Cache/CacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CacheManager {
static CacheManager & instance();
static void initialize(DB::ContextMutablePtr context);
Task cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set<String>& columns);
JobId cacheParts(const String& table_def, const std::unordered_set<String>& columns, bool async = true);
JobId cacheParts(const String& table_def, const std::unordered_set<String>& columns);
static jobject getCacheStatus(JNIEnv * env, const String& jobId);
private:
CacheManager() = default;
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void JobScheduler::initialize(DB::ContextPtr context)

}

JobId JobScheduler::scheduleJob(Job && job, bool auto_remove)
JobId JobScheduler::scheduleJob(Job&& job)
{
cleanFinishedJobs();
if (job_details.contains(job.id))
Expand All @@ -80,7 +80,7 @@ JobId JobScheduler::scheduleJob(Job && job, bool auto_remove)
job_detail.task_results.emplace_back(TaskResult());
auto & task_result = job_detail.task_results.back();
thread_pool->scheduleOrThrow(
[&, clean_job = auto_remove]()
[&]()
{
SCOPE_EXIT({
job_detail.remain_tasks->fetch_sub(1, std::memory_order::acquire);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Storages/Cache/JobScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class JobScheduler

static void initialize(DB::ContextPtr context);

JobId scheduleJob(Job && job, bool auto_remove);
JobId scheduleJob(Job&& job);

std::optional<JobSatus> getJobSatus(const JobId& job_id);

Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ JNIEXPORT void Java_org_apache_gluten_utils_TestExceptionUtils_generateNativeExc



JNIEXPORT jstring Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_, jboolean async_)
JNIEXPORT jstring Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_)
{
LOCAL_ENGINE_JNI_METHOD_START
auto table_def = jstring2string(env, table_);
Expand All @@ -1281,7 +1281,7 @@ JNIEXPORT jstring Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCa
{
column_set.insert(col);
}
auto id = local_engine::CacheManager::instance().cacheParts(table_def, column_set, async_);
auto id = local_engine::CacheManager::instance().cacheParts(table_def, column_set);
return local_engine::charTojstring(env, id.c_str());
LOCAL_ENGINE_JNI_METHOD_END(env, nullptr);
}
Expand Down

0 comments on commit e4c6177

Please sign in to comment.