From eee234e398c9418b6f5f93dcfb142e0e0948711f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Mon, 24 Jun 2024 13:51:42 +0800 Subject: [PATCH 01/30] [GLUTEN-6122] Fix crash when driver send shutdown command to executor #6130 What changes were proposed in this pull request? Fix crash when driver send shutdown command to executor (Fixes: #6122) --- cpp-ch/local-engine/Common/CHUtil.cpp | 7 ++- .../Parser/SerializedPlanParser.cpp | 56 ++++++++++++++++++- .../Parser/SerializedPlanParser.h | 14 ++++- cpp-ch/local-engine/local_engine_jni.cpp | 11 +++- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 937beae99a6b..be66d8ecc509 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -750,7 +750,7 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config) size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE); double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO); global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); - + String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY); size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE); double index_mark_cache_size_ratio = config->getDouble("index_mark_cache_size_ratio", DEFAULT_INDEX_MARK_CACHE_SIZE_RATIO); @@ -919,7 +919,10 @@ void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, void BackendFinalizerUtil::finalizeGlobally() { - // Make sure client caches release before ClientCacheRegistry + /// Make sure that all active LocalExecutor stop before spark executor shutdown, otherwise crash map happen. + LocalExecutor::cancelAll(); + + /// Make sure client caches release before ClientCacheRegistry ReadBufferBuilderFactory::instance().clean(); StorageMergeTreeFactory::clear(); auto & global_context = SerializedPlanParser::global_context; diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index f9ea783a2bbd..70db692c8009 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -2053,6 +2053,33 @@ void SerializedPlanParser::wrapNullable( SharedContextHolder SerializedPlanParser::shared_context; +std::unordered_map LocalExecutor::executors; +std::mutex LocalExecutor::executors_mutex; + +void LocalExecutor::cancelAll() +{ + std::lock_guard lock{executors_mutex}; + + for (auto & [handle, executor] : executors) + executor->asyncCancel(); + + for (auto & [handle, executor] : executors) + executor->waitCancelFinished(); +} + +void LocalExecutor::addExecutor(LocalExecutor * executor) +{ + std::lock_guard lock{executors_mutex}; + Int64 handle = reinterpret_cast(executor); + executors.emplace(handle, executor); +} + +void LocalExecutor::removeExecutor(Int64 handle) +{ + std::lock_guard lock{executors_mutex}; + executors.erase(handle); +} + LocalExecutor::~LocalExecutor() { if (context->getConfigRef().getBool("dump_pipeline", false)) @@ -2183,8 +2210,35 @@ Block * LocalExecutor::nextColumnar() void LocalExecutor::cancel() { - if (executor) + asyncCancel(); + waitCancelFinished(); +} + +void LocalExecutor::asyncCancel() +{ + if (executor && !is_cancelled) + { + LOG_INFO(&Poco::Logger::get("LocalExecutor"), "Cancel LocalExecutor {}", reinterpret_cast(this)); executor->cancel(); + } +} + +void LocalExecutor::waitCancelFinished() +{ + if (executor && !is_cancelled) + { + Stopwatch watch; + Chunk chunk; + while (executor->pull(chunk)) + ; + is_cancelled = true; + + LOG_INFO( + &Poco::Logger::get("LocalExecutor"), + "Finish cancel LocalExecutor {}, takes {} ms", + reinterpret_cast(this), + watch.elapsedMilliseconds()); + } } Block & LocalExecutor::getHeader() diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 8964f42d9d02..71cdca58a6ce 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -412,7 +412,7 @@ class LocalExecutor : public BlockIterator Block * nextColumnar(); bool hasNext(); - /// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal + /// Stop execution and wait for pipeline exit, used when task receives shutdown command or executor receives SIGTERM signal void cancel(); Block & getHeader(); @@ -420,9 +420,16 @@ class LocalExecutor : public BlockIterator void setMetric(RelMetricPtr metric_) { metric = metric_; } void setExtraPlanHolder(std::vector & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); } + static void cancelAll(); + static void addExecutor(LocalExecutor * executor); + static void removeExecutor(Int64 handle); + private: std::unique_ptr writeBlockToSparkRow(DB::Block & block); + void asyncCancel(); + void waitCancelFinished(); + /// Dump processor runtime information to log std::string dumpPipeline(); @@ -435,6 +442,11 @@ class LocalExecutor : public BlockIterator DB::QueryPlanPtr current_query_plan; RelMetricPtr metric; std::vector extra_plan_holder; + std::atomic is_cancelled{false}; + + /// Record all active LocalExecutor in current executor to cancel them when executor receives shutdown command from driver. + static std::unordered_map executors; + static std::mutex executors_mutex; }; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 256f373c28b5..bbc467879182 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -283,7 +283,8 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ plan_string.assign(reinterpret_cast(plan_address), plan_size); auto query_plan = parser.parse(plan_string); local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(query_context); - LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); + local_engine::LocalExecutor::addExecutor(executor); + LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); executor->setMetric(parser.getMetric()); executor->setExtraPlanHolder(parser.extra_plan_holder); executor->execute(std::move(query_plan)); @@ -314,17 +315,19 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_BatchIterator_nativeCHNext(JNI JNIEXPORT void Java_org_apache_gluten_vectorized_BatchIterator_nativeCancel(JNIEnv * env, jobject /*obj*/, jlong executor_address) { LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor::removeExecutor(executor_address); local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); executor->cancel(); - LOG_INFO(&Poco::Logger::get("jni"), "Cancel LocalExecutor {}", reinterpret_cast(executor)); + LOG_INFO(&Poco::Logger::get("jni"), "Cancel LocalExecutor {}", reinterpret_cast(executor)); LOCAL_ENGINE_JNI_METHOD_END(env, ) } JNIEXPORT void Java_org_apache_gluten_vectorized_BatchIterator_nativeClose(JNIEnv * env, jobject /*obj*/, jlong executor_address) { LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor::removeExecutor(executor_address); local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); - LOG_INFO(&Poco::Logger::get("jni"), "Finalize LocalExecutor {}", reinterpret_cast(executor)); + LOG_INFO(&Poco::Logger::get("jni"), "Finalize LocalExecutor {}", reinterpret_cast(executor)); delete executor; LOCAL_ENGINE_JNI_METHOD_END(env, ) } @@ -1332,6 +1335,7 @@ Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE plan_string.assign(reinterpret_cast(plan_address), plan_size); auto query_plan = parser.parse(plan_string); local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(context); + local_engine::LocalExecutor::addExecutor(executor); executor->execute(std::move(query_plan)); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); return reinterpret_cast(executor); @@ -1341,6 +1345,7 @@ Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE JNIEXPORT void Java_org_apache_gluten_vectorized_SimpleExpressionEval_nativeClose(JNIEnv * env, jclass, jlong instance) { LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor::removeExecutor(instance); local_engine::LocalExecutor * executor = reinterpret_cast(instance); delete executor; LOCAL_ENGINE_JNI_METHOD_END(env, ) From e0fcfe586efc7efb3ec0c349d5ca8b2371d969d4 Mon Sep 17 00:00:00 2001 From: Shuai li Date: Mon, 24 Jun 2024 13:55:39 +0800 Subject: [PATCH 02/30] [GLUTEN-6178][CH] Add config to insert remote file system directly #6192 What changes were proposed in this pull request? (Please fill in changes proposed in this fix) (Fixes: #6178) How was this patch tested? Test by ut --- ...nClickHouseMergeTreeWriteOnHDFSSuite.scala | 44 ++++++++++++++++++- cpp-ch/local-engine/Common/CHUtil.cpp | 3 +- cpp-ch/local-engine/Common/CHUtil.h | 4 +- .../Disks/ObjectStorages/GlutenDiskHDFS.cpp | 10 ++++- .../Disks/ObjectStorages/GlutenDiskHDFS.h | 2 + .../Mergetree/SparkMergeTreeWriter.cpp | 40 +++++++++-------- .../Storages/Mergetree/SparkMergeTreeWriter.h | 2 + 7 files changed, 83 insertions(+), 22 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala index 572d0cd50a6e..99b212059966 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala @@ -25,10 +25,12 @@ import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMerg import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import java.io.File +import scala.concurrent.duration.DurationInt + // Some sqls' line length exceeds 100 // scalastyle:off line.size.limit @@ -614,5 +616,45 @@ class GlutenClickHouseMergeTreeWriteOnHDFSSuite .count() assertResult(600572)(result) } + + test("test mergetree insert with optimize basic") { + val tableName = "lineitem_mergetree_insert_optimize_basic_hdfs" + val dataPath = s"$HDFS_URL/test/$tableName" + + withSQLConf( + "spark.databricks.delta.optimize.minFileSize" -> "200000000", + "spark.gluten.sql.columnar.backend.ch.runtime_settings.mergetree.merge_after_insert" -> "true", + "spark.gluten.sql.columnar.backend.ch.runtime_settings.mergetree.insert_without_local_storage" -> "true", + "spark.gluten.sql.columnar.backend.ch.runtime_settings.min_insert_block_size_rows" -> "10000" + ) { + spark.sql(s""" + |DROP TABLE IF EXISTS $tableName; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS $tableName + |USING clickhouse + |LOCATION '$dataPath' + |TBLPROPERTIES (storage_policy='__hdfs_main') + | as select * from lineitem + |""".stripMargin) + + val ret = spark.sql(s"select count(*) from $tableName").collect() + assertResult(600572)(ret.apply(0).get(0)) + val conf = new Configuration + conf.set("fs.defaultFS", HDFS_URL) + val fs = FileSystem.get(conf) + + eventually(timeout(60.seconds), interval(2.seconds)) { + val it = fs.listFiles(new Path(dataPath), true) + var files = 0 + while (it.hasNext) { + it.next() + files += 1 + } + assertResult(72)(files) + } + } + } } // scalastyle:off line.size.limit diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index be66d8ecc509..94cd38003bad 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -623,7 +623,8 @@ void BackendInitializerUtil::initSettings(std::map & b { /// Initialize default setting. settings.set("date_time_input_format", "best_effort"); - settings.set("mergetree.merge_after_insert", true); + settings.set(MERGETREE_MERGE_AFTER_INSERT, true); + settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false); for (const auto & [key, value] : backend_conf_map) { diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 50de9461f4de..94e0f0168e11 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -35,7 +35,9 @@ class QueryPlan; namespace local_engine { -static const std::unordered_set BOOL_VALUE_SETTINGS{"mergetree.merge_after_insert"}; +static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = "mergetree.insert_without_local_storage"; +static const String MERGETREE_MERGE_AFTER_INSERT = "mergetree.merge_after_insert"; +static const std::unordered_set BOOL_VALUE_SETTINGS{MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE}; static const std::unordered_set LONG_VALUE_SETTINGS{ "optimize.maxfilesize", "optimize.minFileSize", "mergetree.max_num_part_per_merge_task"}; diff --git a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp index 07a7aa6bd006..f207ad232b4f 100644 --- a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp +++ b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp @@ -52,7 +52,15 @@ void GlutenDiskHDFS::createDirectories(const String & path) void GlutenDiskHDFS::removeDirectory(const String & path) { DiskObjectStorage::removeDirectory(path); - hdfsDelete(hdfs_object_storage->getHDFSFS(), path.c_str(), 1); + String abs_path = "/" + path; + hdfsDelete(hdfs_object_storage->getHDFSFS(), abs_path.c_str(), 1); +} + +void GlutenDiskHDFS::removeRecursive(const String & path) +{ + DiskObjectStorage::removeRecursive(path); + String abs_path = "/" + path; + hdfsDelete(hdfs_object_storage->getHDFSFS(), abs_path.c_str(), 1); } DiskObjectStoragePtr GlutenDiskHDFS::createDiskObjectStorage() diff --git a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h index 222b9f8928a3..97a99f1deaba 100644 --- a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h +++ b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h @@ -57,6 +57,8 @@ class GlutenDiskHDFS : public DB::DiskObjectStorage void removeDirectory(const String & path) override; + void removeRecursive(const String & path) override; + DB::DiskObjectStoragePtr createDiskObjectStorage() override; std::unique_ptr writeFile(const String& path, size_t buf_size, DB::WriteMode mode, diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp index c1f2391a282c..406f2aaa23df 100644 --- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp +++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp @@ -69,11 +69,23 @@ SparkMergeTreeWriter::SparkMergeTreeWriter( , bucket_dir(bucket_dir_) , thread_pool(CurrentMetrics::LocalThread, CurrentMetrics::LocalThreadActive, CurrentMetrics::LocalThreadScheduled, 1, 1, 100000) { + const DB::Settings & settings = context->getSettingsRef(); + merge_after_insert = settings.get(MERGETREE_MERGE_AFTER_INSERT).get(); + insert_without_local_storage = settings.get(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE).get(); + + Field limit_size_field; + if (settings.tryGet("optimize.minFileSize", limit_size_field)) + merge_min_size = limit_size_field.get() <= 0 ? merge_min_size : limit_size_field.get(); + + Field limit_cnt_field; + if (settings.tryGet("mergetree.max_num_part_per_merge_task", limit_cnt_field)) + merge_limit_parts = limit_cnt_field.get() <= 0 ? merge_limit_parts : limit_cnt_field.get(); + dest_storage = MergeTreeRelParser::parseStorage(merge_tree_table, SerializedPlanParser::global_context); + isRemoteStorage = dest_storage->getStoragePolicy()->getAnyDisk()->isRemote(); - if (dest_storage->getStoragePolicy()->getAnyDisk()->isRemote()) + if (useLocalStorage()) { - isRemoteStorage = true; temp_storage = MergeTreeRelParser::copyToDefaultPolicyStorage(merge_tree_table, SerializedPlanParser::global_context); storage = temp_storage; LOG_DEBUG( @@ -86,22 +98,14 @@ SparkMergeTreeWriter::SparkMergeTreeWriter( metadata_snapshot = storage->getInMemoryMetadataPtr(); header = metadata_snapshot->getSampleBlock(); - const DB::Settings & settings = context->getSettingsRef(); squashing = std::make_unique(header, settings.min_insert_block_size_rows, settings.min_insert_block_size_bytes); if (!partition_dir.empty()) extractPartitionValues(partition_dir, partition_values); +} - Field is_merge; - if (settings.tryGet("mergetree.merge_after_insert", is_merge)) - merge_after_insert = is_merge.get(); - - Field limit_size_field; - if (settings.tryGet("optimize.minFileSize", limit_size_field)) - merge_min_size = limit_size_field.get() <= 0 ? merge_min_size : limit_size_field.get(); - - Field limit_cnt_field; - if (settings.tryGet("mergetree.max_num_part_per_merge_task", limit_cnt_field)) - merge_limit_parts = limit_cnt_field.get() <= 0 ? merge_limit_parts : limit_cnt_field.get(); +bool SparkMergeTreeWriter::useLocalStorage() const +{ + return !insert_without_local_storage && isRemoteStorage; } void SparkMergeTreeWriter::write(const DB::Block & block) @@ -161,7 +165,7 @@ void SparkMergeTreeWriter::manualFreeMemory(size_t before_write_memory) // it may alloc memory in current thread, and free on global thread. // Now, wo have not idea to clear global memory by used spark thread tracker. // So we manually correct the memory usage. - if (!isRemoteStorage) + if (isRemoteStorage && insert_without_local_storage) return; auto disk = storage->getStoragePolicy()->getAnyDisk(); @@ -219,7 +223,7 @@ void SparkMergeTreeWriter::saveMetadata() void SparkMergeTreeWriter::commitPartToRemoteStorageIfNeeded() { - if (!isRemoteStorage) + if (!useLocalStorage()) return; LOG_DEBUG( @@ -289,8 +293,8 @@ void SparkMergeTreeWriter::finalizeMerge() { for (const auto & disk : storage->getDisks()) { - auto full_path = storage->getFullPathOnDisk(disk); - disk->removeRecursive(full_path + "/" + tmp_part); + auto rel_path = storage->getRelativeDataPath() + "/" + tmp_part; + disk->removeRecursive(rel_path); } }); } diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h index 2b07521ede3a..13ac22394477 100644 --- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h +++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h @@ -79,6 +79,7 @@ class SparkMergeTreeWriter void finalizeMerge(); bool chunkToPart(Chunk && chunk); bool blockToPart(Block & block); + bool useLocalStorage() const; CustomStorageMergeTreePtr storage = nullptr; CustomStorageMergeTreePtr dest_storage = nullptr; @@ -97,6 +98,7 @@ class SparkMergeTreeWriter std::unordered_set tmp_parts; DB::Block header; bool merge_after_insert; + bool insert_without_local_storage; FreeThreadPool thread_pool; size_t merge_min_size = 1024 * 1024 * 1024; size_t merge_limit_parts = 10; From 0ef2f8216b03f5f279d80c71baac30dbdb94199f Mon Sep 17 00:00:00 2001 From: Zhen Li <10524738+zhli1142015@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:35:14 +0800 Subject: [PATCH 03/30] [VL] Support KnownNullable for Spark 3.5 (#6193) [VL] Support KnownNullable for Spark 3.5. --- .../execution/ScalarFunctionsValidateSuite.scala | 14 +++++++++++++- .../gluten/sql/shims/spark35/Spark35Shims.scala | 4 +++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 11eaa3289cab..75b60addfa13 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.types._ import java.sql.Timestamp @@ -1145,7 +1146,18 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { runQueryAndCompare( "SELECT a, window.start, window.end, count(*) as cnt FROM" + " string_timestamp GROUP by a, window(b, '5 minutes') ORDER BY a, start;") { - checkGlutenOperatorMatch[ProjectExecTransformer] + df => + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.exists(plan => plan.isInstanceOf[ProjectExecTransformer]), + s"Expect ProjectExecTransformer exists " + + s"in executedPlan:\n ${executedPlan.last}" + ) + assert( + !executedPlan.exists(plan => plan.isInstanceOf[ProjectExec]), + s"Expect ProjectExec doesn't exist " + + s"in executedPlan:\n ${executedPlan.last}" + ) } } } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 95571f166ebe..f6feae01a8b2 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -76,7 +76,9 @@ class Spark35Shims extends SparkShims { Sig[SplitPart](ExpressionNames.SPLIT_PART), Sig[Sec](ExpressionNames.SEC), Sig[Csc](ExpressionNames.CSC), - Sig[Empty2Null](ExpressionNames.EMPTY2NULL)) + Sig[KnownNullable](ExpressionNames.KNOWN_NULLABLE), + Sig[Empty2Null](ExpressionNames.EMPTY2NULL) + ) } override def aggregateExpressionMappings: Seq[Sig] = { From f07e348f4dfa5cf72a15a6986fd9524873072cdc Mon Sep 17 00:00:00 2001 From: Gluten Performance Bot <137994563+GlutenPerfBot@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:39:38 +0800 Subject: [PATCH 04/30] [VL] Daily Update Velox Version (2024_06_24) (#6187) f45966f17 by Deepak Majeti, Use separate headers for DWRF Reader Writer registration API (10132) 00485536f by Zac Wen, Switch to storage read if SSD cache load fails (10256) 3c2cc4b26 by Bikramjeet Vig, Fix NaN handling for in-predicate (10115) 18c4d5e2b by Kevin Wilfong, Capture MemoryArbitrationContext and ThreadDebugInfo in AsyncSource and restore them when invoking make (10186) 171174833 by Jimmy Lu, Count IO execution time in ExponentialBackoff retry policy (10286) 24f5aed63 by zhli1142015, Add support for DECIMAL input to greatest and least Spark functions (10195) 8faac7bf2 by zhli1142015, Add log Spark function (10243) c97e7fcc8 by Kevin Wilfong, Fix parallel spills lead to crashes in approx_percentile (10268) 54b2ce9a5 by Reetika Agrawal, Add benchmark for IcebergSplitReader (9849) dcd49ca38 by Krishna Pai, Restrict CAST of string to boolean (9833) ca5e409aa by xiaoxmeng, Only load stripe footer in buffer input support sync load (10276) 652cf372e by Zac Wen, Fix memory cache hit underreporting in ioStats (10272) a2366523d by yanngyoung, Add order by plan for memory arbitration fuzzer (10255) a5b443a70 by Wei He, Update header guards in files in velox/external/date to avoid collision (10269) --- cpp/velox/memory/VeloxMemoryManager.cc | 2 +- ep/build-velox/src/get_velox.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 60c79ffe8725..733eb4c4bc39 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -74,7 +74,7 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { uint64_t targetBytes, bool allowSpill, bool allowAbort) override { - velox::memory::ScopedMemoryArbitrationContext ctx(nullptr); + velox::memory::ScopedMemoryArbitrationContext ctx((const velox::memory::MemoryPool*)nullptr); facebook::velox::exec::MemoryReclaimer::Stats status; VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool"); std::lock_guard l(mutex_); // FIXME: Do we have recursive locking for this mutex? diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index a0a7baa0da45..d3ecddbdfa9a 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_06_21 +VELOX_BRANCH=2024_06_24 VELOX_HOME="" #Set on run gluten on HDFS From 1fbdbc41779321db3380bce0807b73389af64e1a Mon Sep 17 00:00:00 2001 From: Chang chen Date: Tue, 25 Jun 2024 07:13:39 +0800 Subject: [PATCH 05/30] [GLUTEN-6067][CH] [Part 2] Support CH backend with Spark3.5 - Prepare for supporting sink transform (#6197) [CH] [Part 2] Support CH backend with Spark3.5 - Prepare for supporting sink transform * [Refactor] remove duplicate codes * Add NativeWriteChecker * [Prepare to commit] getExtendedColumnarPostRules from Spark shim --- .../clickhouse/CHIteratorApi.scala | 143 ++-- .../clickhouse/CHSparkPlanExecApi.scala | 9 - .../execution/CHHashJoinExecTransformer.scala | 3 +- ...lutenClickHouseNativeWriteTableSuite.scala | 612 ++++++++---------- .../GlutenClickHouseTPCHMetricsSuite.scala | 2 +- .../spark/gluten/NativeWriteChecker.scala | 52 ++ .../velox/VeloxSparkPlanExecApi.scala | 9 - cpp-ch/local-engine/Common/CHUtil.cpp | 17 +- cpp-ch/local-engine/Common/CHUtil.h | 12 +- .../Parser/CHColumnToSparkRow.cpp | 2 +- .../Parser/SerializedPlanParser.cpp | 310 ++++----- .../Parser/SerializedPlanParser.h | 39 +- cpp-ch/local-engine/local_engine_jni.cpp | 39 +- .../tests/benchmark_local_engine.cpp | 80 +-- cpp-ch/local-engine/tests/gluten_test_util.h | 18 + .../local-engine/tests/gtest_local_engine.cpp | 22 +- cpp-ch/local-engine/tests/gtest_parser.cpp | 407 ++++-------- .../tests/json/clickhouse_pr_65234.json | 273 ++++++++ .../tests/json/gtest_local_engine_config.json | 269 ++++++++ .../json/read_student_option_schema.csv.json | 77 +++ .../gluten/backendsapi/SparkPlanExecApi.scala | 4 +- .../utils/SubstraitPlanPrinterUtil.scala | 35 +- 22 files changed, 1379 insertions(+), 1055 deletions(-) create mode 100644 backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala create mode 100644 cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json create mode 100644 cpp-ch/local-engine/tests/json/gtest_local_engine_config.json create mode 100644 cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index 941237629569..376e46ebe975 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.backendsapi.clickhouse -import org.apache.gluten.{GlutenConfig, GlutenNumaBindingInfo} +import org.apache.gluten.GlutenNumaBindingInfo import org.apache.gluten.backendsapi.IteratorApi import org.apache.gluten.execution._ import org.apache.gluten.expression.ConverterUtils @@ -61,6 +61,52 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { StructType(dataSchema) } + private def createNativeIterator( + splitInfoByteArray: Array[Array[Byte]], + wsPlan: Array[Byte], + materializeInput: Boolean, + inputIterators: Seq[Iterator[ColumnarBatch]]): BatchIterator = { + + /** Generate closeable ColumnBatch iterator. */ + val listIterator = + inputIterators + .map { + case i: CloseableCHColumnBatchIterator => i + case it => new CloseableCHColumnBatchIterator(it) + } + .map(it => new ColumnarNativeIterator(it.asJava).asInstanceOf[GeneralInIterator]) + .asJava + new CHNativeExpressionEvaluator().createKernelWithBatchIterator( + wsPlan, + splitInfoByteArray, + listIterator, + materializeInput + ) + } + + private def createCloseIterator( + context: TaskContext, + pipelineTime: SQLMetric, + updateNativeMetrics: IMetrics => Unit, + updateInputMetrics: Option[InputMetricsWrapper => Unit] = None, + nativeIter: BatchIterator): CloseableCHColumnBatchIterator = { + + val iter = new CollectMetricIterator( + nativeIter, + updateNativeMetrics, + updateInputMetrics, + updateInputMetrics.map(_ => context.taskMetrics().inputMetrics).orNull) + + context.addTaskFailureListener( + (ctx, _) => { + if (ctx.isInterrupted()) { + iter.cancel() + } + }) + context.addTaskCompletionListener[Unit](_ => iter.close()) + new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) + } + // only set file schema for text format table private def setFileSchemaForLocalFiles( localFilesNode: LocalFilesNode, @@ -198,45 +244,24 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { inputIterators: Seq[Iterator[ColumnarBatch]] = Seq() ): Iterator[ColumnarBatch] = { - assert( + require( inputPartition.isInstanceOf[GlutenPartition], "CH backend only accepts GlutenPartition in GlutenWholeStageColumnarRDD.") - - val transKernel = new CHNativeExpressionEvaluator() - val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map { - iter => new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava) - }.asJava) - val splitInfoByteArray = inputPartition .asInstanceOf[GlutenPartition] .splitInfosByteArray - val nativeIter = - transKernel.createKernelWithBatchIterator( - inputPartition.plan, - splitInfoByteArray, - inBatchIters, - false) + val wsPlan = inputPartition.plan + val materializeInput = false - val iter = new CollectMetricIterator( - nativeIter, - updateNativeMetrics, - updateInputMetrics, - context.taskMetrics().inputMetrics) - - context.addTaskFailureListener( - (ctx, _) => { - if (ctx.isInterrupted()) { - iter.cancel() - } - }) - context.addTaskCompletionListener[Unit](_ => iter.close()) - - // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator( context, - new CloseableCHColumnBatchIterator( - iter.asInstanceOf[Iterator[ColumnarBatch]], - Some(pipelineTime))) + createCloseIterator( + context, + pipelineTime, + updateNativeMetrics, + Some(updateInputMetrics), + createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) + ) } // Generate Iterator[ColumnarBatch] for final stage. @@ -252,52 +277,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { partitionIndex: Int, materializeInput: Boolean): Iterator[ColumnarBatch] = { // scalastyle:on argcount - GlutenConfig.getConf - - val transKernel = new CHNativeExpressionEvaluator() - val columnarNativeIterator = - new JArrayList[GeneralInIterator](inputIterators.map { - iter => - new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava) - }.asJava) - // we need to complete dependency RDD's firstly - val nativeIterator = transKernel.createKernelWithBatchIterator( - rootNode.toProtobuf.toByteArray, - // Final iterator does not contain scan split, so pass empty split info to native here. - new Array[Array[Byte]](0), - columnarNativeIterator, - materializeInput - ) - - val iter = new CollectMetricIterator(nativeIterator, updateNativeMetrics, null, null) - context.addTaskFailureListener( - (ctx, _) => { - if (ctx.isInterrupted()) { - iter.cancel() - } - }) - context.addTaskCompletionListener[Unit](_ => iter.close()) - new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) - } -} + // Final iterator does not contain scan split, so pass empty split info to native here. + val splitInfoByteArray = new Array[Array[Byte]](0) + val wsPlan = rootNode.toProtobuf.toByteArray -object CHIteratorApi { - - /** Generate closeable ColumnBatch iterator. */ - def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { - iter match { - case _: CloseableCHColumnBatchIterator => iter - case _ => new CloseableCHColumnBatchIterator(iter) - } + // we need to complete dependency RDD's firstly + createCloseIterator( + context, + pipelineTime, + updateNativeMetrics, + None, + createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) } } class CollectMetricIterator( val nativeIterator: BatchIterator, val updateNativeMetrics: IMetrics => Unit, - val updateInputMetrics: InputMetricsWrapper => Unit, - val inputMetrics: InputMetrics + val updateInputMetrics: Option[InputMetricsWrapper => Unit] = None, + val inputMetrics: InputMetrics = null ) extends Iterator[ColumnarBatch] { private var outputRowCount = 0L private var outputVectorCount = 0L @@ -329,9 +328,7 @@ class CollectMetricIterator( val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics] nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount) updateNativeMetrics(nativeMetrics) - if (updateInputMetrics != null) { - updateInputMetrics(inputMetrics) - } + updateInputMetrics.foreach(_(inputMetrics)) metricsUpdated = true } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 1c83e326eed4..ac3ea61ff810 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -50,7 +50,6 @@ import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation} -import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} @@ -583,14 +582,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = List() - /** - * Generate extended columnar post-rules. - * - * @return - */ - override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => NativeWritePostRule(spark)) - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index a7e7769e7736..da9d9c7586c0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi import org.apache.gluten.extension.ValidationResult import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} @@ -75,7 +74,7 @@ case class CHBroadcastBuildSideRDD( override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = { CHBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted, broadcastContext) - CHIteratorApi.genCloseableColumnBatchIterator(Iterator.empty) + Iterator.empty } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 9269303d9251..ccf7bb5d5b2a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -21,6 +21,7 @@ import org.apache.gluten.execution.AllDataTypesWithComplexType.genTestData import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf +import org.apache.spark.gluten.NativeWriteChecker import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -28,11 +29,14 @@ import org.apache.spark.sql.test.SharedSparkSession import org.scalatest.BeforeAndAfterAll +import scala.reflect.runtime.universe.TypeTag + class GlutenClickHouseNativeWriteTableSuite extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper with SharedSparkSession - with BeforeAndAfterAll { + with BeforeAndAfterAll + with NativeWriteChecker { private var _hiveSpark: SparkSession = _ @@ -114,16 +118,19 @@ class GlutenClickHouseNativeWriteTableSuite def getColumnName(s: String): String = { s.replaceAll("\\(", "_").replaceAll("\\)", "_") } + import collection.immutable.ListMap import java.io.File def writeIntoNewTableWithSql(table_name: String, table_create_sql: String)( fields: Seq[String]): Unit = { - spark.sql(table_create_sql) - spark.sql( - s"insert overwrite $table_name select ${fields.mkString(",")}" + - s" from origin_table") + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite( + s"insert overwrite $table_name select ${fields.mkString(",")}" + + s" from origin_table", + checkNative = true) + } } def writeAndCheckRead( @@ -170,82 +177,86 @@ class GlutenClickHouseNativeWriteTableSuite }) } - test("test insert into dir") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") + private val fields_ = ListMap( + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("byte_field", "byte"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)"), + ("date_field", "date") + ) - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) + def withDestinationTable(table: String, createTableSql: String)(f: => Unit): Unit = { + spark.sql(s"drop table IF EXISTS $table") + spark.sql(s"$createTableSql") + f + } - for (format <- formats) { - spark.sql( - s"insert overwrite local directory '$basePath/test_insert_into_${format}_dir1' " - + s"stored as $format select " - + fields.keys.mkString(",") + - " from origin_table cluster by (byte_field)") - spark.sql( - s"insert overwrite local directory '$basePath/test_insert_into_${format}_dir2' " + - s"stored as $format " + - "select string_field, sum(int_field) as x from origin_table group by string_field") - } + def nativeWrite(f: String => Unit): Unit = { + withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { + formats.foreach(f(_)) } } - test("test insert into partition") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - ("spark.sql.orc.compression.codec", "lz4"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") + def nativeWrite2( + f: String => (String, String, String), + extraCheck: (String, String, String) => Unit = null): Unit = nativeWrite { + format => + val (table_name, table_create_sql, insert_sql) = f(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) + Option(extraCheck).foreach(_(table_name, table_create_sql, insert_sql)) + } + } - val table_create_sql = - s"create table if not exists $table_name (" + - fields - .map(f => s"${f._1} ${f._2}") - .mkString(",") + - " ) partitioned by (another_date_field date) " + - s"stored as $format" + def nativeWriteWithOriginalView[A <: Product: TypeTag]( + data: Seq[A], + viewName: String, + pairs: (String, String)*)(f: String => Unit): Unit = { + val configs = pairs :+ ("spark.gluten.sql.native.writer.enabled", "true") + withSQLConf(configs: _*) { + withTempView(viewName) { + spark.createDataFrame(data).createOrReplaceTempView(viewName) + formats.foreach(f(_)) + } + } + } - spark.sql(table_create_sql) + test("test insert into dir") { + nativeWriteWithOriginalView(genTestData(), "origin_table") { + format => + Seq( + s"""insert overwrite local directory '$basePath/test_insert_into_${format}_dir1' + |stored as $format select ${fields_.keys.mkString(",")} + |from origin_table""".stripMargin, + s"""insert overwrite local directory '$basePath/test_insert_into_${format}_dir2' + |stored as $format select string_field, sum(int_field) as x + |from origin_table group by string_field""".stripMargin + ).foreach(checkNativeWrite(_, checkNative = true)) + } + } - spark.sql( - s"insert into $table_name partition(another_date_field = '2020-01-01') select " - + fields.keys.mkString(",") + - " from origin_table") + test("test insert into partition") { + def destination(format: String): (String, String, String) = { + val table_name = table_name_template.format(format) + val table_create_sql = + s"""create table if not exists $table_name + |(${fields_.map(f => s"${f._1} ${f._2}").mkString(",")}) + |partitioned by (another_date_field date) stored as $format""".stripMargin + val insert_sql = + s"""insert into $table_name partition(another_date_field = '2020-01-01') + | select ${fields_.keys.mkString(",")} from origin_table""".stripMargin + (table_name, table_create_sql, insert_sql) + } + def nativeFormatWrite(format: String): Unit = { + val (table_name, table_create_sql, insert_sql) = destination(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) var files = recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) .filter(_.getName.endsWith(s".$format")) if (format == "orc") { @@ -255,154 +266,103 @@ class GlutenClickHouseNativeWriteTableSuite assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01")) } } + + nativeWriteWithOriginalView( + genTestData(), + "origin_table", + ("spark.sql.orc.compression.codec", "lz4"))(nativeFormatWrite) } test("test CTAS") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { + nativeWriteWithOriginalView(genTestData(), "origin_table") { + format => val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") val table_create_sql = s"create table $table_name using $format as select " + - fields + fields_ .map(f => s"${f._1}") .mkString(",") + " from origin_table" - spark.sql(table_create_sql) - spark.sql(s"drop table IF EXISTS $table_name") + val insert_sql = + s"create table $table_name as select " + + fields_ + .map(f => s"${f._1}") + .mkString(",") + + " from origin_table" + withDestinationTable(table_name, table_create_sql) { + spark.sql(s"drop table IF EXISTS $table_name") - try { - val table_create_sql = - s"create table $table_name as select " + - fields - .map(f => s"${f._1}") - .mkString(",") + - " from origin_table" - spark.sql(table_create_sql) - } catch { - case _: UnsupportedOperationException => // expected - case _: Exception => fail("should not throw exception") + try { + // FIXME: using checkNativeWrite + spark.sql(insert_sql) + } catch { + case _: UnsupportedOperationException => // expected + case e: Exception => fail("should not throw exception", e) + } } - } } } test("test insert into partition, bigo's case which incur InsertIntoHiveTable") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - ("spark.sql.hive.convertMetastoreParquet", "false"), - ("spark.sql.hive.convertMetastoreOrc", "false"), - (GlutenConfig.GLUTEN_ENABLED.key, "true") - ) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") - val table_create_sql = s"create table if not exists $table_name (" + fields - .map(f => s"${f._1} ${f._2}") - .mkString(",") + " ) partitioned by (another_date_field string)" + - s"stored as $format" + def destination(format: String): (String, String, String) = { + val table_name = table_name_template.format(format) + val table_create_sql = s"create table if not exists $table_name (" + fields_ + .map(f => s"${f._1} ${f._2}") + .mkString(",") + " ) partitioned by (another_date_field string)" + + s"stored as $format" + val insert_sql = + s"insert overwrite table $table_name " + + "partition(another_date_field = '2020-01-01') select " + + fields_.keys.mkString(",") + " from (select " + fields_.keys.mkString( + ",") + ", row_number() over (order by int_field desc) as rn " + + "from origin_table where float_field > 3 ) tt where rn <= 100" + (table_name, table_create_sql, insert_sql) + } - spark.sql(table_create_sql) - spark.sql( - s"insert overwrite table $table_name " + - "partition(another_date_field = '2020-01-01') select " - + fields.keys.mkString(",") + " from (select " + fields.keys.mkString( - ",") + ", row_number() over (order by int_field desc) as rn " + - "from origin_table where float_field > 3 ) tt where rn <= 100") + def nativeFormatWrite(format: String): Unit = { + val (table_name, table_create_sql, insert_sql) = destination(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) val files = recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) .filter(_.getName.startsWith("part")) assert(files.length == 1) assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01")) } } + + nativeWriteWithOriginalView( + genTestData(), + "origin_table", + ("spark.sql.hive.convertMetastoreParquet", "false"), + ("spark.sql.hive.convertMetastoreOrc", "false"))(nativeFormatWrite) } test("test 1-col partitioned table") { + nativeWrite { + format => + { + val table_name = table_name_template.format(format) + val table_create_sql = + s"create table if not exists $table_name (" + + fields_ + .filterNot(e => e._1.equals("date_field")) + .map(f => s"${f._1} ${f._2}") + .mkString(",") + + " ) partitioned by (date_field date) " + + s"stored as $format" - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - val table_create_sql = - s"create table if not exists $table_name (" + - fields - .filterNot(e => e._1.equals("date_field")) - .map(f => s"${f._1} ${f._2}") - .mkString(",") + - " ) partitioned by (date_field date) " + - s"stored as $format" - - writeAndCheckRead( - table_name, - writeIntoNewTableWithSql(table_name, table_create_sql), - fields.keys.toSeq) - } + writeAndCheckRead( + table_name, + writeIntoNewTableWithSql(table_name, table_create_sql), + fields_.keys.toSeq) + } } } // even if disable native writer, this UT fail, spark bug??? ignore("test 1-col partitioned table, partitioned by already ordered column") { withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) { - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) val originDF = spark.createDataFrame(genTestData()) originDF.createOrReplaceTempView("origin_table") @@ -410,7 +370,7 @@ class GlutenClickHouseNativeWriteTableSuite val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + - fields + fields_ .filterNot(e => e._1.equals("date_field")) .map(f => s"${f._1} ${f._2}") .mkString(",") + @@ -420,31 +380,27 @@ class GlutenClickHouseNativeWriteTableSuite spark.sql(s"drop table IF EXISTS $table_name") spark.sql(table_create_sql) spark.sql( - s"insert overwrite $table_name select ${fields.mkString(",")}" + + s"insert overwrite $table_name select ${fields_.mkString(",")}" + s" from origin_table order by date_field") } } } test("test 2-col partitioned table") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date"), - ("byte_field", "byte") - ) - - for (format <- formats) { + val fields: ListMap[String, String] = ListMap( + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)"), + ("date_field", "date"), + ("byte_field", "byte") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -458,7 +414,6 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } @@ -506,25 +461,21 @@ class GlutenClickHouseNativeWriteTableSuite // This test case will be failed with incorrect result randomly, ignore first. ignore("test hive parquet/orc table, all columns being partitioned. ") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("date_field", "date"), - ("timestamp_field", "timestamp"), - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)") - ) - - for (format <- formats) { + val fields: ListMap[String, String] = ListMap( + ("date_field", "date"), + ("timestamp_field", "timestamp"), + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("byte_field", "byte"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -540,20 +491,15 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } - test(("test hive parquet/orc table with aggregated results")) { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("sum(int_field)", "bigint") - ) - - for (format <- formats) { + test("test hive parquet/orc table with aggregated results") { + val fields: ListMap[String, String] = ListMap( + ("sum(int_field)", "bigint") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -566,29 +512,12 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } test("test 1-col partitioned + 1-col bucketed table") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { + nativeWrite { + format => // spark write does not support bucketed table // https://issues.apache.org/jira/browse/SPARK-19256 val table_name = table_name_template.format(format) @@ -604,7 +533,7 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "byte_field") .saveAsTable(table_name) }, - fields.keys.toSeq + fields_.keys.toSeq ) assert( @@ -614,10 +543,8 @@ class GlutenClickHouseNativeWriteTableSuite .filter(!_.getName.equals("date_field=__HIVE_DEFAULT_PARTITION__")) .head .listFiles() - .filter(!_.isHidden) - .length == 2 + .count(!_.isHidden) == 2 ) // 2 bucket files - } } } @@ -745,8 +672,8 @@ class GlutenClickHouseNativeWriteTableSuite } test("test consecutive blocks having same partition value") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -760,15 +687,14 @@ class GlutenClickHouseNativeWriteTableSuite .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select sum(id) from " + table_name).collect().apply(0).apply(0) + val ret = spark.sql(s"select sum(id) from $table_name").collect().apply(0).apply(0) assert(ret == 449985000) - } } } test("test decimal with rand()") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") spark @@ -778,32 +704,30 @@ class GlutenClickHouseNativeWriteTableSuite .format(format) .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select max(p) from " + table_name).collect().apply(0).apply(0) - } + val ret = spark.sql(s"select max(p) from $table_name").collect().apply(0).apply(0) } } test("test partitioned by constant") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { - spark.sql(s"drop table IF EXISTS tmp_123_$format") - spark.sql( - s"create table tmp_123_$format(" + - s"x1 string, x2 bigint,x3 string, x4 bigint, x5 string )" + - s"partitioned by (day date) stored as $format") - - spark.sql( - s"insert into tmp_123_$format partition(day) " + - "select cast(id as string), id, cast(id as string), id, cast(id as string), " + - "'2023-05-09' from range(10000000)") - } + nativeWrite2 { + format => + val table_name = s"tmp_123_$format" + val create_sql = + s"""create table tmp_123_$format( + |x1 string, x2 bigint,x3 string, x4 bigint, x5 string ) + |partitioned by (day date) stored as $format""".stripMargin + val insert_sql = + s"""insert into tmp_123_$format partition(day) + |select cast(id as string), id, cast(id as string), + | id, cast(id as string), '2023-05-09' + |from range(10000000)""".stripMargin + (table_name, create_sql, insert_sql) } } test("test bucketed by constant") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -815,15 +739,13 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(10000000)(spark.table(table_name).count()) } } test("test consecutive null values being partitioned") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -835,14 +757,13 @@ class GlutenClickHouseNativeWriteTableSuite .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(30000)(spark.table(table_name).count()) } } test("test consecutive null values being bucketed") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -854,78 +775,79 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(30000)(spark.table(table_name).count()) } } test("test native write with empty dataset") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite2( + format => { val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, str string) stored as $format") - spark.sql( - s"insert into $table_name select id, cast(id as string) from range(10)" + - " where id > 100") + ( + table_name, + s"create table $table_name (id int, str string) stored as $format", + s"insert into $table_name select id, cast(id as string) from range(10) where id > 100" + ) + }, + (table_name, _, _) => { + assertResult(0)(spark.table(table_name).count()) } - } + ) } test("test native write with union") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, str string) stored as $format") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string) from range(10) union all " + - "select 10, '10' from range(10)") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string) from range(10) union all " + - "select 10, cast(id as string) from range(10)") - - } + withDestinationTable( + table_name, + s"create table $table_name (id int, str string) stored as $format") { + checkNativeWrite( + s"insert overwrite table $table_name " + + "select id, cast(id as string) from range(10) union all " + + "select 10, '10' from range(10)", + checkNative = true) + checkNativeWrite( + s"insert overwrite table $table_name " + + "select id, cast(id as string) from range(10) union all " + + "select 10, cast(id as string) from range(10)", + checkNative = true + ) + } } } test("test native write and non-native read consistency") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { - val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, name string, info char(4)) stored as $format") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string), concat('aaa', cast(id as string)) from range(10)") + nativeWrite2( + { + format => + val table_name = "t_" + format + ( + table_name, + s"create table $table_name (id int, name string, info char(4)) stored as $format", + s"insert overwrite table $table_name " + + "select id, cast(id as string), concat('aaa', cast(id as string)) from range(10)" + ) + }, + (table_name, _, _) => compareResultsAgainstVanillaSpark( s"select * from $table_name", compareResult = true, _ => {}) - } - } + ) } test("GLUTEN-4316: fix crash on dynamic partition inserting") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - formats.foreach( - format => { - val tbl = "t_" + format - spark.sql(s"drop table IF EXISTS $tbl") - val sql1 = - s"create table $tbl(a int, b map, c struct) " + - s"partitioned by (day string) stored as $format" - val sql2 = s"insert overwrite $tbl partition (day) " + - s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + - s"struct('1', null) as c, '2024-01-08' as day from range(10)" - spark.sql(sql1) - spark.sql(sql2) - }) + nativeWrite2 { + format => + val tbl = "t_" + format + val sql1 = + s"create table $tbl(a int, b map, c struct) " + + s"partitioned by (day string) stored as $format" + val sql2 = s"insert overwrite $tbl partition (day) " + + s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + + s"struct('1', null) as c, '2024-01-08' as day from range(10)" + (tbl, sql1, sql2) } } - } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala index 09fa3ff109f2..1b3df81667a0 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala @@ -46,7 +46,7 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite .set("spark.io.compression.codec", "LZ4") .set("spark.sql.shuffle.partitions", "1") .set("spark.sql.autoBroadcastJoinThreshold", "10MB") - .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG") + // .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG") .set( "spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size", s"$parquetMaxBlockSize") diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala new file mode 100644 index 000000000000..79616d52d0bc --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.gluten + +import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.FakeRowAdaptor +import org.apache.spark.sql.util.QueryExecutionListener + +trait NativeWriteChecker extends GlutenClickHouseWholeStageTransformerSuite { + + def checkNativeWrite(sqlStr: String, checkNative: Boolean): Unit = { + var nativeUsed = false + + val queryListener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + if (!nativeUsed) { + nativeUsed = if (isSparkVersionGE("3.4")) { + false + } else { + qe.executedPlan.find(_.isInstanceOf[FakeRowAdaptor]).isDefined + } + } + } + } + + try { + spark.listenerManager.register(queryListener) + spark.sql(sqlStr) + spark.sparkContext.listenerBus.waitUntilEmpty() + assertResult(checkNative)(nativeUsed) + } finally { + spark.listenerManager.unregister(queryListener) + } + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 1f868c4c2044..7b8d523a6d27 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -827,15 +827,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { buf.result } - /** - * Generate extended columnar post-rules. - * - * @return - */ - override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { - SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() - } - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List(ArrowConvertorRule) } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 94cd38003bad..ae3f6dbd5208 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -77,6 +77,7 @@ namespace ErrorCodes { extern const int BAD_ARGUMENTS; extern const int UNKNOWN_TYPE; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; } } @@ -466,17 +467,17 @@ String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline) using namespace DB; -std::map BackendInitializerUtil::getBackendConfMap(std::string * plan) +std::map BackendInitializerUtil::getBackendConfMap(const std::string & plan) { std::map ch_backend_conf; - if (plan == nullptr) + if (plan.empty()) return ch_backend_conf; /// Parse backend configs from plan extensions do { auto plan_ptr = std::make_unique(); - auto success = plan_ptr->ParseFromString(*plan); + auto success = plan_ptr->ParseFromString(plan); if (!success) break; @@ -841,14 +842,8 @@ void BackendInitializerUtil::initCompiledExpressionCache(DB::Context::Configurat #endif } -void BackendInitializerUtil::init_json(std::string * plan_json) -{ - auto plan_ptr = std::make_unique(); - google::protobuf::util::JsonStringToMessage(plan_json->c_str(), plan_ptr.get()); - return init(new String(plan_ptr->SerializeAsString())); -} -void BackendInitializerUtil::init(std::string * plan) +void BackendInitializerUtil::init(const std::string & plan) { std::map backend_conf_map = getBackendConfMap(plan); DB::Context::ConfigurationPtr config = initConfig(backend_conf_map); @@ -906,7 +901,7 @@ void BackendInitializerUtil::init(std::string * plan) }); } -void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, std::string * plan) +void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, const std::string & plan) { std::map backend_conf_map = getBackendConfMap(plan); diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 94e0f0168e11..245d7b3d15c4 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -137,9 +137,8 @@ class BackendInitializerUtil /// Initialize two kinds of resources /// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime /// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver - static void init(std::string * plan); - static void init_json(std::string * plan_json); - static void updateConfig(const DB::ContextMutablePtr &, std::string *); + static void init(const std::string & plan); + static void updateConfig(const DB::ContextMutablePtr &, const std::string &); // use excel text parser @@ -196,7 +195,7 @@ class BackendInitializerUtil static void updateNewSettings(const DB::ContextMutablePtr &, const DB::Settings &); - static std::map getBackendConfMap(std::string * plan); + static std::map getBackendConfMap(const std::string & plan); inline static std::once_flag init_flag; inline static Poco::Logger * logger; @@ -283,10 +282,7 @@ class ConcurrentDeque return deq.empty(); } - std::deque unsafeGet() - { - return deq; - } + std::deque unsafeGet() { return deq; } private: std::deque deq; diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp index 2b4eb824a5fd..5bb66e4b3f9d 100644 --- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp +++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp @@ -453,7 +453,7 @@ std::unique_ptr CHColumnToSparkRow::convertCHColumnToSparkRow(cons if (!block.columns()) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty columns"); std::unique_ptr spark_row_info = std::make_unique(block, masks); - spark_row_info->setBufferAddress(reinterpret_cast(alloc(spark_row_info->getTotalBytes(), 64))); + spark_row_info->setBufferAddress(static_cast(alloc(spark_row_info->getTotalBytes(), 64))); // spark_row_info->setBufferAddress(alignedAlloc(spark_row_info->getTotalBytes(), 64)); memset(spark_row_info->getBufferAddress(), 0, spark_row_info->getTotalBytes()); for (auto col_idx = 0; col_idx < spark_row_info->getNumCols(); col_idx++) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 70db692c8009..3115950cdf09 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -87,14 +87,14 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; - extern const int NO_SUCH_DATA_PART; - extern const int UNKNOWN_FUNCTION; - extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int INVALID_JOIN_ON_EXPRESSION; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; +extern const int NO_SUCH_DATA_PART; +extern const int UNKNOWN_FUNCTION; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int INVALID_JOIN_ON_EXPRESSION; } } @@ -144,16 +144,13 @@ void SerializedPlanParser::parseExtensions( if (extension.has_extension_function()) { function_mapping.emplace( - std::to_string(extension.extension_function().function_anchor()), - extension.extension_function().name()); + std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name()); } } } std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( - const std::vector & expressions, - const Block & header, - const Block & read_schema) + const std::vector & expressions, const Block & header, const Block & read_schema) { auto actions_dag = std::make_shared(blockToNameAndTypeList(header)); NamesWithAliases required_columns; @@ -259,8 +256,8 @@ std::string getDecimalFunction(const substrait::Type_Decimal & decimal, bool nul bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel) { - return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with( - "iterator"); + return rel.has_local_files() && rel.local_files().items().size() == 1 + && rel.local_files().items().at(0).uri_file().starts_with("iterator"); } bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel) @@ -380,13 +377,13 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) return nested_type; } -QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr plan) +QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan) { - logDebugMessage(*plan, "substrait plan"); - parseExtensions(plan->extensions()); - if (plan->relations_size() == 1) + logDebugMessage(plan, "substrait plan"); + parseExtensions(plan.extensions()); + if (plan.relations_size() == 1) { - auto root_rel = plan->relations().at(0); + auto root_rel = plan.relations().at(0); if (!root_rel.has_root()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!"); @@ -587,9 +584,7 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co { if (args.size() != 2) throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Spark function extract requires two args, function:{}", - function.ShortDebugString()); + ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString()); // Get the first arg: field const auto & extract_field = args.at(0); @@ -705,9 +700,7 @@ void SerializedPlanParser::parseArrayJoinArguments( /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1 if (scalar_function.arguments_size() != 1) throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Argument number of arrayJoin should be 1 instead of {}", - scalar_function.arguments_size()); + ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 instead of {}", scalar_function.arguments_size()); auto function_name_copy = function_name; parseFunctionArguments(actions_dag, parsed_args, function_name_copy, scalar_function); @@ -746,11 +739,7 @@ void SerializedPlanParser::parseArrayJoinArguments( } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( - const substrait::Expression & rel, - std::vector & result_names, - ActionsDAGPtr actions_dag, - bool keep_result, - bool position) + const substrait::Expression & rel, std::vector & result_names, ActionsDAGPtr actions_dag, bool keep_result, bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -774,7 +763,8 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -866,10 +856,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -884,10 +871,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this)) { LOG_DEBUG( - &Poco::Logger::get("SerializedPlanParser"), - "parse function {} by function parser: {}", - func_name, - func_parser->getName()); + &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); const auto * result_node = func_parser->parse(scalar_function, actions_dag); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); @@ -956,12 +940,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared(); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } else if (startsWith(function_signature, "make_decimal:")) @@ -976,12 +958,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared(); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } @@ -999,9 +979,8 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name, CastType::accurateOrNull); } @@ -1011,9 +990,8 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name); } } @@ -1159,9 +1137,7 @@ void SerializedPlanParser::parseFunctionArgument( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( - ActionsDAGPtr & actions_dag, - const std::string & function_name, - const substrait::FunctionArgument & arg) + ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg) { const ActionsDAG::Node * res; if (arg.value().has_scalar_function()) @@ -1189,11 +1165,8 @@ std::pair SerializedPlanParser::convertStructFieldType(const } auto type_id = type->getTypeId(); - if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id == TypeIndex::UInt32 - || type_id == TypeIndex::UInt64) - { + if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id == TypeIndex::UInt32 || type_id == TypeIndex::UInt64) return {type, field}; - } UINT_CONVERT(type, field, Int8) UINT_CONVERT(type, field, Int16) UINT_CONVERT(type, field, Int32) @@ -1203,11 +1176,7 @@ std::pair SerializedPlanParser::convertStructFieldType(const } ActionsDAGPtr SerializedPlanParser::parseFunction( - const Block & header, - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared(blockToNameAndTypeList(header)); @@ -1217,11 +1186,7 @@ ActionsDAGPtr SerializedPlanParser::parseFunction( } ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression( - const Block & header, - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared(blockToNameAndTypeList(header)); @@ -1303,7 +1268,8 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple( = &actions_dag->addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name); auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -1528,9 +1494,7 @@ std::pair SerializedPlanParser::parseLiteral(const substrait } default: { throw Exception( - ErrorCodes::UNKNOWN_TYPE, - "Unsupported spark literal type {}", - magic_enum::enum_name(literal.literal_type_case())); + ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); } } return std::make_pair(std::move(type), std::move(field)); @@ -1732,8 +1696,7 @@ substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(con { substrait::ReadRel::ExtensionTable extension_table; google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(split_info.data()), - static_cast(split_info.size())); + reinterpret_cast(split_info.data()), static_cast(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = extension_table.ParseFromCodedStream(&coded_in); @@ -1747,8 +1710,7 @@ substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std:: { substrait::ReadRel::LocalFiles local_files; google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(split_info.data()), - static_cast(split_info.size())); + reinterpret_cast(split_info.data()), static_cast(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = local_files.ParseFromCodedStream(&coded_in); @@ -1758,10 +1720,44 @@ substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std:: return local_files; } +std::unique_ptr SerializedPlanParser::createExecutor(DB::QueryPlanPtr query_plan) +{ + Stopwatch stopwatch; + auto * logger = &Poco::Logger::get("SerializedPlanParser"); + const Settings & settings = context->getSettingsRef(); + + QueryPriorities priorities; + auto query_status = std::make_shared( + context, + "", + context->getClientInfo(), + priorities.insert(static_cast(settings.priority)), + CurrentThread::getGroup(), + IAST::QueryKind::Select, + settings, + 0); + + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; + auto pipeline_builder = query_plan->buildQueryPipeline( + optimization_settings, + BuildQueryPipelineSettings{ + .actions_settings + = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes}, + .process_list_element = query_status}); + QueryPipeline pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); + LOG_INFO(logger, "build pipeline {} ms", stopwatch.elapsedMicroseconds() / 1000.0); + + LOG_DEBUG( + logger, "clickhouse plan [optimization={}]:\n{}", settings.query_plan_enable_optimizations, PlanUtil::explainPlan(*query_plan)); + LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(pipeline)); + + return std::make_unique( + context, std::move(query_plan), std::move(pipeline), query_plan->getCurrentDataStream().header.cloneEmpty()); +} -QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) +QueryPlanPtr SerializedPlanParser::parse(const std::string_view & plan) { - auto plan_ptr = std::make_unique(); + substrait::Plan s_plan; /// https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data /// Parsing may fail when the number of recursive layers is large. /// Here, set a limit large enough to avoid this problem. @@ -1769,11 +1765,10 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) google::protobuf::io::CodedInputStream coded_in(reinterpret_cast(plan.data()), static_cast(plan.size())); coded_in.SetRecursionLimit(100000); - auto ok = plan_ptr->ParseFromCodedStream(&coded_in); - if (!ok) + if (!s_plan.ParseFromCodedStream(&coded_in)) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); - auto res = parse(std::move(plan_ptr)); + auto res = parse(s_plan); #ifndef NDEBUG PlanUtil::checkOuputType(*res); @@ -1788,17 +1783,16 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) return res; } -QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan) +QueryPlanPtr SerializedPlanParser::parseJson(const std::string_view & json_plan) { - auto plan_ptr = std::make_unique(); - auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json_plan), plan_ptr.get()); + substrait::Plan plan; + auto s = google::protobuf::util::JsonStringToMessage(json_plan, &plan); if (!s.ok()) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from json string failed: {}", s.ToString()); - return parse(std::move(plan_ptr)); + return parse(plan); } -SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) - : context(context_) +SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) { } @@ -1807,13 +1801,10 @@ ContextMutablePtr SerializedPlanParser::global_context = nullptr; Context::ConfigurationPtr SerializedPlanParser::config = nullptr; void SerializedPlanParser::collectJoinKeys( - const substrait::Expression & condition, - std::vector> & join_keys, - int32_t right_key_start) + const substrait::Expression & condition, std::vector> & join_keys, int32_t right_key_start) { auto condition_name = getFunctionName( - function_mapping.at(std::to_string(condition.scalar_function().function_reference())), - condition.scalar_function()); + function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function()); if (condition_name == "and") { collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start); @@ -1863,8 +1854,8 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & auto substrait_name = function_signature.substr(0, function_signature.find(':')); auto func_parser = FunctionParserFactory::instance().tryGet(substrait_name, plan_parser); - String function_name = func_parser ? func_parser->getName() - : SerializedPlanParser::getFunctionName(function_signature, scalar_function); + String function_name + = func_parser ? func_parser->getName() : SerializedPlanParser::getFunctionName(function_signature, scalar_function); ASTs ast_args; parseFunctionArgumentsToAST(names, scalar_function, ast_args); @@ -1876,9 +1867,7 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & } void ASTParser::parseFunctionArgumentsToAST( - const Names & names, - const substrait::Expression_ScalarFunction & scalar_function, - ASTs & ast_args) + const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args) { const auto & args = scalar_function.arguments(); @@ -2021,12 +2010,12 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expre } } -void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAGPtr actions_dag) +void SerializedPlanParser::removeNullableForRequiredColumns( + const std::set & require_columns, const ActionsDAGPtr & actions_dag) const { for (const auto & item : require_columns) { - const auto * require_node = actions_dag->tryFindInOutputs(item); - if (require_node) + if (const auto * require_node = actions_dag->tryFindInOutputs(item)) { auto function_builder = FunctionFactory::instance().get("assumeNotNull", context); ActionsDAG::NodeRawConstPtrs args = {require_node}; @@ -2037,9 +2026,7 @@ void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & columns, - ActionsDAGPtr actions_dag, - std::map & nullable_measure_names) + const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names) { for (const auto & item : columns) { @@ -2092,86 +2079,23 @@ LocalExecutor::~LocalExecutor() } } - -void LocalExecutor::execute(QueryPlanPtr query_plan) -{ - Stopwatch stopwatch; - - const Settings & settings = context->getSettingsRef(); - current_query_plan = std::move(query_plan); - auto * logger = &Poco::Logger::get("LocalExecutor"); - - QueryPriorities priorities; - auto query_status = std::make_shared( - context, - "", - context->getClientInfo(), - priorities.insert(static_cast(settings.priority)), - CurrentThread::getGroup(), - IAST::QueryKind::Select, - settings, - 0); - - QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; - auto pipeline_builder = current_query_plan->buildQueryPipeline( - optimization_settings, - BuildQueryPipelineSettings{ - .actions_settings - = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, - .compile_expressions = CompileExpressions::yes}, - .process_list_element = query_status}); - - LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}", PlanUtil::explainPlan(*current_query_plan)); - query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); - LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(query_pipeline)); - auto t_pipeline = stopwatch.elapsedMicroseconds(); - - executor = std::make_unique(query_pipeline); - auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline; - stopwatch.stop(); - LOG_INFO( - logger, - "build pipeline {} ms; create executor {} ms;", - t_pipeline / 1000.0, - t_executor / 1000.0); - - header = current_query_plan->getCurrentDataStream().header.cloneEmpty(); - ch_column_to_spark_row = std::make_unique(); -} - -std::unique_ptr LocalExecutor::writeBlockToSparkRow(Block & block) +std::unique_ptr LocalExecutor::writeBlockToSparkRow(const Block & block) const { return ch_column_to_spark_row->convertCHColumnToSparkRow(block); } bool LocalExecutor::hasNext() { - bool has_next; - try + size_t columns = currentBlock().columns(); + if (columns == 0 || isConsumed()) { - size_t columns = currentBlock().columns(); - if (columns == 0 || isConsumed()) - { - auto empty_block = header.cloneEmpty(); - setCurrentBlock(empty_block); - has_next = executor->pull(currentBlock()); - produce(); - } - else - { - has_next = true; - } - } - catch (Exception & e) - { - LOG_ERROR( - &Poco::Logger::get("LocalExecutor"), - "LocalExecutor run query plan failed with message: {}. Plan Explained: \n{}", - e.message(), - PlanUtil::explainPlan(*current_query_plan)); - throw; + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + bool has_next = executor->pull(currentBlock()); + produce(); + return has_next; } - return has_next; + return true; } SparkRowInfoPtr LocalExecutor::next() @@ -2246,12 +2170,17 @@ Block & LocalExecutor::getHeader() return header; } -LocalExecutor::LocalExecutor(ContextPtr context_) - : context(context_) +LocalExecutor::LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan, QueryPipeline && pipeline, const Block & header_) + : query_pipeline(std::move(pipeline)) + , executor(std::make_unique(query_pipeline)) + , header(header_) + , context(context_) + , ch_column_to_spark_row(std::make_unique()) + , current_query_plan(std::move(query_plan)) { } -std::string LocalExecutor::dumpPipeline() +std::string LocalExecutor::dumpPipeline() const { const auto & processors = query_pipeline.getProcessors(); for (auto & processor : processors) @@ -2275,12 +2204,8 @@ std::string LocalExecutor::dumpPipeline() } NonNullableColumnsResolver::NonNullableColumnsResolver( - const Block & header_, - SerializedPlanParser & parser_, - const substrait::Expression & cond_rel_) - : header(header_) - , parser(parser_) - , cond_rel(cond_rel_) + const Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_) + : header(header_), parser(parser_), cond_rel(cond_rel_) { } @@ -2352,8 +2277,7 @@ void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression & } std::string NonNullableColumnsResolver::safeGetFunctionName( - const std::string & function_signature, - const substrait::Expression_ScalarFunction & function) + const std::string & function_signature, const substrait::Expression_ScalarFunction & function) const { try { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 71cdca58a6ce..82e8c4077841 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -218,6 +218,7 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c); class SerializedPlanParser; +class LocalExecutor; // Give a condition expression `cond_rel_`, found all columns with nullability that must not containt // null after this filter. @@ -241,7 +242,7 @@ class NonNullableColumnsResolver void visit(const substrait::Expression & expr); void visitNonNullable(const substrait::Expression & expr); - String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function); + String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function) const; }; class SerializedPlanParser @@ -257,11 +258,21 @@ class SerializedPlanParser friend class JoinRelParser; friend class MergeTreeRelParser; + std::unique_ptr createExecutor(DB::QueryPlanPtr query_plan); + + DB::QueryPlanPtr parse(const std::string_view & plan); + DB::QueryPlanPtr parse(const substrait::Plan & plan); + public: explicit SerializedPlanParser(const ContextPtr & context); - DB::QueryPlanPtr parse(const std::string & plan); - DB::QueryPlanPtr parseJson(const std::string & json_plan); - DB::QueryPlanPtr parse(std::unique_ptr plan); + + /// UT only + DB::QueryPlanPtr parseJson(const std::string_view & json_plan); + std::unique_ptr createExecutor(const substrait::Plan & plan) { return createExecutor(parse((plan))); } + /// + + template + std::unique_ptr createExecutor(const std::string_view & plan); DB::QueryPlanStepPtr parseReadRealWithLocalFile(const substrait::ReadRel & rel); DB::QueryPlanStepPtr parseReadRealWithJavaIter(const substrait::ReadRel & rel); @@ -372,7 +383,7 @@ class SerializedPlanParser const ActionsDAG::Node * toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args); // remove nullable after isNotNull - void removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAGPtr actions_dag); + void removeNullableForRequiredColumns(const std::set & require_columns, const ActionsDAGPtr & actions_dag) const; std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } static std::pair parseLiteral(const substrait::Expression_Literal & literal); void wrapNullable( @@ -394,6 +405,12 @@ class SerializedPlanParser const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field); }; +template +std::unique_ptr SerializedPlanParser::createExecutor(const std::string_view & plan) +{ + return createExecutor(JsonPlan ? parseJson(plan) : parse(plan)); +} + struct SparkBuffer { char * address; @@ -403,16 +420,14 @@ struct SparkBuffer class LocalExecutor : public BlockIterator { public: - LocalExecutor() = default; - explicit LocalExecutor(ContextPtr context); + LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan, QueryPipeline && pipeline, const Block & header_); ~LocalExecutor(); - void execute(QueryPlanPtr query_plan); SparkRowInfoPtr next(); Block * nextColumnar(); bool hasNext(); - /// Stop execution and wait for pipeline exit, used when task receives shutdown command or executor receives SIGTERM signal + /// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal void cancel(); Block & getHeader(); @@ -425,13 +440,13 @@ class LocalExecutor : public BlockIterator static void removeExecutor(Int64 handle); private: - std::unique_ptr writeBlockToSparkRow(DB::Block & block); + std::unique_ptr writeBlockToSparkRow(const DB::Block & block) const; void asyncCancel(); void waitCancelFinished(); /// Dump processor runtime information to log - std::string dumpPipeline(); + std::string dumpPipeline() const; QueryPipeline query_pipeline; std::unique_ptr executor; @@ -439,7 +454,7 @@ class LocalExecutor : public BlockIterator ContextPtr context; std::unique_ptr ch_column_to_spark_row; std::unique_ptr spark_buffer; - DB::QueryPlanPtr current_query_plan; + QueryPlanPtr current_query_plan; RelMetricPtr metric; std::vector extra_plan_holder; std::atomic is_cancelled{false}; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index bbc467879182..9c642d70ec27 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -224,11 +224,9 @@ JNIEXPORT void JNI_OnUnload(JavaVM * vm, void * /*reserved*/) JNIEXPORT void Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative(JNIEnv * env, jobject, jbyteArray conf_plan) { LOCAL_ENGINE_JNI_METHOD_START - jsize plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan); jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string plan_str; - plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); - local_engine::BackendInitializerUtil::init(&plan_str); + local_engine::BackendInitializerUtil::init({reinterpret_cast(plan_buf_addr), plan_buf_size}); env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT); LOCAL_ENGINE_JNI_METHOD_END(env, ) } @@ -254,11 +252,9 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ auto query_context = local_engine::getAllocator(allocator_id)->query_context; // by task update new configs ( in case of dynamic config update ) - jsize plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan); jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string plan_str; - plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); - local_engine::BackendInitializerUtil::updateConfig(query_context, &plan_str); + local_engine::BackendInitializerUtil::updateConfig(query_context, {reinterpret_cast(plan_buf_addr), plan_buf_size}); local_engine::SerializedPlanParser parser(query_context); jsize iter_num = env->GetArrayLength(iter_arr); @@ -277,17 +273,14 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ parser.addSplitInfo(std::string{reinterpret_cast(split_info_addr), split_info_size}); } - jsize plan_size = env->GetArrayLength(plan); + std::string::size_type plan_size = env->GetArrayLength(plan); jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); - std::string plan_string; - plan_string.assign(reinterpret_cast(plan_address), plan_size); - auto query_plan = parser.parse(plan_string); - local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(query_context); + local_engine::LocalExecutor * executor + = parser.createExecutor({reinterpret_cast(plan_address), plan_size}).release(); local_engine::LocalExecutor::addExecutor(executor); - LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); + LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); executor->setMetric(parser.getMetric()); executor->setExtraPlanHolder(parser.extra_plan_holder); - executor->execute(std::move(query_plan)); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT); return reinterpret_cast(executor); @@ -932,11 +925,10 @@ JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniW LOCAL_ENGINE_JNI_METHOD_START auto query_context = local_engine::getAllocator(allocator_id)->query_context; // by task update new configs ( in case of dynamic config update ) - jsize conf_plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type conf_plan_buf_size = env->GetArrayLength(conf_plan); jbyte * conf_plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string conf_plan_str; - conf_plan_str.assign(reinterpret_cast(conf_plan_buf_addr), conf_plan_buf_size); - local_engine::BackendInitializerUtil::updateConfig(query_context, &conf_plan_str); + local_engine::BackendInitializerUtil::updateConfig( + query_context, {reinterpret_cast(conf_plan_buf_addr), conf_plan_buf_size}); const auto uuid_str = jstring2string(env, uuid_); const auto task_id = jstring2string(env, task_id_); @@ -1329,14 +1321,11 @@ Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE local_engine::SerializedPlanParser parser(context); jobject iter = env->NewGlobalRef(input); parser.addInputIter(iter, false); - jsize plan_size = env->GetArrayLength(plan); + std::string::size_type plan_size = env->GetArrayLength(plan); jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); - std::string plan_string; - plan_string.assign(reinterpret_cast(plan_address), plan_size); - auto query_plan = parser.parse(plan_string); - local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(context); + local_engine::LocalExecutor * executor + = parser.createExecutor({reinterpret_cast(plan_address), plan_size}).release(); local_engine::LocalExecutor::addExecutor(executor); - executor->execute(std::move(query_plan)); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); return reinterpret_cast(executor); LOCAL_ENGINE_JNI_METHOD_END(env, -1) diff --git a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp index 89fa4fa961ea..208a3b518d45 100644 --- a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp +++ b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp @@ -154,14 +154,11 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); - } + + while (local_executor->hasNext()) + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); } } @@ -212,13 +209,12 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) + + while (local_executor->hasNext()) { - Block * block = local_executor.nextColumnar(); + Block * block = local_executor->nextColumnar(); delete block; } } @@ -238,15 +234,10 @@ DB::ContextMutablePtr global_context; std::ifstream t(path); std::string str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); std::cout << "the plan from: " << path << std::endl; - - auto query_plan = parser.parse(str); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(str); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - [[maybe_unused]] auto * x = local_executor.nextColumnar(); - } + while (local_executor->hasNext()) [[maybe_unused]] + auto * x = local_executor->nextColumnar(); } } @@ -282,14 +273,12 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); - } + + while (local_executor->hasNext()) + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); } } @@ -320,16 +309,13 @@ DB::ContextMutablePtr global_context; .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; - - local_executor.execute(std::move(query_plan)); + auto local_executor = parser.createExecutor(*plan); local_engine::SparkRowToCHColumn converter; - while (local_executor.hasNext()) + while (local_executor->hasNext()) { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); state.ResumeTiming(); - auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor->getHeader()); state.PauseTiming(); } state.ResumeTiming(); @@ -368,16 +354,13 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; - - local_executor.execute(std::move(query_plan)); + auto local_executor = parser.createExecutor(*plan); local_engine::SparkRowToCHColumn converter; - while (local_executor.hasNext()) + while (local_executor->hasNext()) { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); state.ResumeTiming(); - auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor->getHeader()); state.PauseTiming(); } state.ResumeTiming(); @@ -485,12 +468,8 @@ DB::ContextMutablePtr global_context; y.reserve(cnt); for (auto _ : state) - { for (i = 0; i < cnt; i++) - { y[i] = add(x[i], i); - } - } } [[maybe_unused]] static void BM_TestSumInline(benchmark::State & state) @@ -504,12 +483,8 @@ DB::ContextMutablePtr global_context; y.reserve(cnt); for (auto _ : state) - { for (i = 0; i < cnt; i++) - { y[i] = x[i] + i; - } - } } [[maybe_unused]] static void BM_TestPlus(benchmark::State & state) @@ -545,9 +520,7 @@ DB::ContextMutablePtr global_context; block.insert(y); auto executable_function = function->prepare(arguments); for (auto _ : state) - { auto result = executable_function->execute(block.getColumnsWithTypeAndName(), type, rows, false); - } } [[maybe_unused]] static void BM_TestPlusEmbedded(benchmark::State & state) @@ -847,9 +820,7 @@ QueryPlanPtr joinPlan(QueryPlanPtr left, QueryPlanPtr right, String left_key, St ASTPtr rkey = std::make_shared(right_key); join->addOnKeys(lkey, rkey, true); for (const auto & column : join->columnsFromJoinedTable()) - { join->addJoinedColumn(column); - } auto left_keys = left->getCurrentDataStream().header.getNamesAndTypesList(); join->addJoinedColumnsAndCorrectTypes(left_keys, true); @@ -920,7 +891,8 @@ BENCHMARK(BM_ParquetRead)->Unit(benchmark::kMillisecond)->Iterations(10); int main(int argc, char ** argv) { - BackendInitializerUtil::init(nullptr); + std::string empty; + BackendInitializerUtil::init(empty); SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); }); ::benchmark::Initialize(&argc, argv); diff --git a/cpp-ch/local-engine/tests/gluten_test_util.h b/cpp-ch/local-engine/tests/gluten_test_util.h index d4c16e9fbbd8..dba4496d6221 100644 --- a/cpp-ch/local-engine/tests/gluten_test_util.h +++ b/cpp-ch/local-engine/tests/gluten_test_util.h @@ -24,6 +24,7 @@ #include #include #include +#include #include using BlockRowType = DB::ColumnsWithTypeAndName; @@ -60,6 +61,23 @@ AnotherRowType readParquetSchema(const std::string & file); DB::ActionsDAGPtr parseFilter(const std::string & filter, const AnotherRowType & name_and_types); +namespace pb_util +{ +template +std::string JsonStringToBinary(const std::string_view & json) +{ + Message message; + std::string binary; + auto s = google::protobuf::util::JsonStringToMessage(json, &message); + if (!s.ok()) + { + const std::string err_msg{s.message()}; + throw std::runtime_error(err_msg); + } + message.SerializeToString(&binary); + return binary; +} +} } inline DB::DataTypePtr BIGINT() diff --git a/cpp-ch/local-engine/tests/gtest_local_engine.cpp b/cpp-ch/local-engine/tests/gtest_local_engine.cpp index 2d1807841041..962bf9def52e 100644 --- a/cpp-ch/local-engine/tests/gtest_local_engine.cpp +++ b/cpp-ch/local-engine/tests/gtest_local_engine.cpp @@ -16,9 +16,12 @@ */ #include #include +#include +#include + #include -#include #include +#include #include #include #include @@ -28,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -84,13 +86,23 @@ TEST(ReadBufferFromFile, seekBackwards) ASSERT_EQ(x, 8); } +INCBIN(resource_embedded_config_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/gtest_local_engine_config.json"); + +namespace DB +{ +void registerOutputFormatParquet(DB::FormatFactory & factory); +} + int main(int argc, char ** argv) { - auto * init = new String("{\"advancedExtensions\":{\"enhancement\":{\"@type\":\"type.googleapis.com/substrait.Expression\",\"literal\":{\"map\":{\"keyValues\":[{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level\"},\"value\":{\"string\":\"trace\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.endpoint\"},\"value\":{\"string\":\"localhost:9000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.velox.IOThreads\"},\"value\":{\"string\":\"0\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations\"},\"value\":{\"string\":\"false\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.worker.id\"},\"value\":{\"string\":\"1\"}},{\"key\":{\"string\":\"spark.memory.offHeap.enabled\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.iam.role.session.name\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.shuffle.codec\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.gluten.memory.offHeap.size.in.bytes\"},\"value\":{\"string\":\"10737418240\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.shuffle.codecBackend\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.sql.orc.compression.codec\"},\"value\":{\"string\":\"snappy\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.input.write.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.secret.key\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.access.key\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity\"},\"value\":{\"string\":\"INFO\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.path.style.access\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.timezone\"},\"value\":{\"string\":\"Asia/Shanghai\"}},{\"key\":{\"string\":\"spark.hadoop.input.read.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.use.instance.credentials\"},\"value\":{\"string\":\"false\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method\"},\"value\":{\"string\":\"snappy\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.iam.role\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.memory.task.offHeap.size.in.bytes\"},\"value\":{\"string\":\"10737418240\"}},{\"key\":{\"string\":\"spark.hadoop.input.connect.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.dfs.client.log.severity\"},\"value\":{\"string\":\"INFO\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver\"},\"value\":{\"string\":\"2\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.connection.ssl.enabled\"},\"value\":{\"string\":\"false\"}}]}}}}}"); + BackendInitializerUtil::init(test::pb_util::JsonStringToBinary( + {reinterpret_cast(gresource_embedded_config_jsonData), gresource_embedded_config_jsonSize})); + + auto & factory = FormatFactory::instance(); + DB::registerOutputFormatParquet(factory); - BackendInitializerUtil::init_json(std::move(init)); SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); }); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/gtest_parser.cpp b/cpp-ch/local-engine/tests/gtest_parser.cpp index cbe41c90c81a..485740191ea3 100644 --- a/cpp-ch/local-engine/tests/gtest_parser.cpp +++ b/cpp-ch/local-engine/tests/gtest_parser.cpp @@ -14,307 +14,140 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include #include -#include #include + using namespace local_engine; using namespace DB; -std::string splitBinaryFromJson(const std::string & json) +// Plan for https://github.com/ClickHouse/ClickHouse/pull/65234 +INCBIN(resource_embedded_pr_65234_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/clickhouse_pr_65234.json"); + +TEST(SerializedPlanParser, PR65234) { - std::string binary; - substrait::ReadRel::LocalFiles local_files; - auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json), &local_files); - local_files.SerializeToString(&binary); - return binary; + const std::string split + = R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-clickhouse/target/scala-2.12/test-classes/tests-working-home/tpch-data/supplier/part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]})"; + SerializedPlanParser parser(SerializedPlanParser::global_context); + parser.addSplitInfo(test::pb_util::JsonStringToBinary(split)); + auto query_plan + = parser.parseJson({reinterpret_cast(gresource_embedded_pr_65234_jsonData), gresource_embedded_pr_65234_jsonSize}); } -std::string JsonPlanFor65234() +#include +#include +#include +#include +#include + +Chunk testChunk() { - // Plan for https://github.com/ClickHouse/ClickHouse/pull/65234 - return R"( + auto nameCol = STRING()->createColumn(); + nameCol->insert("one"); + nameCol->insert("two"); + nameCol->insert("three"); + + auto valueCol = UINT()->createColumn(); + valueCol->insert(1); + valueCol->insert(2); + valueCol->insert(3); + MutableColumns x; + x.push_back(std::move(nameCol)); + x.push_back(std::move(valueCol)); + return {std::move(x), 3}; +} + +TEST(LocalExecutor, StorageObjectStorageSink) { - "extensions": [{ - "extensionFunction": { - "functionAnchor": 1, - "name": "is_not_null:str" - } - }, { - "extensionFunction": { - "functionAnchor": 2, - "name": "equal:str_str" - } - }, { - "extensionFunction": { - "functionAnchor": 3, - "name": "is_not_null:i64" - } - }, { - "extensionFunction": { - "name": "and:bool_bool" - } - }], - "relations": [{ - "root": { - "input": { - "project": { - "common": { - "emit": { - "outputMapping": [2] - } - }, - "input": { - "filter": { - "common": { - "direct": { - } - }, - "input": { - "read": { - "common": { - "direct": { - } - }, - "baseSchema": { - "names": ["r_regionkey", "r_name"], - "struct": { - "types": [{ - "i64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, { - "string": { - "nullability": "NULLABILITY_NULLABLE" - } - }] - }, - "columnTypes": ["NORMAL_COL", "NORMAL_COL"] - }, - "filter": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "functionReference": 1, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 2, - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }, { - "value": { - "literal": { - "string": "EUROPE" - } - } - }] - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 3, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - } - } - } - } - }] - } - } - }] - } - }, - "advancedExtension": { - "optimization": { - "@type": "type.googleapis.com/google.protobuf.StringValue", - "value": "isMergeTree\u003d0\n" - } - } - } - }, - "condition": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "functionReference": 1, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 2, - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }, { - "value": { - "literal": { - "string": "EUROPE" - } - } - }] - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 3, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - } - } - } - } - }] - } - } - }] - } - } - } - }, - "expressions": [{ - "selection": { - "directReference": { - "structField": { - } - } - } - }] - } - }, - "names": ["r_regionkey#72"], - "outputSchema": { - "types": [{ - "i64": { - "nullability": "NULLABILITY_NULLABLE" - } - }], - "nullability": "NULLABILITY_REQUIRED" - } - } - }] + /// 0. Create ObjectStorage for HDFS + auto settings = SerializedPlanParser::global_context->getSettingsRef(); + const std::string query + = R"(CREATE TABLE hdfs_engine_xxxx (name String, value UInt32) ENGINE=HDFS('hdfs://localhost:8020/clickhouse/test2', 'Parquet'))"; + DB::ParserCreateQuery parser; + std::string error_message; + const char * pos = query.data(); + auto ast = DB::tryParseQuery( + parser, + pos, + pos + query.size(), + error_message, + /* hilite = */ false, + "QUERY TEST", + /* allow_multi_statements = */ false, + 0, + settings.max_parser_depth, + settings.max_parser_backtracks, + true); + auto & create = ast->as(); + auto arg = create.storage->children[0]; + const auto * func = arg->as(); + EXPECT_TRUE(func && func->name == "HDFS"); + + DB::StorageHDFSConfiguration config; + StorageObjectStorage::Configuration::initialize(config, arg->children[0]->children, SerializedPlanParser::global_context, false); + + const std::shared_ptr object_storage + = std::dynamic_pointer_cast(config.createObjectStorage(SerializedPlanParser::global_context, false)); + EXPECT_TRUE(object_storage != nullptr); + + RelativePathsWithMetadata files_with_metadata; + object_storage->listObjects("/clickhouse", files_with_metadata, 0); + + /// 1. Create ObjectStorageSink + DB::StorageObjectStorageSink sink{ + object_storage, config.clone(), {}, {{STRING(), "name"}, {UINT(), "value"}}, SerializedPlanParser::global_context, ""}; + + /// 2. Create Chunk + /// 3. comsume + sink.consume(testChunk()); + sink.onFinish(); } -)"; + +namespace DB +{ +SinkToStoragePtr createFilelinkSink( + const StorageMetadataPtr & metadata_snapshot, + const String & table_name_for_log, + const String & path, + CompressionMethod compression_method, + const std::optional & format_settings, + const String & format_name, + const ContextPtr & context, + int flags); } -TEST(SerializedPlanParser, PR65234) +INCBIN(resource_embedded_readcsv_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/read_student_option_schema.csv.json"); +TEST(LocalExecutor, StorageFileSink) { const std::string split - = R"({"items":[{"uriFile":"file:///part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]}")"; + = R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-velox/src/test/resources/datasource/csv/student_option_schema.csv","length":"56","text":{"fieldDelimiter":",","maxBlockSize":"8192","header":"1"},"schema":{"names":["id","name","language"],"struct":{"types":[{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}}]}},"metadataColumns":[{}]}]})"; SerializedPlanParser parser(SerializedPlanParser::global_context); - parser.addSplitInfo(splitBinaryFromJson(split)); - parser.parseJson(JsonPlanFor65234()); -} + parser.addSplitInfo(test::pb_util::JsonStringToBinary(split)); + auto local_executor = parser.createExecutor( + {reinterpret_cast(gresource_embedded_readcsv_jsonData), gresource_embedded_readcsv_jsonSize}); + + while (local_executor->hasNext()) + { + const Block & x = *local_executor->nextColumnar(); + EXPECT_EQ(4, x.rows()); + } + + StorageInMemoryMetadata metadata; + metadata.setColumns(ColumnsDescription::fromNamesAndTypes({{"name", STRING()}, {"value", UINT()}})); + StorageMetadataPtr metadata_ptr = std::make_shared(metadata); + + auto sink = createFilelinkSink( + metadata_ptr, + "test_table", + "/tmp/test_table.parquet", + CompressionMethod::None, + {}, + "Parquet", + SerializedPlanParser::global_context, + 0); + + sink->consume(testChunk()); + sink->onFinish(); +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json new file mode 100644 index 000000000000..1c37b68b7144 --- /dev/null +++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json @@ -0,0 +1,273 @@ +{ + "extensions": [{ + "extensionFunction": { + "functionAnchor": 1, + "name": "is_not_null:str" + } + }, { + "extensionFunction": { + "functionAnchor": 2, + "name": "equal:str_str" + } + }, { + "extensionFunction": { + "functionAnchor": 3, + "name": "is_not_null:i64" + } + }, { + "extensionFunction": { + "name": "and:bool_bool" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["r_regionkey", "r_name"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }] + }, + "columnTypes": ["NORMAL_COL", "NORMAL_COL"] + }, + "filter": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + } + } + } + }] + } + } + }] + } + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree\u003d0\n" + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + } + } + }] + } + }, + "names": ["r_regionkey#72"], + "outputSchema": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + } + }] +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json new file mode 100644 index 000000000000..10f0ea3dfdad --- /dev/null +++ b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json @@ -0,0 +1,269 @@ +{ + "advancedExtensions": { + "enhancement": { + "@type": "type.googleapis.com/substrait.Expression", + "literal": { + "map": { + "keyValues": [ + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level" + }, + "value": { + "string": "test" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort" + }, + "value": { + "string": "5368709120" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.endpoint" + }, + "value": { + "string": "localhost:9000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.velox.IOThreads" + }, + "value": { + "string": "0" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations" + }, + "value": { + "string": "false" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.worker.id" + }, + "value": { + "string": "1" + } + }, + { + "key": { + "string": "spark.memory.offHeap.enabled" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.iam.role.session.name" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.shuffle.codec" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.gluten.memory.offHeap.size.in.bytes" + }, + "value": { + "string": "10737418240" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.shuffle.codecBackend" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.sql.orc.compression.codec" + }, + "value": { + "string": "snappy" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by" + }, + "value": { + "string": "5368709120" + } + }, + { + "key": { + "string": "spark.hadoop.input.write.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.secret.key" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.access.key" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity" + }, + "value": { + "string": "INFO" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.path.style.access" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.timezone" + }, + "value": { + "string": "Asia/Shanghai" + } + }, + { + "key": { + "string": "spark.hadoop.input.read.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.use.instance.credentials" + }, + "value": { + "string": "false" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method" + }, + "value": { + "string": "snappy" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.iam.role" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.memory.task.offHeap.size.in.bytes" + }, + "value": { + "string": "10737418240" + } + }, + { + "key": { + "string": "spark.hadoop.input.connect.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.dfs.client.log.severity" + }, + "value": { + "string": "INFO" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver" + }, + "value": { + "string": "2" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.connection.ssl.enabled" + }, + "value": { + "string": "false" + } + } + ] + } + } + } + } +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json new file mode 100644 index 000000000000..f9518d39014a --- /dev/null +++ b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json @@ -0,0 +1,77 @@ +{ + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "id", + "name", + "language" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "columnTypes": [ + "NORMAL_COL", + "NORMAL_COL", + "NORMAL_COL" + ] + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree=0\n" + } + } + } + }, + "names": [ + "id#20", + "name#21", + "language#22" + ], + "outputSchema": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + } + } + ] +} \ No newline at end of file diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 9a37c4a40dd1..3ca5e0313924 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -430,7 +430,9 @@ trait SparkPlanExecApi { * * @return */ - def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] + def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { + SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() + } def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala index 77d5d55f618d..a6ec7cb21fbf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala @@ -24,37 +24,34 @@ import io.substrait.proto.{NamedStruct, Plan} object SubstraitPlanPrinterUtil extends Logging { - /** Transform Substrait Plan to json format. */ - def substraitPlanToJson(substraintPlan: Plan): String = { + private def typeRegistry( + d: com.google.protobuf.Descriptors.Descriptor): com.google.protobuf.TypeRegistry = { val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry + com.google.protobuf.TypeRegistry .newBuilder() - .add(substraintPlan.getDescriptorForType()) + .add(d) .add(defaultRegistry) .build() - JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan) + } + private def MessageToJson(message: com.google.protobuf.Message): String = { + val registry = typeRegistry(message.getDescriptorForType) + JsonFormat.printer.usingTypeRegistry(registry).print(message) } - def substraitNamedStructToJson(substraintPlan: NamedStruct): String = { - val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry - .newBuilder() - .add(substraintPlan.getDescriptorForType()) - .add(defaultRegistry) - .build() - JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan) + /** Transform Substrait Plan to json format. */ + def substraitPlanToJson(substraitPlan: Plan): String = { + MessageToJson(substraitPlan) + } + + def substraitNamedStructToJson(namedStruct: NamedStruct): String = { + MessageToJson(namedStruct) } /** Transform substrait plan json string to PlanNode */ def jsonToSubstraitPlan(planJson: String): Plan = { try { val builder = Plan.newBuilder() - val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry - .newBuilder() - .add(builder.getDescriptorForType) - .add(defaultRegistry) - .build() + val registry = typeRegistry(builder.getDescriptorForType) JsonFormat.parser().usingTypeRegistry(registry).merge(planJson, builder) builder.build() } catch { From 4c52976e4fce98e861da210f13a85a74d45f386e Mon Sep 17 00:00:00 2001 From: Shuai li Date: Tue, 25 Jun 2024 10:28:39 +0800 Subject: [PATCH 06/30] [GLUTEN-6176][CH] Support aggreate avg return decimal (#6177) * Support aggreate avg return decimal * update version * fix rebase * add ut --- .../GlutenClickHouseDecimalSuite.scala | 5 +- .../AggregateFunctionSparkAvg.cpp | 158 ++++++++++++++++++ cpp-ch/local-engine/Common/CHUtil.cpp | 9 +- cpp-ch/local-engine/Common/CHUtil.h | 5 +- .../local-engine/Common/GlutenDecimalUtils.h | 108 ++++++++++++ cpp-ch/local-engine/Parser/RelParser.cpp | 23 ++- .../org/apache/gluten/GlutenConfig.scala | 8 +- 7 files changed, 303 insertions(+), 13 deletions(-) create mode 100644 cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp create mode 100644 cpp-ch/local-engine/Common/GlutenDecimalUtils.h diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala index 088487101081..7320b7c05152 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala @@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply( (DecimalType.apply(9, 4), Seq()), // 1: ch decimal avg is float - (DecimalType.apply(18, 8), Seq(1)), + (DecimalType.apply(18, 8), Seq()), // 1: ch decimal avg is float, 3/10: all value is null and compare with limit - (DecimalType.apply(38, 19), Seq(1, 3, 10)) + (DecimalType.apply(38, 19), Seq(3, 10)) ) private def createDecimalTables(dataType: DecimalType): Unit = { @@ -337,7 +337,6 @@ class GlutenClickHouseDecimalSuite allowPrecisionLoss => Range .inclusive(1, 22) - .filter(_ != 17) // Ignore Q17 which include avg .foreach { sql_num => { diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp new file mode 100644 index 000000000000..5eb3a0b36057 --- /dev/null +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace DB +{ +struct Settings; + +namespace ErrorCodes +{ + +} +} + +namespace local_engine +{ +using namespace DB; + + +DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type) +{ + const UInt32 precision_value = std::min(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision); + const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value); + return createDecimal(precision_value, scale_value); +} + +template +requires is_decimal +class AggregateFunctionSparkAvg final : public AggregateFunctionAvg +{ +public: + using Base = AggregateFunctionAvg; + + explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + : Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_) + , num_scale(num_scale_) + , round_scale(round_scale_) + { + } + + DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + { + const DataTypePtr & data_type = argument_types_[0]; + const UInt32 precision_value = std::min(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision); + const auto scale_value = std::min(num_scale_ + 4, precision_value); + return createDecimal(precision_value, scale_value); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override + { + const DataTypePtr & result_type = this->getResultType(); + auto result_scale = getDecimalScale(*result_type); + WhichDataType which(result_type); + if (which.isDecimal32()) + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + else if (which.isDecimal64()) + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + else if (which.isDecimal128()) + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + else + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + } + + String getName() const override { return "sparkAvg"; } + +private: + Int128 NO_SANITIZE_UNDEFINED + divideDecimalAndUInt(AvgFraction, UInt64> avg, UInt32 num_scale, UInt32 result_scale, UInt32 round_scale) const + { + auto value = avg.numerator.value; + if (result_scale > num_scale) + { + auto diff = DecimalUtils::scaleMultiplier>(result_scale - num_scale); + value = value * diff; + } + else if (result_scale < num_scale) + { + auto diff = DecimalUtils::scaleMultiplier>(num_scale - result_scale); + value = value / diff; + } + + auto result = value / avg.denominator; + + if (round_scale > result_scale) + return result; + + auto round_diff = DecimalUtils::scaleMultiplier>(result_scale - round_scale); + return (result + round_diff / 2) / round_diff * round_diff; + } + +private: + UInt32 num_scale; + UInt32 round_scale; +}; + +AggregateFunctionPtr +createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) +{ + assertNoParameters(name, parameters); + assertUnary(name, argument_types); + + AggregateFunctionPtr res; + const DataTypePtr & data_type = argument_types[0]; + if (!isDecimal(data_type)) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name); + + bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get(); + const UInt32 p1 = DB::getDecimalPrecision(*data_type); + const UInt32 s1 = DB::getDecimalScale(*data_type); + auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; + auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss); + + res.reset(createWithDecimalType(*data_type, argument_types, getDecimalScale(*data_type), round_scale)); + return res; +} + +void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory) +{ + factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg); +} + +} diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index ae3f6dbd5208..588cc1cb2599 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -626,6 +626,7 @@ void BackendInitializerUtil::initSettings(std::map & b settings.set("date_time_input_format", "best_effort"); settings.set(MERGETREE_MERGE_AFTER_INSERT, true); settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false); + settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true); for (const auto & [key, value] : backend_conf_map) { @@ -665,6 +666,11 @@ void BackendInitializerUtil::initSettings(std::map & b settings.set("session_timezone", time_zone_val); LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", "session_timezone", time_zone_val); } + else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + { + settings.set(key, toField(key, value)); + LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", key, value); + } } /// Finally apply some fixed kvs to settings. @@ -788,6 +794,7 @@ void BackendInitializerUtil::updateNewSettings(const DB::ContextMutablePtr & con extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory &); extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &); +extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &); extern void registerFunctions(FunctionFactory &); void registerAllFunctions() @@ -797,7 +804,7 @@ void registerAllFunctions() DB::registerAggregateFunctions(); auto & agg_factory = AggregateFunctionFactory::instance(); registerAggregateFunctionsBloomFilter(agg_factory); - + registerAggregateFunctionSparkAvg(agg_factory); { /// register aggregate function combinators from local_engine auto & factory = AggregateFunctionCombinatorFactory::instance(); diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 245d7b3d15c4..0321d410a7d5 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -37,7 +37,10 @@ namespace local_engine { static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = "mergetree.insert_without_local_storage"; static const String MERGETREE_MERGE_AFTER_INSERT = "mergetree.merge_after_insert"; -static const std::unordered_set BOOL_VALUE_SETTINGS{MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE}; +static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = "spark.sql.decimalOperations.allowPrecisionLoss"; + +static const std::unordered_set BOOL_VALUE_SETTINGS{ + MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, DECIMAL_OPERATIONS_ALLOW_PREC_LOSS}; static const std::unordered_set LONG_VALUE_SETTINGS{ "optimize.maxfilesize", "optimize.minFileSize", "mergetree.max_num_part_per_merge_task"}; diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h new file mode 100644 index 000000000000..32af66ec590e --- /dev/null +++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h @@ -0,0 +1,108 @@ +/* +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + + +namespace local_engine +{ + +class GlutenDecimalUtils +{ +public: + static constexpr size_t MAX_PRECISION = 38; + static constexpr size_t MAX_SCALE = 38; + static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18); + static constexpr auto user_Default = std::tuple(10, 0); + static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6; + + // The decimal types compatible with other numeric types + static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0); + static constexpr auto BYTE_DECIMAL = std::tuple(3, 0); + static constexpr auto SHORT_DECIMAL = std::tuple(5, 0); + static constexpr auto INT_DECIMAL = std::tuple(10, 0); + static constexpr auto LONG_DECIMAL = std::tuple(20, 0); + static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7); + static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15); + static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0); + + static std::tuple adjustPrecisionScale(size_t precision, size_t scale) + { + if (precision <= MAX_PRECISION) + { + // Adjustment only needed when we exceed max precision + return std::tuple(precision, scale); + } + else if (scale < 0) + { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale); + } + else + { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + auto intDigits = precision - scale; + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + auto minScaleValue = std::min(scale, GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE); + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - intDigits, minScaleValue); + + return std::tuple(GlutenDecimalUtils::MAX_PRECISION, adjustedScale); + } + } + + static std::tuple dividePrecisionScale(size_t p1, size_t s1, size_t p2, size_t s2, bool allowPrecisionLoss) + { + if (allowPrecisionLoss) + { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + const size_t intDig = p1 - s1 + s2; + const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1); + const size_t precision = intDig + scale; + return adjustPrecisionScale(precision, scale); + } + else + { + auto intDig = std::min(MAX_SCALE, p1 - s1 + s2); + auto decDig = std::min(MAX_SCALE, std::max(static_cast(6), s1 + p2 + 1)); + auto diff = (intDig + decDig) - MAX_SCALE; + if (diff > 0) + { + decDig -= diff / 2 + 1; + intDig = MAX_SCALE - decDig; + } + return std::tuple(intDig + decDig, decDig); + } + } + + static std::tuple widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2) + { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + auto scale = std::max(s1, s2); + auto range = std::max(p1 - s1, p2 - s2); + return std::tuple(range + scale, scale); + } + +}; + +} diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index 7fc807827109..282339c4d641 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -15,12 +15,16 @@ * limitations under the License. */ #include "RelParser.h" + #include +#include + #include +#include #include -#include -#include #include +#include + namespace DB { @@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction( { auto & factory = AggregateFunctionFactory::instance(); auto action = NullsAction::EMPTY; - return factory.get(name, action, arg_types, parameters, properties); + + String function_name = name; + if (name == "avg" && isDecimal(removeNullable(arg_types[0]))) + function_name = "sparkAvg"; + else if (name == "avgPartialMerge") + { + if (auto agg_func = typeid_cast(arg_types[0].get()); + !agg_func->getArgumentsDataTypes().empty() && isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0]))) + { + function_name = "sparkAvgPartialMerge"; + } + } + + return factory.get(function_name, action, arg_types, parameters, properties); } std::optional RelParser::parseSignatureFunctionName(UInt32 function_ref) diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 148e8cdc067c..4b4e29e7d0fb 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -718,7 +718,9 @@ object GlutenConfig { GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY, GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY, - GLUTEN_OFFHEAP_ENABLED + GLUTEN_OFFHEAP_ENABLED, + SESSION_LOCAL_TIMEZONE.key, + DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key ) nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava) @@ -735,10 +737,6 @@ object GlutenConfig { .filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY)) .foreach(entry => nativeConfMap.put(entry._1, entry._2)) - conf - .filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key)) - .foreach(entry => nativeConfMap.put(entry._1, entry._2)) - // return nativeConfMap } From cf04f0fe3e338169492a28a35cb3562d5d29cdaa Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Tue, 25 Jun 2024 10:34:50 +0800 Subject: [PATCH 07/30] [GLUTEN-5659][VL] Add more configs for AWS s3 (#5660) Add more configs for AWS s3 spark.gluten.velox.fs.s3a.retry.mode spark.gluten.velox.fs.s3a.connect.timeout spark.hadoop.fs.s3a.retry.limit spark.hadoop.fs.s3a.connection.maximum --- cpp/velox/utils/ConfigExtractor.cc | 23 ++++++++++++++ docs/Configuration.md | 2 ++ .../org/apache/gluten/GlutenConfig.scala | 30 +++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/cpp/velox/utils/ConfigExtractor.cc b/cpp/velox/utils/ConfigExtractor.cc index a71f143225b9..816166351c0e 100644 --- a/cpp/velox/utils/ConfigExtractor.cc +++ b/cpp/velox/utils/ConfigExtractor.cc @@ -34,6 +34,13 @@ const bool kVeloxFileHandleCacheEnabledDefault = false; // Log granularity of AWS C++ SDK const std::string kVeloxAwsSdkLogLevel = "spark.gluten.velox.awsSdkLogLevel"; const std::string kVeloxAwsSdkLogLevelDefault = "FATAL"; +// Retry mode for AWS s3 +const std::string kVeloxS3RetryMode = "spark.gluten.velox.fs.s3a.retry.mode"; +const std::string kVeloxS3RetryModeDefault = "legacy"; +// Connection timeout for AWS s3 +const std::string kVeloxS3ConnectTimeout = "spark.gluten.velox.fs.s3a.connect.timeout"; +// Using default fs.s3a.connection.timeout value in hadoop +const std::string kVeloxS3ConnectTimeoutDefault = "200s"; } // namespace namespace gluten { @@ -64,6 +71,10 @@ std::shared_ptr getHiveConfig(std::shared_ptr< bool useInstanceCredentials = conf->get("spark.hadoop.fs.s3a.use.instance.credentials", false); std::string iamRole = conf->get("spark.hadoop.fs.s3a.iam.role", ""); std::string iamRoleSessionName = conf->get("spark.hadoop.fs.s3a.iam.role.session.name", ""); + std::string retryMaxAttempts = conf->get("spark.hadoop.fs.s3a.retry.limit", "20"); + std::string retryMode = conf->get(kVeloxS3RetryMode, kVeloxS3RetryModeDefault); + std::string maxConnections = conf->get("spark.hadoop.fs.s3a.connection.maximum", "15"); + std::string connectTimeout = conf->get(kVeloxS3ConnectTimeout, kVeloxS3ConnectTimeoutDefault); std::string awsSdkLogLevel = conf->get(kVeloxAwsSdkLogLevel, kVeloxAwsSdkLogLevelDefault); @@ -79,6 +90,14 @@ std::shared_ptr getHiveConfig(std::shared_ptr< if (envAwsEndpoint != nullptr) { awsEndpoint = std::string(envAwsEndpoint); } + const char* envRetryMaxAttempts = std::getenv("AWS_MAX_ATTEMPTS"); + if (envRetryMaxAttempts != nullptr) { + retryMaxAttempts = std::string(envRetryMaxAttempts); + } + const char* envRetryMode = std::getenv("AWS_RETRY_MODE"); + if (envRetryMode != nullptr) { + retryMode = std::string(envRetryMode); + } if (useInstanceCredentials) { hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3UseInstanceCredentials] = "true"; @@ -98,6 +117,10 @@ std::shared_ptr getHiveConfig(std::shared_ptr< hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3SSLEnabled] = sslEnabled ? "true" : "false"; hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3PathStyleAccess] = pathStyleAccess ? "true" : "false"; hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3LogLevel] = awsSdkLogLevel; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3MaxAttempts] = retryMaxAttempts; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3RetryMode] = retryMode; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3MaxConnections] = maxConnections; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3ConnectTimeout] = connectTimeout; #endif #ifdef ENABLE_GCS diff --git a/docs/Configuration.md b/docs/Configuration.md index 089675286f68..2c2bd4de11f2 100644 --- a/docs/Configuration.md +++ b/docs/Configuration.md @@ -89,6 +89,8 @@ The following configurations are related to Velox settings. | spark.gluten.sql.columnar.backend.velox.maxCoalescedBytes | Set the max coalesced bytes for velox file scan. | | | spark.gluten.sql.columnar.backend.velox.cachePrefetchMinPct | Set prefetch cache min pct for velox file scan. | | | spark.gluten.velox.awsSdkLogLevel | Log granularity of AWS C++ SDK in velox. | FATAL | +| spark.gluten.velox.fs.s3a.retry.mode | Retry mode for AWS s3 connection error, can be "legacy", "standard" and "adaptive". | legacy | +| spark.gluten.velox.fs.s3a.connect.timeout | Timeout for AWS s3 connection. | 1s | | spark.gluten.sql.columnar.backend.velox.orc.scan.enabled | Enable velox orc scan. If disabled, vanilla spark orc scan will be used. | true | | spark.gluten.sql.complexType.scan.fallback.enabled | Force fallback for complex type scan, including struct, map, array. | true | diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 4b4e29e7d0fb..cc2d6ac5fdef 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -436,6 +436,10 @@ class GlutenConfig(conf: SQLConf) extends Logging { def awsSdkLogLevel: String = conf.getConf(AWS_SDK_LOG_LEVEL) + def awsS3RetryMode: String = conf.getConf(AWS_S3_RETRY_MODE) + + def awsConnectionTimeout: String = conf.getConf(AWS_S3_CONNECT_TIMEOUT) + def enableCastAvgAggregateFunction: Boolean = conf.getConf(COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED) def enableGlutenCostEvaluator: Boolean = conf.getConf(COST_EVALUATOR_ENABLED) @@ -488,6 +492,10 @@ object GlutenConfig { val SPARK_S3_IAM: String = HADOOP_PREFIX + S3_IAM_ROLE val S3_IAM_ROLE_SESSION_NAME = "fs.s3a.iam.role.session.name" val SPARK_S3_IAM_SESSION_NAME: String = HADOOP_PREFIX + S3_IAM_ROLE_SESSION_NAME + val S3_RETRY_MAX_ATTEMPTS = "fs.s3a.retry.limit" + val SPARK_S3_RETRY_MAX_ATTEMPTS: String = HADOOP_PREFIX + S3_RETRY_MAX_ATTEMPTS + val S3_CONNECTION_MAXIMUM = "fs.s3a.connection.maximum" + val SPARK_S3_CONNECTION_MAXIMUM: String = HADOOP_PREFIX + S3_CONNECTION_MAXIMUM // Hardware acceleraters backend val GLUTEN_SHUFFLE_CODEC_BACKEND = "spark.gluten.sql.columnar.shuffle.codecBackend" @@ -642,6 +650,10 @@ object GlutenConfig { SPARK_S3_USE_INSTANCE_CREDENTIALS, SPARK_S3_IAM, SPARK_S3_IAM_SESSION_NAME, + SPARK_S3_RETRY_MAX_ATTEMPTS, + SPARK_S3_CONNECTION_MAXIMUM, + AWS_S3_CONNECT_TIMEOUT.key, + AWS_S3_RETRY_MODE.key, AWS_SDK_LOG_LEVEL.key, // gcs config SPARK_GCS_STORAGE_ROOT_URL, @@ -693,6 +705,10 @@ object GlutenConfig { (SPARK_S3_USE_INSTANCE_CREDENTIALS, "false"), (SPARK_S3_IAM, ""), (SPARK_S3_IAM_SESSION_NAME, ""), + (SPARK_S3_RETRY_MAX_ATTEMPTS, "20"), + (SPARK_S3_CONNECTION_MAXIMUM, "15"), + (AWS_S3_CONNECT_TIMEOUT.key, AWS_S3_CONNECT_TIMEOUT.defaultValueString), + (AWS_S3_RETRY_MODE.key, AWS_S3_RETRY_MODE.defaultValueString), ( COLUMNAR_VELOX_CONNECTOR_IO_THREADS.key, conf.getOrElse(GLUTEN_NUM_TASK_SLOTS_PER_EXECUTOR_KEY, "-1")), @@ -1941,6 +1957,20 @@ object GlutenConfig { .stringConf .createWithDefault("FATAL") + val AWS_S3_RETRY_MODE = + buildConf("spark.gluten.velox.fs.s3a.retry.mode") + .internal() + .doc("Retry mode for AWS s3 connection error: legacy, standard and adaptive.") + .stringConf + .createWithDefault("legacy") + + val AWS_S3_CONNECT_TIMEOUT = + buildConf("spark.gluten.velox.fs.s3a.connect.timeout") + .internal() + .doc("Timeout for AWS s3 connection.") + .stringConf + .createWithDefault("200s") + val VELOX_ORC_SCAN_ENABLED = buildStaticConf("spark.gluten.sql.columnar.backend.velox.orc.scan.enabled") .internal() From de26ed2dad41d2d1e893c8d1b3ae806385d9972f Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Tue, 25 Jun 2024 16:10:05 +0800 Subject: [PATCH 08/30] [CH] Support flatten (#6194) [CH] Support flatten Co-authored-by: liuneng1994 --- .../gluten/utils/CHExpressionUtil.scala | 1 - cpp-ch/clickhouse.version | 3 +- .../Functions/SparkArrayFlatten.cpp | 160 ++++++++++++++++++ .../Parser/SerializedPlanParser.h | 1 + .../clickhouse/ClickHouseTestSettings.scala | 2 +- .../sql/GlutenDataFrameFunctionsSuite.scala | 82 +++++++++ .../clickhouse/ClickHouseTestSettings.scala | 2 +- .../sql/GlutenDataFrameFunctionsSuite.scala | 82 +++++++++ 8 files changed, 329 insertions(+), 4 deletions(-) create mode 100644 cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index cf45c1118f13..e9bee84396f8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -209,7 +209,6 @@ object CHExpressionUtil { UNIX_MICROS -> DefaultValidator(), TIMESTAMP_MILLIS -> DefaultValidator(), TIMESTAMP_MICROS -> DefaultValidator(), - FLATTEN -> DefaultValidator(), STACK -> DefaultValidator() ) } diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 4a3088e54309..54d0a74c5bb4 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,4 @@ CH_ORG=Kyligence CH_BRANCH=rebase_ch/20240621 -CH_COMMIT=acf666c1c4f +CH_COMMIT=c811cbb985f + diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp new file mode 100644 index 000000000000..d39bca5ea104 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; +} + +/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array. +class SparkArrayFlatten : public IFunction +{ +public: + static constexpr auto name = "sparkArrayFlatten"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + size_t getNumberOfArguments() const override { return 1; } + bool useDefaultImplementationForConstants() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isArray(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}, expected Array", + arguments[0]->getName(), getName()); + + DataTypePtr nested_type = arguments[0]; + nested_type = checkAndGetDataType(removeNullable(nested_type).get())->getNestedType(); + return nested_type; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + /** We create an array column with array elements as the most deep elements of nested arrays, + * and construct offsets by selecting elements of most deep offsets by values of ancestor offsets. + * +Example 1: + +Source column: Array(Array(UInt8)): +Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]] +data: [1, 2, 3], [4, 5], [6], [7, 8] +offsets: 2, 4 +data.data: 1 2 3 4 5 6 7 8 +data.offsets: 3 5 6 8 + +Result column: Array(UInt8): +Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8] +data: 1 2 3 4 5 6 7 8 +offsets: 5 8 + +Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one): +3 5 6 8 + ^ ^ + +Example 2: + +Source column: Array(Array(Array(UInt8))): +Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]] + +most deep data: 1 2 3 4 + +offsets1: 2 3 +offsets2: 0 3 4 +- ^ ^ - select by prev offsets +offsets3: 1 1 3 4 +- ^ ^ - select by prev offsets + +result offsets: 3, 4 +result: Row 1: [1, 2, 3], Row2: [4] + */ + + const ColumnArray * src_col = checkAndGetColumn(arguments[0].column.get()); + + if (!src_col) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in argument of function 'arrayFlatten'", + arguments[0].column->getName()); + + const IColumn::Offsets & src_offsets = src_col->getOffsets(); + + ColumnArray::ColumnOffsets::MutablePtr result_offsets_column; + const IColumn::Offsets * prev_offsets = &src_offsets; + const IColumn * prev_data = &src_col->getData(); + bool nullable = prev_data->isNullable(); + // when array has null element, return null + if (nullable) + { + const ColumnNullable * nullable_column = checkAndGetColumn(prev_data); + prev_data = nullable_column->getNestedColumnPtr().get(); + for (size_t i = 0; i < nullable_column->size(); i++) + { + if (nullable_column->isNullAt(i)) + { + auto res= nullable_column->cloneEmpty(); + res->insertManyDefaults(input_rows_count); + return res; + } + } + } + if (isNothing(prev_data->getDataType())) + return prev_data->cloneResized(input_rows_count); + // only flatten one dimension + if (const ColumnArray * next_col = checkAndGetColumn(prev_data)) + { + result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count); + + IColumn::Offsets & result_offsets = result_offsets_column->getData(); + + const IColumn::Offsets * next_offsets = &next_col->getOffsets(); + + for (size_t i = 0; i < input_rows_count; ++i) + result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray + prev_data = &next_col->getData(); + } + + auto res = ColumnArray::create( + prev_data->getPtr(), + result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr()); + if (nullable) + return makeNullable(res); + return res; + } + +private: + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(SparkArrayFlatten) +{ + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 82e8c4077841..aa18197e5647 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -180,6 +180,7 @@ static const std::map SCALAR_FUNCTIONS {"array", "array"}, {"shuffle", "arrayShuffle"}, {"range", "range"}, /// dummy mapping + {"flatten", "sparkArrayFlatten"}, // map functions {"map", "map"}, diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 8572ef54d5c8..1626716805cb 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -172,6 +172,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("shuffle function - array for primitive type not containing null") .exclude("shuffle function - array for primitive type containing null") .exclude("shuffle function - array for non-primitive type") + .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( @@ -674,7 +675,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala index 2b0b40790a76..e64f760ab55f 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala @@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS false ) } + + testGluten("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))) + ).toDF("i") + + val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(Seq(1)), Row(Seq(1))) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a"))) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 50e7929e4619..3147c7c3dbf3 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -190,6 +190,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("shuffle function - array for primitive type not containing null") .exclude("shuffle function - array for primitive type containing null") .exclude("shuffle function - array for non-primitive type") + .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( @@ -714,7 +715,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala index 2b0b40790a76..e64f760ab55f 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala @@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS false ) } + + testGluten("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))) + ).toDF("i") + + val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(Seq(1)), Row(Seq(1))) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a"))) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } } From 3e5b54e64032ab3e860c1636ab4760557a8a1e96 Mon Sep 17 00:00:00 2001 From: Zhen Li <10524738+zhli1142015@users.noreply.github.com> Date: Tue, 25 Jun 2024 19:15:55 +0800 Subject: [PATCH 09/30] [VL] Fix greatest and least function tests (#6209) [VL] Fix greatest and least function tests. --- .../ScalarFunctionsValidateSuite.scala | 44 ++++++++++++++----- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 75b60addfa13..a2baf95ecdc0 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -157,24 +157,28 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { checkLengthAndPlan(df, 1) } - test("greatest function") { - val df = runQueryAndCompare( - "SELECT greatest(l_orderkey, l_orderkey)" + - "from lineitem limit 1")(checkGlutenOperatorMatch[ProjectExecTransformer]) - } - - test("least function") { - val df = runQueryAndCompare( - "SELECT least(l_orderkey, l_orderkey)" + - "from lineitem limit 1")(checkGlutenOperatorMatch[ProjectExecTransformer]) - } - test("Test greatest function") { runQueryAndCompare( "SELECT greatest(l_orderkey, l_orderkey)" + "from lineitem limit 1") { checkGlutenOperatorMatch[ProjectExecTransformer] } + withTempPath { + path => + spark + .sql("""SELECT * + FROM VALUES (CAST(5.345 AS DECIMAL(6, 2)), CAST(5.35 AS DECIMAL(5, 4))), + (CAST(5.315 AS DECIMAL(6, 2)), CAST(5.355 AS DECIMAL(5, 4))), + (CAST(3.345 AS DECIMAL(6, 2)), CAST(4.35 AS DECIMAL(5, 4))) AS data(a, b);""") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view") + + runQueryAndCompare("SELECT greatest(a, b) from view") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } } test("Test least function") { @@ -183,6 +187,22 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { "from lineitem limit 1") { checkGlutenOperatorMatch[ProjectExecTransformer] } + withTempPath { + path => + spark + .sql("""SELECT * + FROM VALUES (CAST(5.345 AS DECIMAL(6, 2)), CAST(5.35 AS DECIMAL(5, 4))), + (CAST(5.315 AS DECIMAL(6, 2)), CAST(5.355 AS DECIMAL(5, 4))), + (CAST(3.345 AS DECIMAL(6, 2)), CAST(4.35 AS DECIMAL(5, 4))) AS data(a, b);""") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view") + + runQueryAndCompare("SELECT least(a, b) from view") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } } test("Test hash function") { From ad0fb0e718bdd7437360e717a4a3fb0ac8fbc6af Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Tue, 25 Jun 2024 21:31:42 +0800 Subject: [PATCH 10/30] [VL] Fix udf segfault for static build (#6215) --- cpp/velox/symbols.map | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/velox/symbols.map b/cpp/velox/symbols.map index ebd2b9af0096..525faf3526a1 100644 --- a/cpp/velox/symbols.map +++ b/cpp/velox/symbols.map @@ -6,6 +6,8 @@ }; Java_org_apache_gluten_*; + JNI_OnLoad; + JNI_OnUnload; local: # Hide symbols of static dependencies *; From 22475dacf5122aaedc5e85b38ec496dad7a2a7e2 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Tue, 25 Jun 2024 21:43:35 +0800 Subject: [PATCH 11/30] [VL] Daily Update Velox Version (2024_06_25) (#6204) Velox main changes: ``` 1225f773f by joey.ljy, Add session timezone to Parquet PageReader (#9781) 33cdf0a97 by Wei He, Add custom input generator for lead, lag, nth_value, and ntile in WindowFuzzerTest (#8360) 82a12e165 by Deepak Majeti, Remove setup-centos8.sh (#10249) 7be328cac by Zhenyuan Zhao, Make dwrf support taking custom column reader factory (#10267) 1f981ae8f by Orri Erling, Add more size classes (#10139) dc533655f by Masha Basmanova, Add from_unixtime(epoch, hours, minutes) Presto function (#10215) 7f547dbca by Wei He, Add custom result verifiers for min_by and max_by (#9070) 9974a3339 by Wei He, Allow logging input vectors in aggregation fuzzer (#10229) ``` --- .github/workflows/velox_docker.yml | 10 ++++++++++ ep/build-velox/src/get_velox.sh | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/velox_docker.yml b/.github/workflows/velox_docker.yml index 5f64c9f7e0e8..31796c15bdd5 100644 --- a/.github/workflows/velox_docker.yml +++ b/.github/workflows/velox_docker.yml @@ -120,6 +120,12 @@ jobs: with: name: velox-arrow-jar-centos-7-${{github.sha}} path: /root/.m2/repository/org/apache/arrow/ + - name: Setup tzdata + run: | + if [ "${{ matrix.os }}" = "ubuntu:22.04" ]; then + apt-get update + TZ="Etc/GMT" DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata + fi - name: Setup java and maven run: | if [ "${{ matrix.java }}" = "java-17" ]; then @@ -530,6 +536,10 @@ jobs: with: name: velox-arrow-jar-centos-7-${{github.sha}} path: /root/.m2/repository/org/apache/arrow/ + - name: Setup tzdata + run: | + apt-get update + TZ="Etc/GMT" DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata - name: Setup java and maven run: | apt-get update && apt-get install -y openjdk-8-jdk maven wget diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index d3ecddbdfa9a..06998787d45e 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_06_24 +VELOX_BRANCH=2024_06_25 VELOX_HOME="" #Set on run gluten on HDFS From 524434826b42fd7e5cfda9b6e023efa656e2c6ae Mon Sep 17 00:00:00 2001 From: James Xu Date: Tue, 25 Jun 2024 22:25:04 +0800 Subject: [PATCH 12/30] [GLUTEN-6219] Fix some code style issue for BasicScanExecTransformer.scala (#6220) Co-authored-by: James Xu --- .../execution/BasicScanExecTransformer.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 3bbd99c50a6a..9d231bbc2891 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} import com.google.protobuf.StringValue +import io.substrait.proto.NamedStruct import scala.collection.JavaConverters._ @@ -109,19 +110,19 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource } override protected def doTransform(context: SubstraitContext): TransformContext = { - val output = filteRedundantField(outputAttributes()) + val output = filterRedundantField(outputAttributes()) val typeNodes = ConverterUtils.collectAttributeTypeNodes(output) val nameList = ConverterUtils.collectAttributeNamesWithoutExprId(output) val columnTypeNodes = output.map { attr => if (getPartitionSchema.exists(_.name.equals(attr.name))) { - new ColumnTypeNode(1) + new ColumnTypeNode(NamedStruct.ColumnType.PARTITION_COL_VALUE) } else if (SparkShimLoader.getSparkShims.isRowIndexMetadataColumn(attr.name)) { - new ColumnTypeNode(3) + new ColumnTypeNode(NamedStruct.ColumnType.ROWINDEX_COL_VALUE) } else if (attr.isMetadataCol) { - new ColumnTypeNode(2) + new ColumnTypeNode(NamedStruct.ColumnType.METADATA_COL_VALUE) } else { - new ColumnTypeNode(0) + new ColumnTypeNode(NamedStruct.ColumnType.NORMAL_COL_VALUE) } }.asJava // Will put all filter expressions into an AND expression @@ -156,8 +157,8 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource TransformContext(output, output, readNode) } - def filteRedundantField(outputs: Seq[Attribute]): Seq[Attribute] = { - var final_output: List[Attribute] = List() + private def filterRedundantField(outputs: Seq[Attribute]): Seq[Attribute] = { + var finalOutput: List[Attribute] = List() val outputList = outputs.toArray for (i <- outputList.indices) { var dup = false @@ -167,9 +168,9 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource } } if (!dup) { - final_output = final_output :+ outputList(i) + finalOutput = finalOutput :+ outputList(i) } } - final_output + finalOutput } } From 945ac2342202533ccf862e248ef262a377ba1569 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 26 Jun 2024 09:57:38 +0800 Subject: [PATCH 13/30] [GLUTEN-6180][VL] Fix NPE if spilling is requested during task creation (#6205) --- .../memory/memtarget/MemoryTargets.java | 2 +- .../arrow/alloc/ArrowBufferAllocators.java | 11 +- .../memory/nmm/NativeMemoryManagers.java | 157 +++++++++--------- .../vectorized/NativePlanEvaluator.java | 2 +- 4 files changed, 91 insertions(+), 81 deletions(-) diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java index 2d6fc0748464..c3ece743310a 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java @@ -63,6 +63,6 @@ public static MemoryTarget newConsumer( factory = TreeMemoryConsumers.shared(); } - return dynamicOffHeapSizingIfEnabled(factory.newConsumer(tmm, name, spillers, virtualChildren)); + return factory.newConsumer(tmm, name, spillers, virtualChildren); } } diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java b/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java index efee20e48b83..51f49da704eb 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java @@ -60,11 +60,12 @@ public static class ArrowBufferAllocatorManager implements TaskResource { listener = new ManagedAllocationListener( MemoryTargets.throwOnOom( - MemoryTargets.newConsumer( - tmm, - "ArrowContextInstance", - Collections.emptyList(), - Collections.emptyMap())), + MemoryTargets.dynamicOffHeapSizingIfEnabled( + MemoryTargets.newConsumer( + tmm, + "ArrowContextInstance", + Collections.emptyList(), + Collections.emptyMap()))), TaskResources.getSharedUsage()); } diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java b/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java index 928f869ba4e1..37456badd42f 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java @@ -26,6 +26,8 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.TaskResources; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Arrays; import java.util.Collections; @@ -37,6 +39,7 @@ import java.util.stream.Stream; public final class NativeMemoryManagers { + private static final Logger LOG = LoggerFactory.getLogger(NativeMemoryManagers.class); // TODO: Let all caller support spill. public static NativeMemoryManager contextInstance(String name) { @@ -67,86 +70,92 @@ private static NativeMemoryManager createNativeMemoryManager( final MemoryTarget target = MemoryTargets.throwOnOom( MemoryTargets.overAcquire( - MemoryTargets.newConsumer( - tmm, - name, - // call memory manager's shrink API, if no good then call the spiller - Stream.concat( - Stream.of( - new Spiller() { - @Override - public long spill(MemoryTarget self, long size) { - return Optional.of(out.get()) - .map(nmm -> nmm.shrink(size)) - .orElseThrow( - () -> - new IllegalStateException( - "" - + "Shrink is requested before native " - + "memory manager is created. Try moving " - + "any actions about memory allocation out " - + "from the memory manager constructor.")); - } + MemoryTargets.dynamicOffHeapSizingIfEnabled( + MemoryTargets.newConsumer( + tmm, + name, + // call memory manager's shrink API, if no good then call the spiller + Stream.concat( + Stream.of( + new Spiller() { + @Override + public long spill(MemoryTarget self, long size) { + return Optional.ofNullable(out.get()) + .map(nmm -> nmm.shrink(size)) + .orElseGet( + () -> { + LOG.warn( + "Shrink is requested before native " + + "memory manager is created. Try moving " + + "any actions about memory allocation" + + " out from the memory manager" + + " constructor."); + return 0L; + }); + } - @Override - public Set applicablePhases() { - return Spillers.PHASE_SET_SHRINK_ONLY; - } - }), - spillers.stream()) - .map(spiller -> Spillers.withMinSpillSize(spiller, reservationBlockSize)) - .collect(Collectors.toList()), - Collections.singletonMap( - "single", - new MemoryUsageRecorder() { - @Override - public void inc(long bytes) { - // no-op - } + @Override + public Set applicablePhases() { + return Spillers.PHASE_SET_SHRINK_ONLY; + } + }), + spillers.stream()) + .map( + spiller -> Spillers.withMinSpillSize(spiller, reservationBlockSize)) + .collect(Collectors.toList()), + Collections.singletonMap( + "single", + new MemoryUsageRecorder() { + @Override + public void inc(long bytes) { + // no-op + } - @Override - public long peak() { - throw new UnsupportedOperationException("Not implemented"); - } + @Override + public long peak() { + throw new UnsupportedOperationException("Not implemented"); + } - @Override - public long current() { - throw new UnsupportedOperationException("Not implemented"); - } + @Override + public long current() { + throw new UnsupportedOperationException("Not implemented"); + } - @Override - public MemoryUsageStats toStats() { - return getNativeMemoryManager().collectMemoryUsage(); - } + @Override + public MemoryUsageStats toStats() { + return getNativeMemoryManager().collectMemoryUsage(); + } - private NativeMemoryManager getNativeMemoryManager() { - return Optional.of(out.get()) - .orElseThrow( - () -> - new IllegalStateException( - "" - + "Memory usage stats are requested before native " - + "memory manager is created. Try moving any " - + "actions about memory allocation out from the " - + "memory manager constructor.")); - } - })), - MemoryTargets.newConsumer( - tmm, - "OverAcquire.DummyTarget", - Collections.singletonList( - new Spiller() { - @Override - public long spill(MemoryTarget self, long size) { - return self.repay(size); - } + private NativeMemoryManager getNativeMemoryManager() { + return Optional.ofNullable(out.get()) + .orElseThrow( + () -> + new IllegalStateException( + "" + + "Memory usage stats are requested before" + + " native memory manager is created. Try" + + " moving any actions about memory" + + " allocation out from the memory manager" + + " constructor.")); + } + }))), + MemoryTargets.dynamicOffHeapSizingIfEnabled( + MemoryTargets.newConsumer( + tmm, + "OverAcquire.DummyTarget", + Collections.singletonList( + new Spiller() { + @Override + public long spill(MemoryTarget self, long size) { + return self.repay(size); + } - @Override - public Set applicablePhases() { - return Spillers.PHASE_SET_ALL; - } - }), - Collections.emptyMap()), + @Override + public Set applicablePhases() { + return Spillers.PHASE_SET_ALL; + } + }), + Collections.emptyMap())), overAcquiredRatio)); // listener ManagedReservationListener rl = diff --git a/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java b/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java index e54724a599c1..2ac048b2b960 100644 --- a/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java +++ b/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java @@ -71,7 +71,7 @@ public GeneralOutIterator createKernelWithBatchIterator( @Override public long spill(MemoryTarget self, long size) { ColumnarBatchOutIterator instance = - Optional.of(outIterator.get()) + Optional.ofNullable(outIterator.get()) .orElseThrow( () -> new IllegalStateException( From 9d2fcdeaaa2631bb50eedf214d6684f9aa2252ce Mon Sep 17 00:00:00 2001 From: Kerwin Zhang Date: Wed, 26 Jun 2024 11:20:19 +0800 Subject: [PATCH 14/30] [CELEBORN] Fix potential ClassNotFoundException (#6217) --- .../org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java index 4593d019c27e..9dd4e1d1191e 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java @@ -65,7 +65,7 @@ public static boolean unregisterShuffle( unregisterAppShuffleId.invoke(shuffleIdTracker, shuffleClient, appShuffleId); } return true; - } catch (NoSuchMethodException ex) { + } catch (NoSuchMethodException | ClassNotFoundException ex) { try { if (lifecycleManager != null) { Method unregisterShuffleMethod = From 529d3f403d2f1b1c9bf6105e0468f4a9f3a19d43 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 26 Jun 2024 13:45:20 +0800 Subject: [PATCH 15/30] [VL] Add a benchmark to track on iterator facility's performance (#6225) --- .../backendsapi/velox/VeloxIteratorApi.scala | 3 +- .../datasource/ArrowCSVFileFormat.scala | 3 +- .../execution/RowToVeloxColumnarExec.scala | 3 +- .../execution/VeloxAppendBatchesExec.scala | 3 +- .../VeloxBroadcastBuildSideRDD.scala | 2 +- .../execution/VeloxColumnarToRowExec.scala | 2 +- .../python/ColumnarArrowEvalPythonExec.scala | 3 +- .../ColumnarCachedBatchSerializer.scala | 3 +- .../datasources/VeloxWriteQueue.scala | 2 +- .../org/apache/gluten/utils/Iterators.scala | 228 ------------------ .../gluten/utils/iterator/Iterators.scala | 53 ++++ .../gluten/utils/iterator/IteratorsV1.scala | 222 +++++++++++++++++ .../utils/{ => iterator}/IteratorSuite.scala | 30 +-- .../utils/iterator/IteratorBenchmark.scala | 129 ++++++++++ .../execution/ColumnarBuildSideRelation.scala | 3 +- .../spark/sql/execution/utils/ExecUtil.scala | 2 +- 16 files changed, 438 insertions(+), 253 deletions(-) delete mode 100644 gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala rename gluten-core/src/test/scala/org/apache/gluten/utils/{ => iterator}/IteratorSuite.scala (86%) create mode 100644 gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala index 880e1e56b852..22862156c6b2 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala @@ -26,6 +26,7 @@ import org.apache.gluten.substrait.plan.PlanNode import org.apache.gluten.substrait.rel.{LocalFilesBuilder, LocalFilesNode, SplitInfo} import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.gluten.utils._ +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized._ import org.apache.spark.{SparkConf, TaskContext} @@ -36,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.{BinaryType, DateType, Decimal, DecimalType, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ExecutorManager diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala index 7c3ca8fc8cde..a8e65b0539c7 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala @@ -21,7 +21,8 @@ import org.apache.gluten.exception.SchemaMismatchException import org.apache.gluten.execution.RowToVeloxColumnarExec import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool -import org.apache.gluten.utils.{ArrowUtil, Iterators} +import org.apache.gluten.utils.ArrowUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ArrowWritableColumnVector import org.apache.spark.TaskContext diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala index 5c9c5889bd13..d694f15fa9bd 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala @@ -22,7 +22,8 @@ import org.apache.gluten.exception.GlutenException import org.apache.gluten.exec.Runtimes import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.{ArrowAbiUtil, Iterators} +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized._ import org.apache.spark.broadcast.Broadcast diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala index 8c2834574204..4b4db703de7a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala @@ -17,7 +17,8 @@ package org.apache.gluten.execution import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.utils.{Iterators, VeloxBatchAppender} +import org.apache.gluten.utils.VeloxBatchAppender +import org.apache.gluten.utils.iterator.Iterators import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala index 17d0522d0732..fe3c0b7e3938 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.execution -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.spark.{broadcast, SparkContext} import org.apache.spark.sql.execution.joins.BuildSideRelation diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 77bf49727283..0d6714d3af92 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -20,7 +20,7 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.extension.ValidationResult import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.NativeColumnarToRowJniWrapper import org.apache.spark.broadcast.Broadcast diff --git a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala index d5639057dac8..88280ff2edde 100644 --- a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala +++ b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala @@ -20,7 +20,8 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.GlutenPlan import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators -import org.apache.gluten.utils.{Iterators, PullOutProjectHelper} +import org.apache.gluten.utils.PullOutProjectHelper +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ArrowWritableColumnVector import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext} diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index 7385c53d61b3..cb65b7504bfc 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -23,7 +23,8 @@ import org.apache.gluten.exec.Runtimes import org.apache.gluten.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExec} import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.{ArrowAbiUtil, Iterators} +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ColumnarBatchSerializerJniWrapper import org.apache.spark.internal.Logging diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala index 089db1da1dee..b2905e157554 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.gluten.datasource.DatasourceJniWrapper -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ColumnarBatchInIterator import org.apache.spark.TaskContext diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala deleted file mode 100644 index 1e3681355d6c..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.utils - -import org.apache.spark.{InterruptibleIterator, TaskContext} -import org.apache.spark.util.TaskResources - -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean - -private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit) extends Iterator[A] { - private var closer: Option[() => Unit] = None - - TaskResources.addRecycler("Iterators#PayloadCloser", 100) { - tryClose() - } - - override def hasNext: Boolean = { - tryClose() - in.hasNext - } - - override def next(): A = { - val a: A = in.next() - closer.synchronized { - closer = Some( - () => { - closeCallback.apply(a) - }) - } - a - } - - private def tryClose(): Unit = { - closer.synchronized { - closer match { - case Some(c) => c.apply() - case None => - } - closer = None // make sure the payload is closed once - } - } -} - -private class IteratorCompleter[A](in: Iterator[A])(completionCallback: => Unit) - extends Iterator[A] { - private val completed = new AtomicBoolean(false) - - TaskResources.addRecycler("Iterators#IteratorRecycler", 100) { - tryComplete() - } - - override def hasNext: Boolean = { - val out = in.hasNext - if (!out) { - tryComplete() - } - out - } - - override def next(): A = { - in.next() - } - - private def tryComplete(): Unit = { - if (!completed.compareAndSet(false, true)) { - return // make sure the iterator is completed once - } - completionCallback - } -} - -private class LifeTimeAccumulator[A](in: Iterator[A], onCollected: Long => Unit) - extends Iterator[A] { - private val closed = new AtomicBoolean(false) - private val startTime = System.nanoTime() - - TaskResources.addRecycler("Iterators#LifeTimeAccumulator", 100) { - tryFinish() - } - - override def hasNext: Boolean = { - val out = in.hasNext - if (!out) { - tryFinish() - } - out - } - - override def next(): A = { - in.next() - } - - private def tryFinish(): Unit = { - // pipeline metric should only be calculate once. - if (!closed.compareAndSet(false, true)) { - return - } - val lifeTime = TimeUnit.NANOSECONDS.toMillis( - System.nanoTime() - startTime - ) - onCollected(lifeTime) - } -} - -private class ReadTimeAccumulator[A](in: Iterator[A], onAdded: Long => Unit) extends Iterator[A] { - - override def hasNext: Boolean = { - val prev = System.nanoTime() - val out = in.hasNext - val after = System.nanoTime() - val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) - onAdded(duration) - out - } - - override def next(): A = { - val prev = System.nanoTime() - val out = in.next() - val after = System.nanoTime() - val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) - onAdded(duration) - out - } -} - -/** - * To protect the wrapped iterator to avoid undesired order of calls to its `hasNext` and `next` - * methods. - */ -private class InvocationFlowProtection[A](in: Iterator[A]) extends Iterator[A] { - sealed private trait State - private case object Init extends State - private case class HasNextCalled(hasNext: Boolean) extends State - private case object NextCalled extends State - - private var state: State = Init - - override def hasNext: Boolean = { - val out = state match { - case Init | NextCalled => - in.hasNext - case HasNextCalled(lastHasNext) => - lastHasNext - } - state = HasNextCalled(out) - out - } - - override def next(): A = { - val out = state match { - case Init | NextCalled => - if (!in.hasNext) { - throw new IllegalStateException("End of stream") - } - in.next() - case HasNextCalled(lastHasNext) => - if (!lastHasNext) { - throw new IllegalStateException("End of stream") - } - in.next() - } - state = NextCalled - out - } -} - -class WrapperBuilder[A](in: Iterator[A]) { // FIXME how to make the ctor companion-private? - private var wrapped: Iterator[A] = in - - def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] = { - wrapped = new PayloadCloser(wrapped)(closeCallback) - this - } - - def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] = { - wrapped = new IteratorCompleter(wrapped)(completionCallback) - this - } - - def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] = { - wrapped = new LifeTimeAccumulator[A](wrapped, onCollected) - this - } - - def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] = { - wrapped = new ReadTimeAccumulator[A](wrapped, onAdded) - this - } - - def asInterruptible(context: TaskContext): WrapperBuilder[A] = { - wrapped = new InterruptibleIterator[A](context, wrapped) - this - } - - def protectInvocationFlow(): WrapperBuilder[A] = { - wrapped = new InvocationFlowProtection[A](wrapped) - this - } - - def create(): Iterator[A] = { - wrapped - } -} - -/** - * Utility class to provide iterator wrappers for non-trivial use cases. E.g. iterators that manage - * payload's lifecycle. - */ -object Iterators { - def wrap[A](in: Iterator[A]): WrapperBuilder[A] = { - new WrapperBuilder[A](in) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala new file mode 100644 index 000000000000..eedfa66cfeaf --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.utils.iterator + +import org.apache.gluten.utils.iterator.IteratorsV1.WrapperBuilderV1 + +import org.apache.spark.TaskContext + +/** + * Utility class to provide iterator wrappers for non-trivial use cases. E.g. iterators that manage + * payload's lifecycle. + */ +object Iterators { + sealed trait Version + case object V1 extends Version + + private val DEFAULT_VERSION: Version = V1 + + trait WrapperBuilder[A] { + def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] + def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] + def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] + def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] + def asInterruptible(context: TaskContext): WrapperBuilder[A] + def protectInvocationFlow(): WrapperBuilder[A] + def create(): Iterator[A] + } + + def wrap[A](in: Iterator[A]): WrapperBuilder[A] = { + wrap(V1, in) + } + + def wrap[A](version: Version, in: Iterator[A]): WrapperBuilder[A] = { + version match { + case V1 => + new WrapperBuilderV1[A](in) + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala new file mode 100644 index 000000000000..3e9248c44458 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.utils.iterator + +import org.apache.gluten.utils.iterator.Iterators.WrapperBuilder + +import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.util.TaskResources + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +object IteratorsV1 { + private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit) extends Iterator[A] { + private var closer: Option[() => Unit] = None + + TaskResources.addRecycler("Iterators#PayloadCloser", 100) { + tryClose() + } + + override def hasNext: Boolean = { + tryClose() + in.hasNext + } + + override def next(): A = { + val a: A = in.next() + closer.synchronized { + closer = Some( + () => { + closeCallback.apply(a) + }) + } + a + } + + private def tryClose(): Unit = { + closer.synchronized { + closer match { + case Some(c) => c.apply() + case None => + } + closer = None // make sure the payload is closed once + } + } + } + + private class IteratorCompleter[A](in: Iterator[A])(completionCallback: => Unit) + extends Iterator[A] { + private val completed = new AtomicBoolean(false) + + TaskResources.addRecycler("Iterators#IteratorRecycler", 100) { + tryComplete() + } + + override def hasNext: Boolean = { + val out = in.hasNext + if (!out) { + tryComplete() + } + out + } + + override def next(): A = { + in.next() + } + + private def tryComplete(): Unit = { + if (!completed.compareAndSet(false, true)) { + return // make sure the iterator is completed once + } + completionCallback + } + } + + private class LifeTimeAccumulator[A](in: Iterator[A], onCollected: Long => Unit) + extends Iterator[A] { + private val closed = new AtomicBoolean(false) + private val startTime = System.nanoTime() + + TaskResources.addRecycler("Iterators#LifeTimeAccumulator", 100) { + tryFinish() + } + + override def hasNext: Boolean = { + val out = in.hasNext + if (!out) { + tryFinish() + } + out + } + + override def next(): A = { + in.next() + } + + private def tryFinish(): Unit = { + // pipeline metric should only be calculate once. + if (!closed.compareAndSet(false, true)) { + return + } + val lifeTime = TimeUnit.NANOSECONDS.toMillis( + System.nanoTime() - startTime + ) + onCollected(lifeTime) + } + } + + private class ReadTimeAccumulator[A](in: Iterator[A], onAdded: Long => Unit) extends Iterator[A] { + + override def hasNext: Boolean = { + val prev = System.nanoTime() + val out = in.hasNext + val after = System.nanoTime() + val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) + onAdded(duration) + out + } + + override def next(): A = { + val prev = System.nanoTime() + val out = in.next() + val after = System.nanoTime() + val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) + onAdded(duration) + out + } + } + + /** + * To protect the wrapped iterator to avoid undesired order of calls to its `hasNext` and `next` + * methods. + */ + private class InvocationFlowProtection[A](in: Iterator[A]) extends Iterator[A] { + sealed private trait State + private case object Init extends State + private case class HasNextCalled(hasNext: Boolean) extends State + private case object NextCalled extends State + + private var state: State = Init + + override def hasNext: Boolean = { + val out = state match { + case Init | NextCalled => + in.hasNext + case HasNextCalled(lastHasNext) => + lastHasNext + } + state = HasNextCalled(out) + out + } + + override def next(): A = { + val out = state match { + case Init | NextCalled => + if (!in.hasNext) { + throw new IllegalStateException("End of stream") + } + in.next() + case HasNextCalled(lastHasNext) => + if (!lastHasNext) { + throw new IllegalStateException("End of stream") + } + in.next() + } + state = NextCalled + out + } + } + + class WrapperBuilderV1[A] private[iterator] (in: Iterator[A]) extends WrapperBuilder[A] { + private var wrapped: Iterator[A] = in + + override def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] = { + wrapped = new PayloadCloser(wrapped)(closeCallback) + this + } + + override def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] = { + wrapped = new IteratorCompleter(wrapped)(completionCallback) + this + } + + override def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] = { + wrapped = new LifeTimeAccumulator[A](wrapped, onCollected) + this + } + + override def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] = { + wrapped = new ReadTimeAccumulator[A](wrapped, onAdded) + this + } + + override def asInterruptible(context: TaskContext): WrapperBuilder[A] = { + wrapped = new InterruptibleIterator[A](context, wrapped) + this + } + + override def protectInvocationFlow(): WrapperBuilder[A] = { + wrapped = new InvocationFlowProtection[A](wrapped) + this + } + + override def create(): Iterator[A] = { + wrapped + } + } +} diff --git a/gluten-core/src/test/scala/org/apache/gluten/utils/IteratorSuite.scala b/gluten-core/src/test/scala/org/apache/gluten/utils/iterator/IteratorSuite.scala similarity index 86% rename from gluten-core/src/test/scala/org/apache/gluten/utils/IteratorSuite.scala rename to gluten-core/src/test/scala/org/apache/gluten/utils/iterator/IteratorSuite.scala index 389e2adfefd4..1a84d671922d 100644 --- a/gluten-core/src/test/scala/org/apache/gluten/utils/IteratorSuite.scala +++ b/gluten-core/src/test/scala/org/apache/gluten/utils/iterator/IteratorSuite.scala @@ -14,18 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.utils +package org.apache.gluten.utils.iterator + +import org.apache.gluten.utils.iterator.Iterators.{V1, WrapperBuilder} import org.apache.spark.util.TaskResources import org.scalatest.funsuite.AnyFunSuite -class IteratorSuite extends AnyFunSuite { +class IteratorV1Suite extends IteratorSuite { + override protected def wrap[A](in: Iterator[A]): WrapperBuilder[A] = Iterators.wrap(V1, in) +} + +abstract class IteratorSuite extends AnyFunSuite { + protected def wrap[A](in: Iterator[A]): WrapperBuilder[A] + test("Trivial wrapping") { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .create() assertResult(strings) { wrapped.toArray @@ -37,8 +44,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .recycleIterator { completeCount += 1 } @@ -56,8 +62,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val _ = Iterators - .wrap(itr) + val _ = wrap(itr) .recycleIterator { completeCount += 1 } @@ -72,8 +77,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .recyclePayload { _: String => closeCount += 1 } .create() assertResult(strings) { @@ -89,8 +93,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .recyclePayload { _: String => closeCount += 1 } .create() assertResult(strings.take(2)) { @@ -115,8 +118,7 @@ class IteratorSuite extends AnyFunSuite { new Object } } - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .protectInvocationFlow() .create() wrapped.hasNext diff --git a/gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala b/gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala new file mode 100644 index 000000000000..aa69f309aac8 --- /dev/null +++ b/gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.utils.iterator + +import org.apache.gluten.utils.iterator.Iterators +import org.apache.gluten.utils.iterator.Iterators.V1 + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.util.TaskResources + +object IteratorBenchmark extends BenchmarkBase { + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Iterator Nesting") { + TaskResources.runUnsafe { + val nPayloads: Int = 50000000 // 50 millions + + def makeScalaIterator: Iterator[Any] = { + (0 until nPayloads).view.map { _: Int => new Object }.iterator + } + + def compareIterator(name: String)( + makeGlutenIterator: Iterators.Version => Iterator[Any]): Unit = { + val benchmark = new Benchmark(name, nPayloads, output = output) + benchmark.addCase("Scala Iterator") { + _ => + val count = makeScalaIterator.count(_ => true) + assert(count == nPayloads) + } + benchmark.addCase("Gluten Iterator V1") { + _ => + val count = makeGlutenIterator(V1).count(_ => true) + assert(count == nPayloads) + } + benchmark.run() + } + + compareIterator("0 Levels Nesting") { + version => + Iterators + .wrap(version, makeScalaIterator) + .create() + } + compareIterator("1 Levels Nesting - read") { + version => + Iterators + .wrap(version, makeScalaIterator) + .collectReadMillis { _ => } + .create() + } + compareIterator("5 Levels Nesting - read") { + version => + Iterators + .wrap(version, makeScalaIterator) + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .create() + } + compareIterator("10 Levels Nesting - read") { + version => + Iterators + .wrap(version, makeScalaIterator) + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .create() + } + compareIterator("1 Levels Nesting - recycle") { + version => + Iterators + .wrap(version, makeScalaIterator) + .recycleIterator {} + .create() + } + compareIterator("5 Levels Nesting - recycle") { + version => + Iterators + .wrap(version, makeScalaIterator) + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .create() + } + compareIterator("10 Levels Nesting - recycle") { + version => + Iterators + .wrap(version, makeScalaIterator) + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .create() + } + } + } + } +} diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 9d9f5ab1765c..840f8618b0b4 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -21,7 +21,8 @@ import org.apache.gluten.exec.Runtimes import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers import org.apache.gluten.sql.shims.SparkShimLoader -import org.apache.gluten.utils.{ArrowAbiUtil, Iterators} +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} import org.apache.spark.sql.catalyst.InternalRow diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala index 083915f12db9..090b8fa2562a 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.utils import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.{ArrowWritableColumnVector, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper, NativePartitioning} import org.apache.spark.{Partitioner, RangePartitioner, ShuffleDependency} From 9ae34a91379fc833c3db873292e53e589be9e62b Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 26 Jun 2024 15:17:12 +0800 Subject: [PATCH 16/30] [GLUTEN-5643] Fix the failure when the pre-project of GenerateExec falls back (#6167) --- .../velox/VeloxSparkPlanExecApi.scala | 4 +- .../gluten/expression/DummyExpression.scala | 77 +++++++++++++++++++ .../spark/sql/expression/UDFResolver.scala | 5 +- .../gluten/execution/TestOperator.scala | 24 +++++- cpp/velox/substrait/SubstraitToVeloxPlan.cc | 38 +++++---- .../expression/ExpressionConverter.scala | 2 +- 6 files changed, 131 insertions(+), 19 deletions(-) create mode 100644 backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 7b8d523a6d27..b48da15683e8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -852,7 +852,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG), Sig[TransformKeys](TRANSFORM_KEYS), - Sig[TransformValues](TRANSFORM_VALUES) + Sig[TransformValues](TRANSFORM_VALUES), + // For test purpose. + Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION) ) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala new file mode 100644 index 000000000000..e2af66b599d3 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.expression + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +abstract class DummyExpression(child: Expression) extends UnaryExpression with Serializable { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, c => c) + + override def dataType: DataType = child.dataType + + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, "The input row of DummyExpression should have only 1 field.") + accessor(input, 0) + } +} + +// Can be used as a wrapper to force fall back the original expression to mock the fallback behavior +// of an supported expression in Gluten which fails native validation. +case class VeloxDummyExpression(child: Expression) + extends DummyExpression(child) + with Transformable { + override def getTransformer( + childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer = { + if (childrenTransformers.size != children.size) { + throw new IllegalStateException( + this.getClass.getSimpleName + + ": getTransformer called before children transformer initialized.") + } + + GenericExpressionTransformer( + VeloxDummyExpression.VELOX_DUMMY_EXPRESSION, + childrenTransformers, + this) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +} + +object VeloxDummyExpression { + val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression" + + private val identifier = new FunctionIdentifier(VELOX_DUMMY_EXPRESSION) + + def registerFunctions(registry: FunctionRegistry): Unit = { + registry.registerFunction( + identifier, + new ExpressionInfo(classOf[VeloxDummyExpression].getName, VELOX_DUMMY_EXPRESSION), + (e: Seq[Expression]) => VeloxDummyExpression(e.head) + ) + } + + def unregisterFunctions(registry: FunctionRegistry): Unit = { + registry.dropFunction(identifier) + } +} diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index 915fc554584c..e45e8b6fa6d7 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -94,7 +94,8 @@ case class UDFExpression( dataType: DataType, nullable: Boolean, children: Seq[Expression]) - extends Transformable { + extends Unevaluable + with Transformable { override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { this.copy(children = newChildren) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a892b6f313a4..9b47a519cd28 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.datasource.ArrowCSVFileFormat import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec +import org.apache.gluten.expression.VeloxDummyExpression import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf @@ -45,6 +46,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla override def beforeAll(): Unit = { super.beforeAll() createTPCHNotNullTables() + VeloxDummyExpression.registerFunctions(spark.sessionState.functionRegistry) + } + + override def afterAll(): Unit = { + VeloxDummyExpression.unregisterFunctions(spark.sessionState.functionRegistry) + super.afterAll() } override protected def sparkConf: SparkConf = { @@ -66,14 +73,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla test("select_part_column") { val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem limit 1") { - df => { assert(df.schema.fields.length == 2) } + df => + { + assert(df.schema.fields.length == 2) + } } checkLengthAndPlan(df, 1) } test("select_as") { val df = runQueryAndCompare("select l_shipdate as my_col from lineitem limit 1") { - df => { assert(df.schema.fieldNames(0).equals("my_col")) } + df => + { + assert(df.schema.fieldNames(0).equals("my_col")) + } } checkLengthAndPlan(df, 1) } @@ -1074,6 +1087,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla // No ProjectExecTransformer is introduced. checkSparkOperatorChainMatch[GenerateExecTransformer, FilterExecTransformer] } + + runQueryAndCompare( + s""" + |SELECT $func(${VeloxDummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2; + |""".stripMargin) { + checkGlutenOperatorMatch[GenerateExecTransformer] + } } } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 8b8a9262403c..73047b2f4907 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -26,6 +26,7 @@ #include "utils/ConfigExtractor.h" #include "config/GlutenConfig.h" +#include "operators/plannodes/RowVectorStream.h" namespace gluten { namespace { @@ -710,16 +711,23 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: namespace { void extractUnnestFieldExpr( - std::shared_ptr projNode, + std::shared_ptr child, int32_t index, std::vector& unnestFields) { - auto name = projNode->names()[index]; - auto expr = projNode->projections()[index]; - auto type = expr->type(); + if (auto projNode = std::dynamic_pointer_cast(child)) { + auto name = projNode->names()[index]; + auto expr = projNode->projections()[index]; + auto type = expr->type(); - auto unnestFieldExpr = std::make_shared(type, name); - VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field"); - unnestFields.emplace_back(unnestFieldExpr); + auto unnestFieldExpr = std::make_shared(type, name); + VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field"); + unnestFields.emplace_back(unnestFieldExpr); + } else { + auto name = child->outputType()->names()[index]; + auto field = child->outputType()->childAt(index); + auto unnestFieldExpr = std::make_shared(field, name); + unnestFields.emplace_back(unnestFieldExpr); + } } } // namespace @@ -752,10 +760,13 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "injectedProject="); if (injectedProject) { - auto projNode = std::dynamic_pointer_cast(childNode); + // Child should be either ProjectNode or ValueStreamNode in case of project fallback. VELOX_CHECK( - projNode != nullptr && projNode->names().size() > requiredChildOutput.size(), - "injectedProject is true, but the Project is missing or does not have the corresponding projection field") + (std::dynamic_pointer_cast(childNode) != nullptr || + std::dynamic_pointer_cast(childNode) != nullptr) && + childNode->outputType()->size() > requiredChildOutput.size(), + "injectedProject is true, but the ProjectNode or ValueStreamNode (in case of projection fallback)" + " is missing or does not have the corresponding projection field") bool isStack = generateRel.has_advanced_extension() && SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "isStack="); @@ -768,7 +779,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: // +- Project [fake_column#128, [1,2,3] AS _pre_0#129] // +- RewrittenNodeWall Scan OneRowRelation[fake_column#128] // The last projection column in GeneratorRel's child(Project) is the column we need to unnest - extractUnnestFieldExpr(projNode, projNode->projections().size() - 1, unnest); + auto index = childNode->outputType()->size() - 1; + extractUnnestFieldExpr(childNode, index, unnest); } else { // For stack function, e.g. stack(2, 1,2,3), a sample // input substrait plan is like the following: @@ -782,10 +794,10 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: auto generatorFunc = generator.scalar_function(); auto numRows = SubstraitParser::getLiteralValue(generatorFunc.arguments(0).value().literal()); auto numFields = static_cast(std::ceil((generatorFunc.arguments_size() - 1.0) / numRows)); - auto totalProjectCount = projNode->names().size(); + auto totalProjectCount = childNode->outputType()->size(); for (auto i = totalProjectCount - numFields; i < totalProjectCount; ++i) { - extractUnnestFieldExpr(projNode, i, unnest); + extractUnnestFieldExpr(childNode, i, unnest); } } } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b7b0889dc1eb..da5625cd45e5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -trait Transformable extends Unevaluable { +trait Transformable { def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer } From e30006464e507744a7e433718f5778bb2d58856f Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 26 Jun 2024 15:36:36 +0800 Subject: [PATCH 17/30] [VL] Daily Update Velox Version (2024_06_26) (#6223) --- ep/build-velox/src/get_velox.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index 06998787d45e..a96719dc10fc 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_06_25 +VELOX_BRANCH=2024_06_26 VELOX_HOME="" #Set on run gluten on HDFS From a51f6931007256eec10b1a7ae69ef39554ce52f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Wed, 26 Jun 2024 15:51:18 +0800 Subject: [PATCH 18/30] [GLUTEN-6208][CH] Enable more uts in GlutenStringExpressionsSuite (#6218) --- .../Parser/SerializedPlanParser.cpp | 13 --- .../Parser/SerializedPlanParser.h | 2 - .../Parser/scalar_function_parser/concat.cpp | 79 +++++++++++++++++++ .../clickhouse/ClickHouseTestSettings.scala | 9 --- .../clickhouse/ClickHouseTestSettings.scala | 61 -------------- .../clickhouse/ClickHouseTestSettings.scala | 12 --- .../clickhouse/ClickHouseTestSettings.scala | 12 --- 7 files changed, 79 insertions(+), 109 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 3115950cdf09..325ec32dc65f 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -664,19 +664,6 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co else ch_function_name = "reverseUTF8"; } - else if (function_name == "concat") - { - /// 1. ConcatOverloadResolver cannot build arrayConcat for Nullable(Array) type which causes failures when using functions like concat(split()). - /// So we use arrayConcat directly if the output type is array. - /// 2. CH ConcatImpl can only accept at least 2 arguments, but Spark concat can accept 1 argument, like concat('a') - /// in such case we use identity function - if (function.output_type().has_list()) - ch_function_name = "arrayConcat"; - else if (args.size() == 1) - ch_function_name = "identity"; - else - ch_function_name = "concat"; - } else ch_function_name = SCALAR_FUNCTIONS.at(function_name); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index aa18197e5647..6ce92b558b73 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -127,13 +127,11 @@ static const std::map SCALAR_FUNCTIONS {"trim", ""}, // trimLeft or trimLeftSpark, depends on argument size {"ltrim", ""}, // trimRight or trimRightSpark, depends on argument size {"rtrim", ""}, // trimBoth or trimBothSpark, depends on argument size - {"concat", ""}, /// dummy mapping {"strpos", "positionUTF8"}, {"char_length", "char_length"}, /// Notice: when input argument is binary type, corresponding ch function is length instead of char_length {"replace", "replaceAll"}, {"regexp_replace", "replaceRegexpAll"}, - // {"regexp_extract", "regexpExtract"}, {"regexp_extract_all", "regexpExtractAllSpark"}, {"chr", "char"}, {"rlike", "match"}, diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp new file mode 100644 index 000000000000..416fe7741812 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +class FunctionParserConcat : public FunctionParser +{ +public: + explicit FunctionParserConcat(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserConcat() override = default; + + static constexpr auto name = "concat"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse( + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override + { + /* + parse concat(args) as: + 1. if output type is array, return arrayConcat(args) + 2. otherwise: + 1) if args is empty, return empty string + 2) if args have size 1, return identity(args[0]) + 3) otherwise return concat(args) + */ + auto args = parseFunctionArguments(substrait_func, "", actions_dag); + const auto & output_type = substrait_func.output_type(); + const ActionsDAG::Node * result_node = nullptr; + if (output_type.has_list()) + { + result_node = toFunctionNode(actions_dag, "arrayConcat", args); + } + else + { + if (args.empty()) + result_node = addColumnToActionsDAG(actions_dag, std::make_shared(), ""); + else if (args.size() == 1) + result_node = toFunctionNode(actions_dag, "identity", args); + else + result_node = toFunctionNode(actions_dag, "concat", args); + } + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); + } +}; + +static FunctionParserRegister register_concat; +} diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 1626716805cb..d12a40b764f8 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -437,7 +437,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -894,7 +893,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-34814: LikeSimplification should handle NULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -902,22 +900,15 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string for ascii") .exclude("base64/unbase64 for string") .exclude("encode/decode for string") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") .excludeGlutenTest("SPARK-40213: ascii for Latin-1 Supplement characters") enableSuite[GlutenTryCastSuite] .exclude("null cast") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 3147c7c3dbf3..52e7ebcbda49 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -458,7 +458,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -474,58 +473,9 @@ class ClickHouseTestSettings extends BackendTestSettings { enableSuite[GlutenXPathFunctionsSuite] enableSuite[QueryTestSuite] enableSuite[GlutenAnsiCastSuiteWithAnsiModeOff] - .exclude("null cast") .exclude("cast string to date") - .exclude("cast string to timestamp") - .exclude("cast from boolean") - .exclude("cast from int") - .exclude("cast from long") - .exclude("cast from float") - .exclude("cast from double") - .exclude("cast from timestamp") - .exclude("data type casting") - .exclude("cast and add") - .exclude("from decimal") - .exclude("cast from array") - .exclude("cast from map") - .exclude("cast from struct") - .exclude("cast struct with a timestamp field") - .exclude("cast between string and interval") - .exclude("cast string to boolean") - .exclude("SPARK-20302 cast with same structure") - .exclude("SPARK-22500: cast for struct should not generate codes beyond 64KB") - .exclude("SPARK-27671: cast from nested null type in struct") - .exclude("Process Infinity, -Infinity, NaN in case insensitive manner") - .exclude("SPARK-22825 Cast array to string") - .exclude("SPARK-33291: Cast array with null elements to string") - .exclude("SPARK-22973 Cast map to string") - .exclude("SPARK-22981 Cast struct to string") - .exclude("SPARK-33291: Cast struct with null elements to string") - .exclude("SPARK-34667: cast year-month interval to string") - .exclude("SPARK-34668: cast day-time interval to string") - .exclude("SPARK-35698: cast timestamp without time zone to string") .exclude("SPARK-35711: cast timestamp without time zone to timestamp with local time zone") - .exclude("SPARK-35716: cast timestamp without time zone to date type") - .exclude("SPARK-35718: cast date type to timestamp without timezone") - .exclude("SPARK-35719: cast timestamp with local time zone to timestamp without timezone") - .exclude("SPARK-35720: cast string to timestamp without timezone") - .exclude("SPARK-35112: Cast string to day-time interval") - .exclude("SPARK-35111: Cast string to year-month interval") - .exclude("SPARK-35820: Support cast DayTimeIntervalType in different fields") .exclude("SPARK-35819: Support cast YearMonthIntervalType in different fields") - .exclude("SPARK-35768: Take into account year-month interval fields in cast") - .exclude("SPARK-35735: Take into account day-time interval fields in cast") - .exclude("ANSI mode: Throw exception on casting out-of-range value to byte type") - .exclude("ANSI mode: Throw exception on casting out-of-range value to short type") - .exclude("ANSI mode: Throw exception on casting out-of-range value to int type") - .exclude("ANSI mode: Throw exception on casting out-of-range value to long type") - .exclude("Fast fail for cast string type to decimal type in ansi mode") - .exclude("cast a timestamp before the epoch 1970-01-01 00:00:00Z") - .exclude("cast from array III") - .exclude("cast from map II") - .exclude("cast from map III") - .exclude("cast from struct II") - .exclude("cast from struct III") enableSuite[GlutenAnsiCastSuiteWithAnsiModeOn] .exclude("null cast") .exclude("cast string to date") @@ -902,7 +852,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK - 34814: LikeSimplification should handleNULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -911,24 +860,14 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string for ascii") .exclude("base64/unbase64 for string") .exclude("encode/decode for string") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") - .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") - .exclude("ToNumber: positive tests") - .exclude("ToNumber: negative tests (the input string does not match the format string)") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") enableSuite[GlutenTryCastSuite] .exclude("null cast") .exclude("cast string to date") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 07af1fa845ca..38ed2c53463b 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -457,7 +457,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -756,7 +755,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK - 34814: LikeSimplification should handleNULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -766,24 +764,14 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("base64/unbase64 for string") .exclude("encode/decode for string") .exclude("Levenshtein distance") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") - .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") - .exclude("ToNumber: positive tests") - .exclude("ToNumber: negative tests (the input string does not match the format string)") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSuiteV1Filter] diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 07af1fa845ca..38ed2c53463b 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -457,7 +457,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -756,7 +755,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK - 34814: LikeSimplification should handleNULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -766,24 +764,14 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("base64/unbase64 for string") .exclude("encode/decode for string") .exclude("Levenshtein distance") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") - .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") - .exclude("ToNumber: positive tests") - .exclude("ToNumber: negative tests (the input string does not match the format string)") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSuiteV1Filter] From 774c66830ba813a5c6231cb1dd504cdd0c862e75 Mon Sep 17 00:00:00 2001 From: KevinyhZou <37431499+KevinyhZou@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:59:48 +0800 Subject: [PATCH 19/30] [GLUTEN-6124][CH]Fix json output diff (#6125) What changes were proposed in this pull request? (Please fill in changes proposed in this fix) (Fixes: #6124) How was this patch tested? TEST BY UT --- .../execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala | 5 +++++ cpp-ch/local-engine/Common/CHUtil.cpp | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 5040153320fc..118f8418609d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2048,10 +2048,15 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr """ |select to_json(struct(cast(id as string), id, 1.1, 1.1f, 1.1d)) from range(3) |""".stripMargin + val sql1 = + """ + | select to_json(named_struct('name', concat('/val/', id))) from range(3) + |""".stripMargin // cast('nan' as double) output 'NaN' in Spark, 'nan' in CH // cast('inf' as double) output 'Infinity' in Spark, 'inf' in CH // ignore them temporarily runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + runQueryAndCompare(sql1)(checkGlutenOperatorMatch[ProjectExecTransformer]) } test("GLUTEN-3501: test json output format with struct contains null value") { diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 588cc1cb2599..148e78bfbc79 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -672,7 +672,6 @@ void BackendInitializerUtil::initSettings(std::map & b LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", key, value); } } - /// Finally apply some fixed kvs to settings. settings.set("join_use_nulls", true); settings.set("input_format_orc_allow_missing_columns", true); @@ -694,6 +693,7 @@ void BackendInitializerUtil::initSettings(std::map & b settings.set("output_format_json_quote_64bit_integers", false); settings.set("output_format_json_quote_denormals", true); settings.set("output_format_json_skip_null_value_in_named_tuples", true); + settings.set("output_format_json_escape_forward_slashes", false); settings.set("function_json_value_return_type_allow_complex", true); settings.set("function_json_value_return_type_allow_nullable", true); settings.set("precise_float_parsing", true); From 10a663c2b86c73490cdaee1d94177cb485c9fe31 Mon Sep 17 00:00:00 2001 From: KevinyhZou <37431499+KevinyhZou@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:02:00 +0800 Subject: [PATCH 20/30] [GLUTEN-6156][CH]Fix least diff (#6155) What changes were proposed in this pull request? (Please fill in changes proposed in this fix) (Fixes: #6156) How was this patch tested? test by ut --- ...enClickHouseTPCHSaltNullParquetSuite.scala | 4 +- .../Functions/FunctionGreatestLeast.h | 77 +++++++++++++++++++ .../Functions/SparkFunctionGreatest.cpp | 47 ++--------- .../Functions/SparkFunctionLeast.cpp | 38 +++++++++ .../Parser/SerializedPlanParser.h | 2 +- 5 files changed, 123 insertions(+), 45 deletions(-) create mode 100644 cpp-ch/local-engine/Functions/FunctionGreatestLeast.h create mode 100644 cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 118f8418609d..188995f11058 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2575,12 +2575,12 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_5096") } - test("GLUTEN-5896: Bug fix greatest diff") { + test("GLUTEN-5896: Bug fix greatest/least diff") { val tbl_create_sql = "create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using parquet" val tbl_insert_sql = "insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)" - val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896" + val select_sql = "select id, greatest(x1, x2, x3), least(x1, x2, x3) from test_tbl_5896" spark.sql(tbl_create_sql) spark.sql(tbl_insert_sql) compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) diff --git a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h new file mode 100644 index 000000000000..6930c1d75b79 --- /dev/null +++ b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} +namespace local_engine +{ +template +class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric +{ +public: + bool useDefaultImplementationForNulls() const override { return false; } + virtual String getName() const = 0; + +private: + DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override + { + if (types.empty()) + throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", getName()); + return makeNullable(getLeastSupertype(types)); + } + + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + { + size_t num_arguments = arguments.size(); + DB::Columns converted_columns(num_arguments); + for (size_t arg = 0; arg < num_arguments; ++arg) + converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); + auto result_column = result_type->createColumn(); + result_column->reserve(input_rows_count); + for (size_t row_num = 0; row_num < input_rows_count; ++row_num) + { + size_t best_arg = 0; + for (size_t arg = 1; arg < num_arguments; ++arg) + { + if constexpr (kind == DB::LeastGreatest::Greatest) + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); + if (cmp_result > 0) + best_arg = arg; + } + else + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1); + if (cmp_result < 0) + best_arg = arg; + } + } + result_column->insertFrom(*converted_columns[best_arg], row_num); + } + return result_column; + } +}; + +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp index 9577d65ec5f7..920fe1b9c9cc 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp @@ -14,58 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} -} +#include namespace local_engine { -class SparkFunctionGreatest : public DB::FunctionLeastGreatestGeneric +class SparkFunctionGreatest : public FunctionGreatestestLeast { public: static constexpr auto name = "sparkGreatest"; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } SparkFunctionGreatest() = default; ~SparkFunctionGreatest() override = default; - bool useDefaultImplementationForNulls() const override { return false; } - -private: - DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override - { - if (types.empty()) - throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", name); - return makeNullable(getLeastSupertype(types)); - } - - DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + String getName() const override { - size_t num_arguments = arguments.size(); - DB::Columns converted_columns(num_arguments); - for (size_t arg = 0; arg < num_arguments; ++arg) - converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); - auto result_column = result_type->createColumn(); - result_column->reserve(input_rows_count); - for (size_t row_num = 0; row_num < input_rows_count; ++row_num) - { - size_t best_arg = 0; - for (size_t arg = 1; arg < num_arguments; ++arg) - { - auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); - if (cmp_result > 0) - best_arg = arg; - } - result_column->insertFrom(*converted_columns[best_arg], row_num); - } - return result_column; - } + return name; + } }; REGISTER_FUNCTION(SparkGreatest) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp new file mode 100644 index 000000000000..70aafdf07209 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace local_engine +{ +class SparkFunctionLeast : public FunctionGreatestestLeast +{ +public: + static constexpr auto name = "sparkLeast"; + static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } + SparkFunctionLeast() = default; + ~SparkFunctionLeast() override = default; + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(SparkLeast) +{ + factory.registerFunction(); +} +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 6ce92b558b73..184065836e65 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -105,7 +105,7 @@ static const std::map SCALAR_FUNCTIONS {"sign", "sign"}, {"radians", "radians"}, {"greatest", "sparkGreatest"}, - {"least", "least"}, + {"least", "sparkLeast"}, {"shiftleft", "bitShiftLeft"}, {"shiftright", "bitShiftRight"}, {"check_overflow", "checkDecimalOverflowSpark"}, From d91a316c3e981b3b67bf4316d42fce29dff2708d Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 26 Jun 2024 19:30:52 +0800 Subject: [PATCH 21/30] [VL][Minor] Fix udf jni signature mismatch (#6212) * fix udf library path failed to get resolved on yarn-cluster * fix signature * Revert "fix udf library path failed to get resolved on yarn-cluster" This reverts commit 11f774a1cc03ff3c0152dfbcee7aaa450bb6157c. --- .../src/main/java/org/apache/gluten/udf/UdfJniWrapper.java | 4 +--- .../scala/org/apache/spark/sql/expression/UDFResolver.scala | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java b/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java index 4b609769b2ab..8bfe8bad5c01 100644 --- a/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java +++ b/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java @@ -18,7 +18,5 @@ public class UdfJniWrapper { - public UdfJniWrapper() {} - - public native void getFunctionSignatures(); + public static native void getFunctionSignatures(); } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index e45e8b6fa6d7..8a549c9b4ea9 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -327,7 +327,7 @@ object UDFResolver extends Logging { case None => Seq.empty case Some(_) => - new UdfJniWrapper().getFunctionSignatures() + UdfJniWrapper.getFunctionSignatures() UDFNames.map { name => From 0800596b5d23c55443fd74fa6461c5264619c6b4 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 26 Jun 2024 20:01:17 +0800 Subject: [PATCH 22/30] [VL] Make jni debug workspace configurable (#6228) --- .../gluten/backendsapi/velox/VeloxListenerApi.scala | 7 +++---- .../org/apache/gluten/vectorized/JniWorkspace.java | 4 ++-- .../main/scala/org/apache/gluten/GlutenConfig.scala | 12 +++++++++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala index 41b56804b50b..81f06478cbb6 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala @@ -161,10 +161,9 @@ class VeloxListenerApi extends ListenerApi { private def initialize(conf: SparkConf, isDriver: Boolean): Unit = { SparkDirectoryUtil.init(conf) UDFResolver.resolveUdfConf(conf, isDriver = isDriver) - val debugJni = conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_MODE, defaultValue = false) && - conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, defaultValue = false) - if (debugJni) { - JniWorkspace.enableDebug() + if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, defaultValue = false)) { + val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR) + JniWorkspace.enableDebug(debugDir) } val loader = JniWorkspace.getDefault.libLoader diff --git a/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java b/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java index a7c12387a221..810c945d35ab 100644 --- a/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java +++ b/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java @@ -75,14 +75,14 @@ private static JniWorkspace createDefault() { } } - public static void enableDebug() { + public static void enableDebug(String debugDir) { // Preserve the JNI libraries even after process exits. // This is useful for debugging native code if the debug symbols were embedded in // the libraries. synchronized (DEFAULT_INSTANCE_INIT_LOCK) { if (DEBUG_INSTANCE == null) { final File tempRoot = - Paths.get("/tmp").resolve("gluten-jni-debug-" + UUID.randomUUID()).toFile(); + Paths.get(debugDir).resolve("gluten-jni-debug-" + UUID.randomUUID()).toFile(); try { FileUtils.forceMkdir(tempRoot); } catch (IOException e) { diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index cc2d6ac5fdef..89933cc58a4d 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -391,8 +391,7 @@ class GlutenConfig(conf: SQLConf) extends Logging { conf.getConf(COLUMNAR_VELOX_MEMORY_USE_HUGE_PAGES) def debug: Boolean = conf.getConf(DEBUG_ENABLED) - def debugKeepJniWorkspace: Boolean = - conf.getConf(DEBUG_ENABLED) && conf.getConf(DEBUG_KEEP_JNI_WORKSPACE) + def debugKeepJniWorkspace: Boolean = conf.getConf(DEBUG_KEEP_JNI_WORKSPACE) def taskStageId: Int = conf.getConf(BENCHMARK_TASK_STAGEID) def taskPartitionId: Int = conf.getConf(BENCHMARK_TASK_PARTITIONID) def taskId: Long = conf.getConf(BENCHMARK_TASK_TASK_ID) @@ -553,6 +552,7 @@ object GlutenConfig { val GLUTEN_DEBUG_MODE = "spark.gluten.sql.debug" val GLUTEN_DEBUG_KEEP_JNI_WORKSPACE = "spark.gluten.sql.debug.keepJniWorkspace" + val GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR = "spark.gluten.sql.debug.keepJniWorkspaceDir" // Added back to Spark Conf during executor initialization val GLUTEN_NUM_TASK_SLOTS_PER_EXECUTOR_KEY = "spark.gluten.numTaskSlotsPerExecutor" @@ -1580,11 +1580,17 @@ object GlutenConfig { .createWithDefault(false) val DEBUG_KEEP_JNI_WORKSPACE = - buildConf(GLUTEN_DEBUG_KEEP_JNI_WORKSPACE) + buildStaticConf(GLUTEN_DEBUG_KEEP_JNI_WORKSPACE) .internal() .booleanConf .createWithDefault(false) + val DEBUG_KEEP_JNI_WORKSPACE_DIR = + buildStaticConf(GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR) + .internal() + .stringConf + .createWithDefault("/tmp") + val BENCHMARK_TASK_STAGEID = buildConf("spark.gluten.sql.benchmark_task.stageId") .internal() From dc6abe54a246a1f789bfdd54b2bd7c1f2bf239ab Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Thu, 27 Jun 2024 09:02:35 +0800 Subject: [PATCH 23/30] [VL] Link lib jemalloc produced by custom building (#4747) Co-authored-by: BInwei Yang --- cpp/CMake/Buildjemalloc_pic.cmake | 74 +++++++++++++++++ cpp/CMake/Findjemalloc_pic.cmake | 78 +++++------------- cpp/core/CMakeLists.txt | 10 --- cpp/velox/CMakeLists.txt | 11 +++ cpp/velox/memory/VeloxMemoryManager.cc | 7 ++ dev/vcpkg/CONTRIBUTING.md | 6 +- .../ports/jemalloc/fix-configure-ac.patch | 13 +++ dev/vcpkg/ports/jemalloc/portfile.cmake | 79 +++++++++++++++++++ dev/vcpkg/ports/jemalloc/preprocessor.patch | 12 +++ dev/vcpkg/ports/jemalloc/vcpkg.json | 8 ++ docs/get-started/build-guide.md | 2 +- 11 files changed, 226 insertions(+), 74 deletions(-) create mode 100644 cpp/CMake/Buildjemalloc_pic.cmake create mode 100644 dev/vcpkg/ports/jemalloc/fix-configure-ac.patch create mode 100644 dev/vcpkg/ports/jemalloc/portfile.cmake create mode 100644 dev/vcpkg/ports/jemalloc/preprocessor.patch create mode 100644 dev/vcpkg/ports/jemalloc/vcpkg.json diff --git a/cpp/CMake/Buildjemalloc_pic.cmake b/cpp/CMake/Buildjemalloc_pic.cmake new file mode 100644 index 000000000000..7c2316ea9540 --- /dev/null +++ b/cpp/CMake/Buildjemalloc_pic.cmake @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Building Jemalloc +macro(build_jemalloc) + message(STATUS "Building Jemalloc from Source") + + if(DEFINED ENV{GLUTEN_JEMALLOC_URL}) + set(JEMALLOC_SOURCE_URL "$ENV{GLUTEN_JEMALLOC_URL}") + else() + set(JEMALLOC_BUILD_VERSION "5.2.1") + set(JEMALLOC_SOURCE_URL + "https://github.com/jemalloc/jemalloc/releases/download/${JEMALLOC_BUILD_VERSION}/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" + "https://github.com/ursa-labs/thirdparty/releases/download/latest/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" + ) + endif() + + set(JEMALLOC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-install") + set(JEMALLOC_LIB_DIR "${JEMALLOC_PREFIX}/lib") + set(JEMALLOC_INCLUDE_DIR "${JEMALLOC_PREFIX}/include") + set(JEMALLOC_STATIC_LIB + "${JEMALLOC_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}jemalloc_pic${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + set(JEMALLOC_INCLUDE "${JEMALLOC_PREFIX}/include") + set(JEMALLOC_CONFIGURE_ARGS + "AR=${CMAKE_AR}" + "CC=${CMAKE_C_COMPILER}" + "--prefix=${JEMALLOC_PREFIX}" + "--libdir=${JEMALLOC_LIB_DIR}" + "--with-jemalloc-prefix=je_gluten_" + "--with-private-namespace=je_gluten_private_" + "--without-export" + "--disable-shared" + "--disable-cxx" + "--disable-libdl" + # For fixing an issue when loading native lib: cannot allocate memory in + # static TLS block. + "--disable-initial-exec-tls" + "CFLAGS=-fPIC" + "CXXFLAGS=-fPIC") + set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) + ExternalProject_Add( + jemalloc_ep + URL ${JEMALLOC_SOURCE_URL} + PATCH_COMMAND touch doc/jemalloc.3 doc/jemalloc.html + CONFIGURE_COMMAND "./configure" ${JEMALLOC_CONFIGURE_ARGS} + BUILD_COMMAND ${JEMALLOC_BUILD_COMMAND} + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS "${JEMALLOC_STATIC_LIB}" + INSTALL_COMMAND make install) + + file(MAKE_DIRECTORY "${JEMALLOC_INCLUDE_DIR}") + add_library(jemalloc::libjemalloc STATIC IMPORTED) + set_target_properties( + jemalloc::libjemalloc + PROPERTIES INTERFACE_LINK_LIBRARIES Threads::Threads + IMPORTED_LOCATION "${JEMALLOC_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") + add_dependencies(jemalloc::libjemalloc jemalloc_ep) +endmacro() diff --git a/cpp/CMake/Findjemalloc_pic.cmake b/cpp/CMake/Findjemalloc_pic.cmake index fae9f0d7ad80..ca7b7d213dfc 100644 --- a/cpp/CMake/Findjemalloc_pic.cmake +++ b/cpp/CMake/Findjemalloc_pic.cmake @@ -17,67 +17,25 @@ # Find Jemalloc macro(find_jemalloc) - # Find the existing Protobuf + # Find the existing jemalloc set(CMAKE_FIND_LIBRARY_SUFFIXES ".a") - find_package(jemalloc_pic) - if("${Jemalloc_LIBRARY}" STREQUAL "Jemalloc_LIBRARY-NOTFOUND") - message(FATAL_ERROR "Jemalloc Library Not Found") - endif() - set(PROTOC_BIN ${Jemalloc_PROTOC_EXECUTABLE}) -endmacro() - -# Building Jemalloc -macro(build_jemalloc) - message(STATUS "Building Jemalloc from Source") - - if(DEFINED ENV{GLUTEN_JEMALLOC_URL}) - set(JEMALLOC_SOURCE_URL "$ENV{GLUTEN_JEMALLOC_URL}") + # Find from vcpkg-installed lib path. + find_library( + JEMALLOC_LIBRARY + NAMES jemalloc_pic + PATHS + ${CMAKE_CURRENT_BINARY_DIR}/../../../dev/vcpkg/vcpkg_installed/x64-linux-avx/lib/ + NO_DEFAULT_PATH) + if("${JEMALLOC_LIBRARY}" STREQUAL "JEMALLOC_LIBRARY-NOTFOUND") + message(STATUS "Jemalloc Library Not Found.") + set(JEMALLOC_NOT_FOUND TRUE) else() - set(JEMALLOC_BUILD_VERSION "5.2.1") - set(JEMALLOC_SOURCE_URL - "https://github.com/jemalloc/jemalloc/releases/download/${JEMALLOC_BUILD_VERSION}/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" - "https://github.com/ursa-labs/thirdparty/releases/download/latest/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" - ) + message(STATUS "Found jemalloc: ${JEMALLOC_LIBRARY}") + find_path(JEMALLOC_INCLUDE_DIR jemalloc/jemalloc.h) + add_library(jemalloc::libjemalloc STATIC IMPORTED) + set_target_properties( + jemalloc::libjemalloc + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}" + IMPORTED_LOCATION "${JEMALLOC_LIBRARY}") endif() - - set(JEMALLOC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-install") - set(JEMALLOC_LIB_DIR "${JEMALLOC_PREFIX}/lib") - set(JEMALLOC_INCLUDE_DIR "${JEMALLOC_PREFIX}/include") - set(JEMALLOC_STATIC_LIB - "${JEMALLOC_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}jemalloc_pic${CMAKE_STATIC_LIBRARY_SUFFIX}" - ) - set(JEMALLOC_INCLUDE "${JEMALLOC_PREFIX}/include") - set(JEMALLOC_CONFIGURE_ARGS - "AR=${CMAKE_AR}" - "CC=${CMAKE_C_COMPILER}" - "--prefix=${JEMALLOC_PREFIX}" - "--libdir=${JEMALLOC_LIB_DIR}" - "--with-jemalloc-prefix=je_gluten_" - "--with-private-namespace=je_gluten_private_" - "--without-export" - "--disable-shared" - "--disable-cxx" - "--disable-libdl" - "--disable-initial-exec-tls" - "CFLAGS=-fPIC" - "CXXFLAGS=-fPIC") - set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) - ExternalProject_Add( - jemalloc_ep - URL ${JEMALLOC_SOURCE_URL} - PATCH_COMMAND touch doc/jemalloc.3 doc/jemalloc.html - CONFIGURE_COMMAND "./configure" ${JEMALLOC_CONFIGURE_ARGS} - BUILD_COMMAND ${JEMALLOC_BUILD_COMMAND} - BUILD_IN_SOURCE 1 - BUILD_BYPRODUCTS "${JEMALLOC_STATIC_LIB}" - INSTALL_COMMAND make install) - - file(MAKE_DIRECTORY "${JEMALLOC_INCLUDE_DIR}") - add_library(jemalloc::libjemalloc STATIC IMPORTED) - set_target_properties( - jemalloc::libjemalloc - PROPERTIES INTERFACE_LINK_LIBRARIES Threads::Threads - IMPORTED_LOCATION "${JEMALLOC_STATIC_LIB}" - INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") - add_dependencies(jemalloc::libjemalloc protobuf_ep) endmacro() diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt index 4d7c30402985..e17d13581105 100644 --- a/cpp/core/CMakeLists.txt +++ b/cpp/core/CMakeLists.txt @@ -300,16 +300,6 @@ target_include_directories( set_target_properties(gluten PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${root_directory}/releases) -include(Findjemalloc_pic) -# Build Jemalloc -if(BUILD_JEMALLOC) - build_jemalloc(${STATIC_JEMALLOC}) - message(STATUS "Building Jemalloc: ${STATIC_JEMALLOC}") -else() # - find_jemalloc() - message(STATUS "Use existing Jemalloc libraries") -endif() - if(BUILD_TESTS) add_subdirectory(tests) endif() diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index c2d690a7e055..716a5f68a91c 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -576,6 +576,17 @@ find_package(Folly REQUIRED CONFIG) target_include_directories(velox PUBLIC ${GTEST_INCLUDE_DIRS} ${PROTOBUF_INCLUDE}) +if(BUILD_JEMALLOC) + include(Findjemalloc_pic) + find_jemalloc() + if(JEMALLOC_NOT_FOUND) + include(Buildjemalloc_pic) + build_jemalloc() + endif() + add_definitions(-DENABLE_JEMALLOC) + target_link_libraries(velox PUBLIC jemalloc::libjemalloc) +endif() + target_link_libraries(velox PUBLIC gluten) add_velox_dependencies() diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 733eb4c4bc39..efd165b736be 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -16,6 +16,10 @@ */ #include "VeloxMemoryManager.h" +#ifdef ENABLE_JEMALLOC +#include +#endif + #include "velox/common/memory/MallocAllocator.h" #include "velox/common/memory/MemoryPool.h" #include "velox/exec/MemoryReclaimer.h" @@ -326,6 +330,9 @@ VeloxMemoryManager::~VeloxMemoryManager() { usleep(waitMs * 1000); accumulatedWaitMs += waitMs; } +#ifdef ENABLE_JEMALLOC + je_gluten_malloc_stats_print(NULL, NULL, NULL); +#endif } } // namespace gluten diff --git a/dev/vcpkg/CONTRIBUTING.md b/dev/vcpkg/CONTRIBUTING.md index b725f0b50fc5..719bc91db066 100644 --- a/dev/vcpkg/CONTRIBUTING.md +++ b/dev/vcpkg/CONTRIBUTING.md @@ -13,7 +13,7 @@ Please init vcpkg env first: Vcpkg already maintains a lot of libraries. You can find them by vcpkg cli. -(NOTE: Please always use cli beacause [packages on vcpkg.io](https://vcpkg.io/en/packages.html) is outdate). +(NOTE: Please always use cli because [packages on vcpkg.io](https://vcpkg.io/en/packages.html) is outdate). ``` $ ./.vcpkg/vcpkg search folly @@ -28,7 +28,7 @@ folly[zlib] Support zlib for compression folly[zstd] Support zstd for compression ``` -`[...]` means additional features. Then add depend into [vcpkg.json](./vcpkg.json). +`[...]` means additional features. Then add the dependency into [vcpkg.json](./vcpkg.json). ``` json { @@ -144,7 +144,7 @@ See [vcpkg.json reference](https://learn.microsoft.com/en-us/vcpkg/reference/vcp `portfile.cmake` is a cmake script describing how to build and install the package. A typical portfile has 3 stages: -**Download and perpare source**: +**Download and prepare source**: ``` cmake # Download from Github diff --git a/dev/vcpkg/ports/jemalloc/fix-configure-ac.patch b/dev/vcpkg/ports/jemalloc/fix-configure-ac.patch new file mode 100644 index 000000000000..7799dfb9e80e --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/fix-configure-ac.patch @@ -0,0 +1,13 @@ +diff --git a/configure.ac b/configure.ac +index f6d25f334..3115504e2 100644 +--- a/configure.ac ++++ b/configure.ac +@@ -1592,7 +1592,7 @@ fi + [enable_uaf_detection="0"] + ) + if test "x$enable_uaf_detection" = "x1" ; then +- AC_DEFINE([JEMALLOC_UAF_DETECTION], [ ]) ++ AC_DEFINE([JEMALLOC_UAF_DETECTION], [ ], ["enable UAF"]) + fi + AC_SUBST([enable_uaf_detection]) + diff --git a/dev/vcpkg/ports/jemalloc/portfile.cmake b/dev/vcpkg/ports/jemalloc/portfile.cmake new file mode 100644 index 000000000000..6cac12ca3b7c --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/portfile.cmake @@ -0,0 +1,79 @@ +vcpkg_from_github( + OUT_SOURCE_PATH SOURCE_PATH + REPO jemalloc/jemalloc + REF 54eaed1d8b56b1aa528be3bdd1877e59c56fa90c + SHA512 527bfbf5db9a5c2b7b04df4785b6ae9d445cff8cb17298bf3e550c88890d2bd7953642d8efaa417580610508279b527d3a3b9e227d17394fd2013c88cb7ae75a + HEAD_REF master + PATCHES + fix-configure-ac.patch + preprocessor.patch +) +if(VCPKG_TARGET_IS_WINDOWS) + set(opts "ac_cv_search_log=none required" + "--without-private-namespace" + "--with-jemalloc-prefix=je_gluten_" + "--with-private-namespace=je_gluten_private_" + "--without-export" + "--disable-shared" + "--disable-cxx" + "--disable-libdl" + # For fixing an issue when loading native lib: cannot allocate memory in static TLS block. + "--disable-initial-exec-tls" + "CFLAGS=-fPIC" + "CXXFLAGS=-fPIC") +else() + set(opts + "--with-jemalloc-prefix=je_gluten_" + "--with-private-namespace=je_gluten_private_" + "--without-export" + "--disable-shared" + "--disable-cxx" + "--disable-libdl" + # For fixing an issue when loading native lib: cannot allocate memory in static TLS block. + "--disable-initial-exec-tls" + "CFLAGS=-fPIC" + "CXXFLAGS=-fPIC") +endif() + +vcpkg_configure_make( + SOURCE_PATH "${SOURCE_PATH}" + AUTOCONFIG + NO_WRAPPERS + OPTIONS ${opts} +) + +vcpkg_install_make() + +if(VCPKG_TARGET_IS_WINDOWS) + file(COPY "${SOURCE_PATH}/include/msvc_compat/strings.h" DESTINATION "${CURRENT_PACKAGES_DIR}/include/jemalloc/msvc_compat") + vcpkg_replace_string("${CURRENT_PACKAGES_DIR}/include/jemalloc/jemalloc.h" "" "\"msvc_compat/strings.h\"") + if(VCPKG_LIBRARY_LINKAGE STREQUAL "dynamic") + file(COPY "${CURRENT_BUILDTREES_DIR}/${TARGET_TRIPLET}-rel/lib/jemalloc.lib" DESTINATION "${CURRENT_PACKAGES_DIR}/lib") + file(MAKE_DIRECTORY "${CURRENT_PACKAGES_DIR}/bin") + file(RENAME "${CURRENT_PACKAGES_DIR}/lib/jemalloc.dll" "${CURRENT_PACKAGES_DIR}/bin/jemalloc.dll") + endif() + if(NOT VCPKG_BUILD_TYPE) + if(VCPKG_LIBRARY_LINKAGE STREQUAL "dynamic") + file(COPY "${CURRENT_BUILDTREES_DIR}/${TARGET_TRIPLET}-dbg/lib/jemalloc.lib" DESTINATION "${CURRENT_PACKAGES_DIR}/debug/lib") + file(MAKE_DIRECTORY "${CURRENT_PACKAGES_DIR}/debug/bin") + file(RENAME "${CURRENT_PACKAGES_DIR}/debug/lib/jemalloc.dll" "${CURRENT_PACKAGES_DIR}/debug/bin/jemalloc.dll") + endif() + endif() + if(VCPKG_LIBRARY_LINKAGE STREQUAL "static") + vcpkg_replace_string("${CURRENT_PACKAGES_DIR}/lib/pkgconfig/jemalloc.pc" "install_suffix=" "install_suffix=_s") + if(NOT VCPKG_BUILD_TYPE) + vcpkg_replace_string("${CURRENT_PACKAGES_DIR}/debug/lib/pkgconfig/jemalloc.pc" "install_suffix=" "install_suffix=_s") + endif() + endif() +endif() + +vcpkg_fixup_pkgconfig() + +vcpkg_copy_pdbs() + +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/debug/include") +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/debug/share") +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/tools") + +# Handle copyright +file(INSTALL "${SOURCE_PATH}/COPYING" DESTINATION "${CURRENT_PACKAGES_DIR}/share/${PORT}" RENAME copyright) diff --git a/dev/vcpkg/ports/jemalloc/preprocessor.patch b/dev/vcpkg/ports/jemalloc/preprocessor.patch new file mode 100644 index 000000000000..6e6e2d1403fb --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/preprocessor.patch @@ -0,0 +1,12 @@ +diff --git a/configure.ac b/configure.ac +index 3115504e2..ffb504b08 100644 +--- a/configure.ac ++++ b/configure.ac +@@ -749,6 +749,7 @@ case "${host}" in + so="dll" + if test "x$je_cv_msvc" = "xyes" ; then + importlib="lib" ++ JE_APPEND_VS(CPPFLAGS, -DJEMALLOC_NO_PRIVATE_NAMESPACE) + DSO_LDFLAGS="-LD" + EXTRA_LDFLAGS="-link -DEBUG" + CTARGET='-Fo$@' diff --git a/dev/vcpkg/ports/jemalloc/vcpkg.json b/dev/vcpkg/ports/jemalloc/vcpkg.json new file mode 100644 index 000000000000..007e05b931c9 --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/vcpkg.json @@ -0,0 +1,8 @@ +{ + "name": "jemalloc", + "version": "5.3.0", + "port-version": 1, + "description": "jemalloc is a general purpose malloc(3) implementation that emphasizes fragmentation avoidance and scalable concurrency support", + "homepage": "https://jemalloc.net/", + "license": "BSD-2-Clause" +} diff --git a/docs/get-started/build-guide.md b/docs/get-started/build-guide.md index 3db2244ba229..b2e4b9560301 100644 --- a/docs/get-started/build-guide.md +++ b/docs/get-started/build-guide.md @@ -14,7 +14,7 @@ Please set them via `--`, e.g. `--build_type=Release`. | build_tests | Build gluten cpp tests. | OFF | | build_examples | Build udf example. | OFF | | build_benchmarks | Build gluten cpp benchmarks. | OFF | -| build_jemalloc | Build with jemalloc. | ON | +| build_jemalloc | Build with jemalloc. | OFF | | build_protobuf | Build protobuf lib. | ON | | enable_qat | Enable QAT for shuffle data de/compression. | OFF | | enable_iaa | Enable IAA for shuffle data de/compression. | OFF | From ac227ded59b5e1b7913bcd137bd46b52c7532303 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Thu, 27 Jun 2024 09:16:17 +0800 Subject: [PATCH 24/30] [VL] Remove the registry for Velox's prestosql scalar functions (#5202) --- .../functions/RegistrationAllFunctions.cc | 16 ++++++++++------ cpp/velox/substrait/SubstraitParser.cc | 12 +----------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index b827690d1cdf..638dbcccff0c 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -26,7 +26,6 @@ #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" -#include "velox/functions/sparksql/Bitwise.h" #include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/Rand.h" #include "velox/functions/sparksql/Register.h" @@ -35,6 +34,14 @@ using namespace facebook; +namespace facebook::velox::functions { +void registerPrestoVectorFunctions() { + // Presto function. To be removed. + VELOX_REGISTER_VECTOR_FUNCTION(udf_arrays_overlap, "arrays_overlap"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform_keys, "transform_keys"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform_values, "transform_values"); +} +} // namespace facebook::velox::functions namespace gluten { namespace { void registerFunctionOverwrite() { @@ -67,19 +74,16 @@ void registerFunctionOverwrite() { velox::exec::registerFunctionCallToSpecialForm( kRowConstructorWithAllNull, std::make_unique(kRowConstructorWithAllNull)); - velox::functions::sparksql::registerBitwiseFunctions("spark_"); velox::functions::registerBinaryIntegral({"check_add"}); velox::functions::registerBinaryIntegral({"check_subtract"}); velox::functions::registerBinaryIntegral({"check_multiply"}); velox::functions::registerBinaryIntegral({"check_divide"}); + + velox::functions::registerPrestoVectorFunctions(); } } // namespace void registerAllFunctions() { - // The registration order matters. Spark sql functions are registered after - // presto sql functions to overwrite the registration for same named - // functions. - velox::functions::prestosql::registerAllScalarFunctions(); velox::functions::sparksql::registerFunctions(""); velox::aggregate::prestosql::registerAllAggregateFunctions( "", true /*registerCompanionFunctions*/, false /*onlyPrestoSignatures*/, true /*overwrite*/); diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 5555ecfef954..b842914ca933 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -391,23 +391,13 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"named_struct", "row_constructor"}, {"bit_or", "bitwise_or_agg"}, {"bit_and", "bitwise_and_agg"}, - {"bitwise_and", "spark_bitwise_and"}, - {"bitwise_not", "spark_bitwise_not"}, - {"bitwise_or", "spark_bitwise_or"}, - {"bitwise_xor", "spark_bitwise_xor"}, - // TODO: the below registry for rand functions can be removed - // after presto function registry is removed. - {"rand", "spark_rand"}, {"murmur3hash", "hash_with_seed"}, {"xxhash64", "xxhash64_with_seed"}, {"modulus", "remainder"}, {"date_format", "format_datetime"}, {"collect_set", "set_agg"}, - {"forall", "all_match"}, - {"exists", "any_match"}, {"negative", "unaryminus"}, - {"get_array_item", "get"}, - {"arrays_zip", "zip"}}; + {"get_array_item", "get"}}; const std::unordered_map SubstraitParser::typeMap_ = { {"bool", "BOOLEAN"}, From 32808dd22a0384d0e0bb5011bf2393710a4d5942 Mon Sep 17 00:00:00 2001 From: Kerwin Zhang Date: Thu, 27 Jun 2024 09:58:25 +0800 Subject: [PATCH 25/30] [CELEBORN] Upgrade celeborn to 0.4.1 to support scala 2.13-based compilation (#6226) --- .github/workflows/velox_docker.yml | 6 +++--- docs/get-started/ClickHouse.md | 12 ++++++------ .../gluten/celeborn/CelebornShuffleManager.java | 8 +++++++- .../shuffle/gluten/celeborn/CelebornUtils.java | 14 ++++++++++++-- pom.xml | 2 +- tools/gluten-it/pom.xml | 4 ++-- 6 files changed, 31 insertions(+), 15 deletions(-) diff --git a/.github/workflows/velox_docker.yml b/.github/workflows/velox_docker.yml index 31796c15bdd5..d110d0a6d223 100644 --- a/.github/workflows/velox_docker.yml +++ b/.github/workflows/velox_docker.yml @@ -521,7 +521,7 @@ jobs: fail-fast: false matrix: spark: ["spark-3.2"] - celeborn: ["celeborn-0.4.0", "celeborn-0.3.2"] + celeborn: ["celeborn-0.4.1", "celeborn-0.3.2-incubating"] runs-on: ubuntu-20.04 container: ubuntu:22.04 steps: @@ -557,8 +557,8 @@ jobs: fi echo "EXTRA_PROFILE: ${EXTRA_PROFILE}" cd /opt && mkdir -p celeborn && \ - wget https://archive.apache.org/dist/incubator/celeborn/${{ matrix.celeborn }}-incubating/apache-${{ matrix.celeborn }}-incubating-bin.tgz && \ - tar xzf apache-${{ matrix.celeborn }}-incubating-bin.tgz -C /opt/celeborn --strip-components=1 && cd celeborn && \ + wget https://archive.apache.org/dist/celeborn/${{ matrix.celeborn }}/apache-${{ matrix.celeborn }}-bin.tgz && \ + tar xzf apache-${{ matrix.celeborn }}-bin.tgz -C /opt/celeborn --strip-components=1 && cd celeborn && \ mv ./conf/celeborn-env.sh.template ./conf/celeborn-env.sh && \ bash -c "echo -e 'CELEBORN_MASTER_MEMORY=4g\nCELEBORN_WORKER_MEMORY=4g\nCELEBORN_WORKER_OFFHEAP_MEMORY=8g' > ./conf/celeborn-env.sh" && \ bash -c "echo -e 'celeborn.worker.commitFiles.threads 128\nceleborn.worker.sortPartition.threads 64' > ./conf/celeborn-defaults.conf" && \ diff --git a/docs/get-started/ClickHouse.md b/docs/get-started/ClickHouse.md index 4352a99e55f9..ab24de7a4fd6 100644 --- a/docs/get-started/ClickHouse.md +++ b/docs/get-started/ClickHouse.md @@ -679,13 +679,13 @@ spark.shuffle.manager=org.apache.spark.shuffle.gluten.celeborn.CelebornShuffleMa quickly start a celeborn cluster ```shell -wget https://archive.apache.org/dist/incubator/celeborn/celeborn-0.3.0-incubating/apache-celeborn-0.3.0-incubating-bin.tgz && \ -tar -zxvf apache-celeborn-0.3.0-incubating-bin.tgz && \ -mv apache-celeborn-0.3.0-incubating-bin/conf/celeborn-defaults.conf.template apache-celeborn-0.3.0-incubating-bin/conf/celeborn-defaults.conf && \ -mv apache-celeborn-0.3.0-incubating-bin/conf/log4j2.xml.template apache-celeborn-0.3.0-incubating-bin/conf/log4j2.xml && \ +wget https://archive.apache.org/dist/celeborn/celeborn-0.3.2-incubating/apache-celeborn-0.3.2-incubating-bin.tgz && \ +tar -zxvf apache-celeborn-0.3.2-incubating-bin.tgz && \ +mv apache-celeborn-0.3.2-incubating-bin/conf/celeborn-defaults.conf.template apache-celeborn-0.3.2-incubating-bin/conf/celeborn-defaults.conf && \ +mv apache-celeborn-0.3.2-incubating-bin/conf/log4j2.xml.template apache-celeborn-0.3.2-incubating-bin/conf/log4j2.xml && \ mkdir /opt/hadoop && chmod 777 /opt/hadoop && \ -echo -e "celeborn.worker.flusher.threads 4\nceleborn.worker.storage.dirs /tmp\nceleborn.worker.monitor.disk.enabled false" > apache-celeborn-0.3.0-incubating-bin/conf/celeborn-defaults.conf && \ -bash apache-celeborn-0.3.0-incubating-bin/sbin/start-master.sh && bash apache-celeborn-0.3.0-incubating-bin/sbin/start-worker.sh +echo -e "celeborn.worker.flusher.threads 4\nceleborn.worker.storage.dirs /tmp\nceleborn.worker.monitor.disk.enabled false" > apache-celeborn-0.3.2-incubating-bin/conf/celeborn-defaults.conf && \ +bash apache-celeborn-0.3.2-incubating-bin/sbin/start-master.sh && bash apache-celeborn-0.3.2-incubating-bin/sbin/start-worker.sh ``` ### Columnar shuffle mode diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java index f454cf00c656..d196691d1b14 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java @@ -217,7 +217,13 @@ public boolean unregisterShuffle(int shuffleId) { } } return CelebornUtils.unregisterShuffle( - lifecycleManager, shuffleClient, shuffleIdTracker, shuffleId, appUniqueId, isDriver()); + lifecycleManager, + shuffleClient, + shuffleIdTracker, + shuffleId, + appUniqueId, + throwsFetchFailure, + isDriver()); } @Override diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java index 9dd4e1d1191e..6b4229ad3037 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java @@ -49,11 +49,21 @@ public static boolean unregisterShuffle( Object shuffleIdTracker, int appShuffleId, String appUniqueId, + boolean throwsFetchFailure, boolean isDriver) { try { - // for Celeborn 0.4.0 try { - if (lifecycleManager != null) { + try { + // for Celeborn 0.4.1 + if (lifecycleManager != null) { + Method unregisterAppShuffle = + lifecycleManager + .getClass() + .getMethod("unregisterAppShuffle", int.class, boolean.class); + unregisterAppShuffle.invoke(lifecycleManager, appShuffleId, throwsFetchFailure); + } + } catch (NoSuchMethodException ex) { + // for Celeborn 0.4.0 Method unregisterAppShuffle = lifecycleManager.getClass().getMethod("unregisterAppShuffle", int.class); unregisterAppShuffle.invoke(lifecycleManager, appShuffleId); diff --git a/pom.xml b/pom.xml index 81ce0e5d462a..887839ce5fc0 100644 --- a/pom.xml +++ b/pom.xml @@ -53,7 +53,7 @@ delta-core 2.4.0 24 - 0.3.2-incubating + 0.4.1 0.8.0 15.0.0 15.0.0-gluten diff --git a/tools/gluten-it/pom.xml b/tools/gluten-it/pom.xml index 3f1760069792..71db637a8403 100644 --- a/tools/gluten-it/pom.xml +++ b/tools/gluten-it/pom.xml @@ -21,7 +21,7 @@ 3.4.2 2.12 3 - 0.3.0-incubating + 0.3.2-incubating 0.8.0 1.2.0-SNAPSHOT 32.0.1-jre @@ -167,7 +167,7 @@ celeborn-0.4 - 0.4.0-incubating + 0.4.1 From 3a42e8fbd3797390a554839ef10e6f9b073460d6 Mon Sep 17 00:00:00 2001 From: Kerwin Zhang Date: Thu, 27 Jun 2024 11:49:15 +0800 Subject: [PATCH 26/30] [CELEBORN] Add config to control celeborn fallback for CI (#6230) --- .../gluten/celeborn/CelebornShuffleManager.java | 12 +++++++++--- .../main/scala/org/apache/gluten/GlutenConfig.scala | 10 ++++++++++ .../org/apache/gluten/integration/Constants.scala | 1 + 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java index d196691d1b14..63fb0cc1b9bd 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java @@ -16,6 +16,7 @@ */ package org.apache.spark.shuffle.gluten.celeborn; +import org.apache.gluten.GlutenConfig; import org.apache.gluten.backendsapi.BackendsApiManager; import org.apache.gluten.exception.GlutenException; @@ -194,9 +195,14 @@ public ShuffleHandle registerShuffle( if (dependency instanceof ColumnarShuffleDependency) { if (fallbackPolicyRunner.applyAllFallbackPolicy( lifecycleManager, dependency.partitioner().numPartitions())) { - logger.warn("Fallback to ColumnarShuffleManager!"); - columnarShuffleIds.add(shuffleId); - return columnarShuffleManager().registerShuffle(shuffleId, dependency); + if (GlutenConfig.getConf().enableCelebornFallback()) { + logger.warn("Fallback to ColumnarShuffleManager!"); + columnarShuffleIds.add(shuffleId); + return columnarShuffleManager().registerShuffle(shuffleId, dependency); + } else { + throw new GlutenException( + "The Celeborn service(Master: " + celebornConf.masterHost() + ") is unavailable"); + } } else { return registerCelebornShuffleHandle(shuffleId, dependency); } diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 89933cc58a4d..58b99a7f3064 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -447,6 +447,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { conf.getConf(DYNAMIC_OFFHEAP_SIZING_ENABLED) def enableHiveFileFormatWriter: Boolean = conf.getConf(NATIVE_HIVEFILEFORMAT_WRITER_ENABLED) + + def enableCelebornFallback: Boolean = conf.getConf(CELEBORN_FALLBACK_ENABLED) } object GlutenConfig { @@ -2049,4 +2051,12 @@ object GlutenConfig { .doubleConf .checkValue(v => v >= 0 && v <= 1, "offheap sizing memory fraction must between [0, 1]") .createWithDefault(0.6) + + val CELEBORN_FALLBACK_ENABLED = + buildStaticConf("spark.gluten.sql.columnar.shuffle.celeborn.fallback.enabled") + .internal() + .doc("If enabled, fall back to ColumnarShuffleManager when celeborn service is unavailable." + + "Otherwise, throw an exception.") + .booleanConf + .createWithDefault(true) } diff --git a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala index 50766f3a91d1..e680ce9d5dda 100644 --- a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala +++ b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala @@ -44,6 +44,7 @@ object Constants { val VELOX_WITH_CELEBORN_CONF: SparkConf = new SparkConf(false) .set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true") + .set("spark.gluten.sql.columnar.shuffle.celeborn.fallback.enabled", "false") .set("spark.sql.parquet.enableVectorizedReader", "true") .set("spark.plugins", "org.apache.gluten.GlutenPlugin") .set( From b65ecced292b9defdc10d6d5e3a46f43c28fd84c Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Thu, 27 Jun 2024 14:06:26 +0800 Subject: [PATCH 27/30] [VL] Remove useless function registering code (#6245) --- .../functions/RegistrationAllFunctions.cc | 7 ------- .../gluten/expression/ExpressionConverter.scala | 16 ++++++++-------- .../gluten/expression/ExpressionNames.scala | 8 ++++---- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index 638dbcccff0c..6b6564fa4aa3 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -52,9 +52,6 @@ void registerFunctionOverwrite() { velox::registerFunction({"round"}); velox::registerFunction({"round"}); velox::registerFunction({"round"}); - // TODO: the below rand function registry can be removed after presto function registry is removed. - velox::registerFunction>({"spark_rand"}); - velox::registerFunction>({"spark_rand"}); auto kRowConstructorWithNull = RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull; velox::exec::registerVectorFunction( @@ -74,10 +71,6 @@ void registerFunctionOverwrite() { velox::exec::registerFunctionCallToSpecialForm( kRowConstructorWithAllNull, std::make_unique(kRowConstructorWithAllNull)); - velox::functions::registerBinaryIntegral({"check_add"}); - velox::functions::registerBinaryIntegral({"check_subtract"}); - velox::functions::registerBinaryIntegral({"check_multiply"}); - velox::functions::registerBinaryIntegral({"check_divide"}); velox::functions::registerPrestoVectorFunctions(); } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index da5625cd45e5..d5222cfc6350 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -564,7 +564,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_ADD + ExpressionNames.CHECKED_ADD ) case tryEval @ TryEval(a: Subtract) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( @@ -572,7 +572,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_SUBTRACT + ExpressionNames.CHECKED_SUBTRACT ) case tryEval @ TryEval(a: Divide) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( @@ -580,7 +580,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_DIVIDE + ExpressionNames.CHECKED_DIVIDE ) case tryEval @ TryEval(a: Multiply) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( @@ -588,7 +588,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_MULTIPLY + ExpressionNames.CHECKED_MULTIPLY ) case a: Add => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -596,7 +596,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_ADD + ExpressionNames.CHECKED_ADD ) case a: Subtract => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -604,7 +604,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_SUBTRACT + ExpressionNames.CHECKED_SUBTRACT ) case a: Multiply => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -612,7 +612,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_MULTIPLY + ExpressionNames.CHECKED_MULTIPLY ) case a: Divide => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -620,7 +620,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_DIVIDE + ExpressionNames.CHECKED_DIVIDE ) case tryEval: TryEval => // This is a placeholder to handle try_eval(other expressions). diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 2be3fad9d39d..8317e28b58bb 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -83,10 +83,10 @@ object ExpressionNames { final val IS_NAN = "isnan" final val NANVL = "nanvl" final val TRY_EVAL = "try" - final val CHECK_ADD = "check_add" - final val CHECK_SUBTRACT = "check_subtract" - final val CHECK_DIVIDE = "check_divide" - final val CHECK_MULTIPLY = "check_multiply" + final val CHECKED_ADD = "checked_add" + final val CHECKED_SUBTRACT = "checked_subtract" + final val CHECKED_DIVIDE = "checked_divide" + final val CHECKED_MULTIPLY = "checked_multiply" // SparkSQL String functions final val ASCII = "ascii" From 51b1901a797c8dc43f11793fd9b679d2477a69aa Mon Sep 17 00:00:00 2001 From: exmy Date: Thu, 27 Jun 2024 14:34:02 +0800 Subject: [PATCH 28/30] [GLUTEN-6235][CH] Fix crash on ExpandTransform::work() (#6238) [CH] Fix crash on ExpandTransform::work() --- .../GlutenClickHouseHiveTableSuite.scala | 25 +++++++++++++++++++ .../local-engine/Operator/ExpandTransform.cpp | 2 +- .../execution/BasicScanExecTransformer.scala | 19 +------------- .../execution/AbstractHiveTableScanExec.scala | 2 +- .../execution/AbstractHiveTableScanExec.scala | 2 +- .../execution/AbstractHiveTableScanExec.scala | 2 +- .../execution/AbstractHiveTableScanExec.scala | 2 +- 7 files changed, 31 insertions(+), 23 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala index 9b52f6a8cb53..4e190c087920 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala @@ -1252,4 +1252,29 @@ class GlutenClickHouseHiveTableSuite } spark.sql("drop table test_tbl_3452") } + + test("GLUTEN-6235: Fix crash on ExpandTransform::work()") { + val tbl = "test_tbl_6235" + sql(s"drop table if exists $tbl") + val createSql = + s""" + |create table $tbl + |stored as textfile + |as select 1 as a1, 2 as a2, 3 as a3, 4 as a4, 5 as a5, 6 as a6, 7 as a7, 8 as a8, 9 as a9 + |""".stripMargin + sql(createSql) + val select_sql = + s""" + |select + |a5,a6,a7,a8,a3,a4,a9 + |,count(distinct a2) as a2 + |,count(distinct a1) as a1 + |,count(distinct if(a3=1,a2,null)) as a33 + |,count(distinct if(a4=2,a1,null)) as a43 + |from $tbl + |group by a5,a6,a7,a8,a3,a4,a9 with cube + |""".stripMargin + compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) + sql(s"drop table if exists $tbl") + } } diff --git a/cpp-ch/local-engine/Operator/ExpandTransform.cpp b/cpp-ch/local-engine/Operator/ExpandTransform.cpp index 106c38e2d8c3..f5787163c5a1 100644 --- a/cpp-ch/local-engine/Operator/ExpandTransform.cpp +++ b/cpp-ch/local-engine/Operator/ExpandTransform.cpp @@ -104,7 +104,7 @@ void ExpandTransform::work() if (kind == EXPAND_FIELD_KIND_SELECTION) { - const auto & original_col = original_cols[field.get()]; + const auto & original_col = original_cols.at(field.get()); if (type->isNullable() == original_col->isNullable()) { cols.push_back(original_col); diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 9d231bbc2891..64071fb14c0c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -110,7 +110,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource } override protected def doTransform(context: SubstraitContext): TransformContext = { - val output = filterRedundantField(outputAttributes()) + val output = outputAttributes() val typeNodes = ConverterUtils.collectAttributeTypeNodes(output) val nameList = ConverterUtils.collectAttributeNamesWithoutExprId(output) val columnTypeNodes = output.map { @@ -156,21 +156,4 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource context.nextOperatorId(this.nodeName)) TransformContext(output, output, readNode) } - - private def filterRedundantField(outputs: Seq[Attribute]): Seq[Attribute] = { - var finalOutput: List[Attribute] = List() - val outputList = outputs.toArray - for (i <- outputList.indices) { - var dup = false - for (j <- 0 until i) { - if (outputList(i).name == outputList(j).name) { - dup = true - } - } - if (!dup) { - finalOutput = finalOutput :+ outputList(i) - } - } - finalOutput - } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 95106b4edba1..46b59ac306c2 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -75,7 +75,7 @@ abstract private[hive] class AbstractHiveTableScanExec( override val output: Seq[Attribute] = { // Retrieve the original attributes based on expression ID so that capitalization matches. - requestedAttributes.map(originalAttributes) + requestedAttributes.map(originalAttributes).distinct } // Bind all partition key attribute references in the partition pruning predicate for later diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 78f5ff7f1be1..dd095f0ff247 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -75,7 +75,7 @@ abstract private[hive] class AbstractHiveTableScanExec( override val output: Seq[Attribute] = { // Retrieve the original attributes based on expression ID so that capitalization matches. - requestedAttributes.map(originalAttributes) + requestedAttributes.map(originalAttributes).distinct } // Bind all partition key attribute references in the partition pruning predicate for later diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 77f15ac57087..87aba00b0f59 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -77,7 +77,7 @@ abstract private[hive] class AbstractHiveTableScanExec( override val output: Seq[Attribute] = { // Retrieve the original attributes based on expression ID so that capitalization matches. - requestedAttributes.map(originalAttributes) + requestedAttributes.map(originalAttributes).distinct } // Bind all partition key attribute references in the partition pruning predicate for later diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala index 77f15ac57087..87aba00b0f59 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/hive/execution/AbstractHiveTableScanExec.scala @@ -77,7 +77,7 @@ abstract private[hive] class AbstractHiveTableScanExec( override val output: Seq[Attribute] = { // Retrieve the original attributes based on expression ID so that capitalization matches. - requestedAttributes.map(originalAttributes) + requestedAttributes.map(originalAttributes).distinct } // Bind all partition key attribute references in the partition pruning predicate for later From e71a0c414ecd2595b166a03ee381845f9977302c Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Thu, 27 Jun 2024 15:06:11 +0800 Subject: [PATCH 29/30] [CH] Support use dynamic disk path #6232 What changes were proposed in this pull request? Support use dynamic disk path spark.gluten.sql.columnar.backend.ch.runtime_config.use_current_directory_as_tmp=true disk.metadata_path and cache.path are automatically mapped to the current directory spark.gluten.sql.columnar.backend.ch.runtime_config.reuse_disk_cache=false Add the current pid number to disk.metadata_path and cache.path How was this patch tested? unit tests --- cpp-ch/local-engine/Common/CHUtil.cpp | 75 +++++++++++++++++++++++++++ cpp-ch/local-engine/Common/CHUtil.h | 4 ++ 2 files changed, 79 insertions(+) diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 148e78bfbc79..76c71ce752d6 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -17,6 +17,7 @@ #include "CHUtil.h" #include +#include #include #include #include @@ -527,6 +528,50 @@ std::map BackendInitializerUtil::getBackendConfMap(con return ch_backend_conf; } +std::vector BackendInitializerUtil::wrapDiskPathConfig( + const String & path_prefix, + const String & path_suffix, + Poco::Util::AbstractConfiguration & config) +{ + std::vector changed_paths; + if (path_prefix.empty() && path_suffix.empty()) + return changed_paths; + Poco::Util::AbstractConfiguration::Keys disks; + std::unordered_set disk_types = {"s3", "hdfs_gluten", "cache"}; + config.keys("storage_configuration.disks", disks); + + std::ranges::for_each( + disks, + [&](const auto & disk_name) + { + String disk_prefix = "storage_configuration.disks." + disk_name; + String disk_type = config.getString(disk_prefix + ".type", ""); + if (!disk_types.contains(disk_type)) + return; + if (disk_type == "cache") + { + String path = config.getString(disk_prefix + ".path", ""); + if (!path.empty()) + { + String final_path = path_prefix + path + path_suffix; + config.setString(disk_prefix + ".path", final_path); + changed_paths.emplace_back(final_path); + } + } + else if (disk_type == "s3" || disk_type == "hdfs_gluten") + { + String metadata_path = config.getString(disk_prefix + ".metadata_path", ""); + if (!metadata_path.empty()) + { + String final_path = path_prefix + metadata_path + path_suffix; + config.setString(disk_prefix + ".metadata_path", final_path); + changed_paths.emplace_back(final_path); + } + } + }); + return changed_paths; +} + DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::map & backend_conf_map) { DB::Context::ConfigurationPtr config; @@ -566,6 +611,25 @@ DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::mapsetString(CH_TASK_MEMORY, backend_conf_map.at(GLUTEN_TASK_OFFHEAP)); } + const bool use_current_directory_as_tmp = config->getBool("use_current_directory_as_tmp", false); + char buffer[PATH_MAX]; + if (use_current_directory_as_tmp && getcwd(buffer, sizeof(buffer)) != nullptr) + { + wrapDiskPathConfig(String(buffer), "", *config); + } + + const bool reuse_disk_cache = config->getBool("reuse_disk_cache", true); + + if (!reuse_disk_cache) + { + String pid = std::to_string(static_cast(getpid())); + auto path_need_clean = wrapDiskPathConfig("", "/" + pid, *config); + std::lock_guard lock(BackendFinalizerUtil::paths_mutex); + BackendFinalizerUtil::paths_need_to_clean.insert( + BackendFinalizerUtil::paths_need_to_clean.end(), + path_need_clean.begin(), + path_need_clean.end()); + } return config; } @@ -936,12 +1000,23 @@ void BackendFinalizerUtil::finalizeGlobally() global_context.reset(); shared_context.reset(); } + std::lock_guard lock(paths_mutex); + std::ranges::for_each(paths_need_to_clean, [](const auto & path) + { + if (fs::exists(path)) + fs::remove_all(path); + }); + paths_need_to_clean.clear(); } void BackendFinalizerUtil::finalizeSessionally() { } +std::vector BackendFinalizerUtil::paths_need_to_clean; + +std::mutex BackendFinalizerUtil::paths_mutex; + Int64 DateTimeUtil::currentTimeMillis() { return timeInMilliseconds(std::chrono::system_clock::now()); diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 0321d410a7d5..1198cfa2195d 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -196,6 +196,7 @@ class BackendInitializerUtil static void registerAllFactories(); static void applyGlobalConfigAndSettings(DB::Context::ConfigurationPtr, DB::Settings &); static void updateNewSettings(const DB::ContextMutablePtr &, const DB::Settings &); + static std::vector wrapDiskPathConfig(const String & path_prefix, const String & path_suffix, Poco::Util::AbstractConfiguration & config); static std::map getBackendConfMap(const std::string & plan); @@ -212,6 +213,9 @@ class BackendFinalizerUtil /// Release session level resources like StorageJoinBuilder. Invoked every time executor/driver shutdown. static void finalizeSessionally(); + + static std::vector paths_need_to_clean; + static std::mutex paths_mutex; }; // Ignore memory track, memory should free before IgnoreMemoryTracker deconstruction From 7bf6cd41062c0c6c60a9b17a81674828e02d6a6b Mon Sep 17 00:00:00 2001 From: Gluten Performance Bot <137994563+GlutenPerfBot@users.noreply.github.com> Date: Thu, 27 Jun 2024 16:15:15 +0800 Subject: [PATCH 30/30] [VL] Daily Update Velox Version (2024_06_27) (#6242) 43cb72a1e by Masha Basmanova, Add support for minus(timestamp with tz, timestamp with tz) Presto function (#10327) 51f86b176 by Masha Basmanova, Add support for TIMESTAMP WITH TIME ZONE inputs to least/greatest Presto functions (#10328) 9857a2eb5 by zhli1142015, Fix NaN handling in Spark In function (#10259) 5a150cec1 by wypb, Fix AbfsReadFile::Impl::preadv to return the length of read. (#10320) 31800f52d by Masha Basmanova, Add support for VARBINARY input to from_base64 Presto function (#10325) 7add4bf64 by Jialiang Tan, Add stats reporter to op test base (#10296) c03401967 by lingbin, Simplify assertion in AllocationTraits (#10322) 6af663c42 by xiaoxmeng, Fix flaky async data cache shutdown test (#10318) 1ae622476 by Yoav Helfman, Fix IOStats for Nimble (#10216) 2a140f9d6 by Jialiang Tan, Shorten AggregationTest.maxSpillBytes from 2m to 2s (#10317) 05222475a by Kevin Wilfong, Ensure the shared HashStringAllocator isn't mutated by Aggregations during spilling (#10309) 622d31ac5 by Kevin Wilfong, Fix race conditions in AsyncDataCache AccessStats (#10312) 136d66be0 by Kevin Wilfong, Fix race condition in MemoryArbitrationFuzzer (#10314) a4e0b6a1f by Kevin Wilfong, Ignore TSAN errors in WindowFuzzer (#10315) --- ep/build-velox/src/get_velox.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index a96719dc10fc..237757d818d0 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_06_26 +VELOX_BRANCH=2024_06_27 VELOX_HOME="" #Set on run gluten on HDFS