diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java index f5f75dc1dca6d..0f2d669fc393c 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java @@ -19,9 +19,15 @@ import java.util.Set; public class CHNativeCacheManager { - public static void cacheParts(String table, Set columns, boolean async) { - nativeCacheParts(table, String.join(",", columns), async); + public static String cacheParts(String table, Set columns, boolean async) { + return nativeCacheParts(table, String.join(",", columns), async); } - private static native void nativeCacheParts(String table, String columns, boolean async); + private static native String nativeCacheParts(String table, String columns, boolean async); + + public static CacheResult getCacheStatus(String jobId) { + return nativeGetCacheStatus(jobId); + } + + private static native CacheResult nativeGetCacheStatus(String jobId); } diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java new file mode 100644 index 0000000000000..0fa69e0d0b1fd --- /dev/null +++ b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution; + +public class CacheResult { + public enum Status { + RUNNING(0), + SUCCESS(1), + ERROR(2); + + private final int value; + + Status(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + public static Status fromInt(int value) { + for (Status myEnum : Status.values()) { + if (myEnum.getValue() == value) { + return myEnum; + } + } + throw new IllegalArgumentException("No enum constant for value: " + value); + } + } + + private final Status status; + private final String message; + + public CacheResult(int status, String message) { + this.status = Status.fromInt(status); + this.message = message; + } + + public Status getStatus() { + return status; + } + + public String getMessage() { + return message; + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala index 4d90ab6533ba7..a0d727ce84dff 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala @@ -65,7 +65,7 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf) resource_id => CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id)) } case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) => - CHNativeCacheManager.cacheParts(mergeTreeTable, columns, true) + CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false) case e => logError(s"Received unexpected message. $e") @@ -74,12 +74,16 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) => try { - CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false) - context.reply(CacheLoadResult(true)) + val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false) + context.reply(CacheJobInfo(status = true, jobId)) } catch { case _: Exception => - context.reply(CacheLoadResult(false, s"executor: $executorId cache data failed.")) + context.reply( + CacheJobInfo(status = false, "", s"executor: $executorId cache data failed.")) } + case GlutenMergeTreeCacheLoadStatus(jobId) => + val status = CHNativeCacheManager.getCacheStatus(jobId) + context.reply(status) case e => logError(s"Received unexpected message. $e") } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala index d675d705f10a2..800b15b9949b0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala @@ -35,8 +35,12 @@ object GlutenRpcMessages { case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String]) extends GlutenRpcMessage + // for mergetree cache case class GlutenMergeTreeCacheLoad(mergeTreeTable: String, columns: util.Set[String]) extends GlutenRpcMessage - case class CacheLoadResult(success: Boolean, reason: String = "") extends GlutenRpcMessage + case class GlutenMergeTreeCacheLoadStatus(jobId: String) + + case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "") + extends GlutenRpcMessage } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala index 1e6b024063b6d..ad504ab22b224 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.commands import org.apache.gluten.exception.GlutenException +import org.apache.gluten.execution.CacheResult +import org.apache.gluten.execution.CacheResult.Status import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.substrait.rel.ExtensionTableBuilder import org.apache.spark.affinity.CHAffinity import org.apache.spark.rpc.GlutenDriverEndpoint -import org.apache.spark.rpc.GlutenRpcMessages.{CacheLoadResult, GlutenMergeTreeCacheLoad} +import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, GlutenMergeTreeCacheLoad, GlutenMergeTreeCacheLoadStatus} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal} import org.apache.spark.sql.delta._ import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.toExecutorId +import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{toExecutorId, waitAllJobFinish} import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.util.ThreadUtils @@ -208,72 +210,68 @@ case class GlutenCHCacheDataCommand( val tableMessage = executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get if (asynExecute) { GlutenDriverEndpoint.executorDataMap.forEach( - (executorId, executor) => { + (_, executor) => { executor.executorEndpointRef.send( GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)) }) Seq(Row(true, "")) } else { - val futureList = ArrayBuffer[Future[CacheLoadResult]]() - val resultList = ArrayBuffer[CacheLoadResult]() + val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]() + val resultList = ArrayBuffer[(String, CacheJobInfo)]() GlutenDriverEndpoint.executorDataMap.forEach( (executorId, executor) => { futureList.append( - executor.executorEndpointRef.ask[CacheLoadResult]( - GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava) - )) + ( + executorId, + executor.executorEndpointRef.ask[CacheJobInfo]( + GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava) + ))) }) futureList.foreach( f => { - resultList.append(ThreadUtils.awaitResult(f, Duration.Inf)) + resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf))) }) - if (resultList.exists(!_.success)) { - Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";"))) - } else { - Seq(Row(true, "")) - } + + val res = waitAllJobFinish(resultList) + Seq(Row(res._1, res._2)) } } else { + def checkExecutorId(executorId: String): Unit = { + if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) { + throw new GlutenException( + s"executor $executorId not found," + + s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") + } + } if (asynExecute) { executorIdsToParts.foreach( value => { + checkExecutorId(value._1) val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) - if (executorData != null) { - executorData.executorEndpointRef.send( - GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)) - } else { - throw new GlutenException( - s"executor ${value._1} not found," + - s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") - } + executorData.executorEndpointRef.send( + GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)) }) Seq(Row(true, "")) } else { - val futureList = ArrayBuffer[Future[CacheLoadResult]]() - val resultList = ArrayBuffer[CacheLoadResult]() + val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]() + val resultList = ArrayBuffer[(String, CacheJobInfo)]() executorIdsToParts.foreach( value => { + checkExecutorId(value._1) val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) - if (executorData != null) { - futureList.append( - executorData.executorEndpointRef.ask[CacheLoadResult]( + futureList.append( + ( + value._1, + executorData.executorEndpointRef.ask[CacheJobInfo]( GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava) - )) - } else { - throw new GlutenException( - s"executor ${value._1} not found," + - s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") - } + ))) }) futureList.foreach( f => { - resultList.append(ThreadUtils.awaitResult(f, Duration.Inf)) + resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf))) }) - if (resultList.exists(!_.success)) { - Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";"))) - } else { - Seq(Row(true, "")) - } + val res = waitAllJobFinish(resultList) + Seq(Row(res._1, res._2)) } } } @@ -284,4 +282,46 @@ object GlutenCHCacheDataCommand { private def toExecutorId(executorId: String): String = executorId.split("_").last + + def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, String) = { + var status = true + val messages = ArrayBuffer[String]() + jobs.foreach( + job => { + if (!job._2.status) { + messages.append(job._2.reason) + status = false + } + }) + + jobs.foreach( + job => { + if (status) { + var complete = false + while (!complete) { + Thread.sleep(5000) + val future_result = GlutenDriverEndpoint.executorDataMap + .get(toExecutorId(job._1)) + .executorEndpointRef + .ask[CacheResult](GlutenMergeTreeCacheLoadStatus(job._2.jobId)) + val result = ThreadUtils.awaitResult(future_result, Duration.Inf) + result.getStatus match { + case Status.ERROR => + status = false + messages.append( + s"executor : {}, failed with message: {}", + job._1, + result.getMessage) + complete = true + case Status.SUCCESS => + complete = true + case _ => + // still running + } + } + } + }) + (status, messages.mkString(";")) + } + } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 35b4f0c97806f..e2f9477a6d64f 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -976,6 +976,7 @@ void BackendInitializerUtil::init(const std::string_view plan) // Init the table metadata cache map StorageMergeTreeFactory::init_cache_map(); + JobScheduler::initialize(SerializedPlanParser::global_context); CacheManager::initialize(SerializedPlanParser::global_context); std::call_once( diff --git a/cpp-ch/local-engine/Common/ConcurrentMap.h b/cpp-ch/local-engine/Common/ConcurrentMap.h index 1719d9b255eaa..2db35102215ae 100644 --- a/cpp-ch/local-engine/Common/ConcurrentMap.h +++ b/cpp-ch/local-engine/Common/ConcurrentMap.h @@ -16,7 +16,7 @@ */ #pragma once -#include +#include #include namespace local_engine diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index abb7295adc0d0..d5db56d0ea4d7 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -162,5 +162,19 @@ struct MergeTreeConfig return config; } }; + +struct GlutenJobSchedulerConfig +{ + inline static const String JOB_SCHEDULER_MAX_THREADS = "job_scheduler_max_threads"; + + size_t job_scheduler_max_threads = 10; + + static GlutenJobSchedulerConfig loadFromContext(DB::ContextPtr context) + { + GlutenJobSchedulerConfig config; + config.job_scheduler_max_threads = context->getConfigRef().getUInt64(JOB_SCHEDULER_MAX_THREADS, 10); + return config; + } +}; } diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp index d2c7b06810db6..6753c1391cd86 100644 --- a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp +++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp @@ -26,12 +26,13 @@ #include #include #include -#include #include #include #include #include +#include + namespace DB { namespace ErrorCodes @@ -49,6 +50,16 @@ extern const Metric LocalThreadScheduled; namespace local_engine { + +jclass CacheManager::cache_result_class = nullptr; +jmethodID CacheManager::cache_result_constructor = nullptr; + +void CacheManager::initJNI(JNIEnv * env) +{ + cache_result_class = CreateGlobalClassReference(env, "Lorg/apache/gluten/execution/CacheResult;"); + cache_result_constructor = GetMethodID(env, cache_result_class, "", "(ILjava/lang/String;)V"); +} + CacheManager & CacheManager::instance() { static CacheManager cache_manager; @@ -59,13 +70,6 @@ void CacheManager::initialize(DB::ContextMutablePtr context_) { auto & manager = instance(); manager.context = context_; - manager.thread_pool = std::make_unique( - CurrentMetrics::LocalThread, - CurrentMetrics::LocalThreadActive, - CurrentMetrics::LocalThreadScheduled, - manager.context->getConfigRef().getInt("cache_sync_max_threads", 10), - 0, - 0); } struct CacheJobContext @@ -73,17 +77,16 @@ struct CacheJobContext MergeTreeTable table; }; -void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set & columns, std::shared_ptr latch) +Task CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set & columns) { CacheJobContext job_context{table}; job_context.table.parts.clear(); job_context.table.parts.push_back(part); job_context.table.snapshot_id = ""; - auto job = [job_detail = job_context, context = this->context, read_columns = columns, latch = latch]() + Task task = [job_detail = job_context, context = this->context, read_columns = columns]() { try { - SCOPE_EXIT({ if (latch) latch->count_down();}); auto storage = MergeTreeRelParser::parseStorage(job_detail.table, context, true); auto storage_snapshot = std::make_shared(*storage, storage->getInMemoryMetadataPtr()); NamesAndTypesList names_and_types_list; @@ -113,8 +116,7 @@ void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p PullingPipelineExecutor executor(pipeline); while (true) { - Chunk chunk; - if (!executor.pull(chunk)) + if (Chunk chunk; !executor.pull(chunk)) break; } LOG_INFO(getLogger("CacheManager"), "Load cache of table {}.{} part {} success.", job_detail.table.database, job_detail.table.table, job_detail.table.parts.front().name); @@ -122,22 +124,58 @@ void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p catch (std::exception& e) { LOG_ERROR(getLogger("CacheManager"), "Load cache of table {}.{} part {} failed.\n {}", job_detail.table.database, job_detail.table.table, job_detail.table.parts.front().name, e.what()); + std::rethrow_exception(std::current_exception()); } }; LOG_INFO(getLogger("CacheManager"), "Loading cache of table {}.{} part {}", job_context.table.database, job_context.table.table, job_context.table.parts.front().name); - thread_pool->scheduleOrThrowOnError(std::move(job)); + return std::move(task); } -void CacheManager::cacheParts(const String& table_def, const std::unordered_set& columns, bool async) +JobId CacheManager::cacheParts(const String& table_def, const std::unordered_set& columns, const bool async) { auto table = parseMergeTreeTableString(table_def); - std::shared_ptr latch = nullptr; - if (!async) latch = std::make_shared(table.parts.size()); + JobId id = toString(UUIDHelpers::generateV4()); + Job job(id); for (const auto & part : table.parts) { - cachePart(table, part, columns, latch); + job.addTask(cachePart(table, part, columns)); + } + auto& scheduler = JobScheduler::instance(); + scheduler.scheduleJob(std::move(job), async); + return id; +} + +jobject CacheManager::getCacheStatus(JNIEnv * env, const String & jobId) +{ + auto& scheduler = JobScheduler::instance(); + auto job_status = scheduler.getJobSatus(jobId); + int status = 0; + String message; + if (job_status.has_value()) + { + switch (job_status.value().status) + { + case JobSatus::RUNNING: + status = 0; + break; + case JobSatus::FINISHED: + status = 1; + break; + case JobSatus::FAILED: + status = 2; + for (const auto & msg : job_status->messages) + { + message.append(msg); + message.append(";"); + } + break; + } + } + else + { + status = 2; + message = fmt::format("job {} not found", jobId); } - if (latch) - latch->wait(); + return env->NewObject(cache_result_class, cache_result_constructor, status, charTojstring(env, message.c_str())); } } \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.h b/cpp-ch/local-engine/Storages/Cache/CacheManager.h index a303b7b7fc63e..650c70e76ea05 100644 --- a/cpp-ch/local-engine/Storages/Cache/CacheManager.h +++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.h @@ -16,29 +16,32 @@ */ #pragma once #include -#include - +#include +#include namespace local_engine { struct MergeTreePart; struct MergeTreeTable; + + + /*** * Manage the cache of the MergeTree, mainly including meta.bin, data.bin, metadata.gluten */ class CacheManager { public: + static jclass cache_result_class; + static jmethodID cache_result_constructor; + static void initJNI(JNIEnv* env); + static CacheManager & instance(); static void initialize(DB::ContextMutablePtr context); - void cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set& columns, std::shared_ptr latch = nullptr); - void cacheParts(const String& table_def, const std::unordered_set& columns, bool async = true); + Task cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set& columns); + JobId cacheParts(const String& table_def, const std::unordered_set& columns, bool async = true); + static jobject getCacheStatus(JNIEnv * env, const String& jobId); private: CacheManager() = default; - - std::unique_ptr thread_pool; DB::ContextMutablePtr context; - std::unordered_map policy_to_disk; - std::unordered_map disk_to_metadisk; - std::unordered_map policy_to_cache; }; } \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp new file mode 100644 index 0000000000000..2e5b33c54b57c --- /dev/null +++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "JobScheduler.h" + +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} +} + +namespace CurrentMetrics +{ +extern const Metric LocalThread; +extern const Metric LocalThreadActive; +extern const Metric LocalThreadScheduled; +} + +namespace local_engine +{ +std::shared_ptr global_job_scheduler = nullptr; + +void JobScheduler::initialize(DB::ContextPtr context) +{ + auto config = GlutenJobSchedulerConfig::loadFromContext(context); + instance().thread_pool = std::make_unique( + CurrentMetrics::LocalThread, + CurrentMetrics::LocalThreadActive, + CurrentMetrics::LocalThreadScheduled, + config.job_scheduler_max_threads, + 0, + 0); + +} + +JobId JobScheduler::scheduleJob(Job && job, bool auto_remove) +{ + cleanFinishedJobs(); + if (job_details.contains(job.id)) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "job {} exists.", job.id); + } + size_t task_num = job.tasks.size(); + auto job_id = job.id; + std::vector task_results; + task_results.reserve(task_num); + JobContext job_context = {std::move(job), std::make_unique(task_num), std::move(task_results)}; + { + std::lock_guard lock(job_details_mutex); + job_details.emplace(job_id, std::move(job_context)); + } + LOG_INFO(logger, "schedule job {}", job_id); + + auto & job_detail = job_details.at(job_id); + + for (auto & task : job_detail.job.tasks) + { + job_detail.task_results.emplace_back(TaskResult()); + auto & task_result = job_detail.task_results.back(); + thread_pool->scheduleOrThrow( + [&, clean_job = auto_remove]() + { + SCOPE_EXIT({ + job_detail.remain_tasks->fetch_sub(1, std::memory_order::acquire); + if (job_detail.isFinished()) + { + addFinishedJob(job_detail.job.id); + } + }); + try + { + task(); + task_result.status = TaskResult::Status::SUCCESS; + } + catch (std::exception & e) + { + task_result.status = TaskResult::Status::FAILED; + task_result.message = e.what(); + } + }); + } + return job_id; +} + +std::optional JobScheduler::getJobSatus(const JobId & job_id) +{ + if (!job_details.contains(job_id)) + { + return std::nullopt; + } + std::optional res; + auto & job_context = job_details.at(job_id); + if (job_context.isFinished()) + { + std::vector messages; + for (auto & task_result : job_context.task_results) + { + if (task_result.status == TaskResult::Status::FAILED) + { + messages.push_back(task_result.message); + } + } + if (messages.empty()) + res = JobSatus::success(); + else + res= JobSatus::failed(messages); + } + else + res = JobSatus::running(); + return res; +} + +void JobScheduler::cleanupJob(const JobId & job_id) +{ + LOG_INFO(logger, "clean job {}", job_id); + job_details.erase(job_id); +} + +void JobScheduler::addFinishedJob(const JobId & job_id) +{ + std::lock_guard lock(finished_job_mutex); + auto job = std::make_pair(job_id, Stopwatch()); + finished_job.emplace_back(job); +} + +void JobScheduler::cleanFinishedJobs() +{ + std::lock_guard lock(finished_job_mutex); + for (auto it = finished_job.begin(); it != finished_job.end();) + { + // clean finished job after 5 minutes + if (it->second.elapsedSeconds() > 60 * 5) + { + cleanupJob(it->first); + it = finished_job.erase(it); + } + else + ++it; + } +} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.h b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h new file mode 100644 index 0000000000000..69333e9f641e3 --- /dev/null +++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ + +using JobId = String; +using Task = std::function; + +class Job +{ + friend class JobScheduler; +public: + explicit Job(const JobId& id) + : id(id) + { + } + + void addTask(Task&& task) + { + tasks.emplace_back(task); + } + +private: + JobId id; + std::vector tasks; +}; + + + +struct JobSatus +{ + enum Status + { + RUNNING, + FINISHED, + FAILED + }; + Status status; + std::vector messages; + + static JobSatus success() + { + return JobSatus{FINISHED}; + } + + static JobSatus running() + { + return JobSatus{RUNNING}; + } + + static JobSatus failed(const std::vector & messages) + { + return JobSatus{FAILED, messages}; + } +}; + +struct TaskResult +{ + enum Status + { + SUCCESS, + FAILED, + RUNNING + }; + Status status = RUNNING; + String message; +}; + +class JobContext +{ +public: + Job job; + std::unique_ptr remain_tasks = std::make_unique(); + std::vector task_results; + + bool isFinished() + { + return remain_tasks->load(std::memory_order::relaxed) == 0; + } +}; + +class JobScheduler +{ +public: + static JobScheduler& instance() + { + static JobScheduler global_job_scheduler; + return global_job_scheduler; + } + + static void initialize(DB::ContextPtr context); + + JobId scheduleJob(Job && job, bool auto_remove); + + std::optional getJobSatus(const JobId& job_id); + + void cleanupJob(const JobId& job_id); + + void addFinishedJob(const JobId& job_id); + + void cleanFinishedJobs(); +private: + JobScheduler() = default; + std::unique_ptr thread_pool; + std::unordered_map job_details; + std::mutex job_details_mutex; + + std::vector> finished_job; + std::mutex finished_job_mutex; + LoggerPtr logger = getLogger("JobScheduler"); +}; +} diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 828556b4abf66..6fe9775fc7dca 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -163,6 +163,7 @@ JNIEXPORT jint JNI_OnLoad(JavaVM * vm, void * /*reserved*/) env, local_engine::SparkRowToCHColumn::spark_row_interator_class, "nextBatch", "()Ljava/nio/ByteBuffer;"); local_engine::BroadCastJoinBuilder::init(env); + local_engine::CacheManager::initJNI(env); local_engine::JNIUtils::vm = vm; return JNI_VERSION_1_8; @@ -1269,7 +1270,7 @@ JNIEXPORT void Java_org_apache_gluten_utils_TestExceptionUtils_generateNativeExc -JNIEXPORT void Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_, jboolean async_) +JNIEXPORT jstring Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_, jboolean async_) { LOCAL_ENGINE_JNI_METHOD_START auto table_def = jstring2string(env, table_); @@ -1280,10 +1281,17 @@ JNIEXPORT void Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCache { column_set.insert(col); } - local_engine::CacheManager::instance().cacheParts(table_def, column_set, async_); - LOCAL_ENGINE_JNI_METHOD_END(env, ); + auto id = local_engine::CacheManager::instance().cacheParts(table_def, column_set, async_); + return local_engine::charTojstring(env, id.c_str()); + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr); } +JNIEXPORT jobject Java_org_apache_gluten_execution_CHNativeCacheManager_nativeGetCacheStatus(JNIEnv * env, jobject, jstring id) +{ + LOCAL_ENGINE_JNI_METHOD_START + return local_engine::CacheManager::instance().getCacheStatus(env, jstring2string(env, id)); + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr); +} #ifdef __cplusplus }