diff --git a/cpp/core/compute/ExecutionCtx.cc b/cpp/core/compute/ExecutionCtx.cc index 1ca64cd93f16..bf56592580b5 100644 --- a/cpp/core/compute/ExecutionCtx.cc +++ b/cpp/core/compute/ExecutionCtx.cc @@ -20,28 +20,50 @@ namespace gluten { -static ExecutionCtxFactoryContext* getExecutionCtxFactoryContext() { - static ExecutionCtxFactoryContext* executionCtxFactoryCtx = new ExecutionCtxFactoryContext; - return executionCtxFactoryCtx; -} +namespace { +class FactoryRegistry { + public: + void registerFactory(const std::string& kind, ExecutionCtx::Factory factory) { + std::lock_guard l(mutex_); + GLUTEN_CHECK(map_.find(kind) == map_.end(), "ExecutionCtx factory already registered for " + kind); + map_[kind] = std::move(factory); + } + + ExecutionCtx::Factory& getFactory(const std::string& kind) { + std::lock_guard l(mutex_); + GLUTEN_CHECK(map_.find(kind) != map_.end(), "ExecutionCtx factory not registered for " + kind); + return map_[kind]; + } + + bool unregisterFactory(const std::string& kind) { + std::lock_guard l(mutex_); + GLUTEN_CHECK(map_.find(kind) != map_.end(), "ExecutionCtx factory not registered for " + kind); + return map_.erase(kind); + } + + private: + std::mutex mutex_; + std::unordered_map map_; +}; -void setExecutionCtxFactory( - ExecutionCtxFactoryWithConf factory, - const std::unordered_map& sparkConfs) { - getExecutionCtxFactoryContext()->set(factory, sparkConfs); - DEBUG_OUT << "Set execution context factory with conf." << std::endl; +FactoryRegistry& executionCtxFactories() { + static FactoryRegistry registry; + return registry; } +} // namespace -void setExecutionCtxFactory(ExecutionCtxFactory factory) { - getExecutionCtxFactoryContext()->set(factory); - DEBUG_OUT << "Set execution context factory." << std::endl; +void ExecutionCtx::registerFactory(const std::string& kind, ExecutionCtx::Factory factory) { + executionCtxFactories().registerFactory(kind, std::move(factory)); } -ExecutionCtx* createExecutionCtx() { - return getExecutionCtxFactoryContext()->create(); +ExecutionCtx* ExecutionCtx::create( + const std::string& kind, + const std::unordered_map& sessionConf) { + auto& factory = executionCtxFactories().getFactory(kind); + return factory(sessionConf); } -void releaseExecutionCtx(ExecutionCtx* executionCtx) { +void ExecutionCtx::release(ExecutionCtx* executionCtx) { delete executionCtx; } diff --git a/cpp/core/compute/ExecutionCtx.h b/cpp/core/compute/ExecutionCtx.h index c28b1c4335bf..f0fc7717bfd8 100644 --- a/cpp/core/compute/ExecutionCtx.h +++ b/cpp/core/compute/ExecutionCtx.h @@ -47,6 +47,13 @@ struct SparkTaskInfo { /// ExecutionCtx is stateful and manager all kinds of native resources' lifecycle during execute a computation fragment. class ExecutionCtx : public std::enable_shared_from_this { public: + using Factory = std::function&)>; + static void registerFactory(const std::string& kind, Factory factory); + static ExecutionCtx* create( + const std::string& kind, + const std::unordered_map& sessionConf = {}); + static void release(ExecutionCtx*); + ExecutionCtx() = default; ExecutionCtx(const std::unordered_map& confMap) : confMap_(confMap) {} virtual ~ExecutionCtx() = default; @@ -156,7 +163,7 @@ class ExecutionCtx : public std::enable_shared_from_this { throw GlutenException("Not implement getNonPartitionedColumnarBatch"); } - std::unordered_map getConfMap() { + const std::unordered_map& getConfMap() { return confMap_; } @@ -168,79 +175,6 @@ class ExecutionCtx : public std::enable_shared_from_this { ::substrait::Plan substraitPlan_; SparkTaskInfo taskInfo_; // static conf map - std::unordered_map confMap_; + const std::unordered_map confMap_; }; - -using ExecutionCtxFactoryWithConf = ExecutionCtx* (*)(const std::unordered_map&); -using ExecutionCtxFactory = ExecutionCtx* (*)(); - -struct ExecutionCtxFactoryContext { - std::mutex mutex; - - enum { - kExecutionCtxFactoryInvalid, - kExecutionCtxFactoryDefault, - kExecutionCtxFactoryWithConf - } type = kExecutionCtxFactoryInvalid; - - union { - ExecutionCtxFactoryWithConf backendFactoryWithConf; - ExecutionCtxFactory backendFactory; - }; - - std::unordered_map sparkConf_; - - void set(ExecutionCtxFactoryWithConf factory, const std::unordered_map& sparkConf) { - std::lock_guard lockGuard(mutex); - - if (type != kExecutionCtxFactoryInvalid) { - assert(false); - abort(); - return; - } - - type = kExecutionCtxFactoryWithConf; - backendFactoryWithConf = factory; - this->sparkConf_.clear(); - for (auto& x : sparkConf) { - this->sparkConf_[x.first] = x.second; - } - } - - void set(ExecutionCtxFactory factory) { - std::lock_guard lockGuard(mutex); - if (type != kExecutionCtxFactoryInvalid) { - assert(false); - abort(); - return; - } - - type = kExecutionCtxFactoryDefault; - backendFactory = factory; - } - - ExecutionCtx* create() { - std::lock_guard lockGuard(mutex); - if (type == kExecutionCtxFactoryInvalid) { - assert(false); - abort(); - return nullptr; - } else if (type == kExecutionCtxFactoryWithConf) { - return backendFactoryWithConf(sparkConf_); - } else { - return backendFactory(); - } - } -}; - -void setExecutionCtxFactory( - ExecutionCtxFactoryWithConf factory, - const std::unordered_map& sparkConf); - -void setExecutionCtxFactory(ExecutionCtxFactory factory); - -ExecutionCtx* createExecutionCtx(); - -void releaseExecutionCtx(ExecutionCtx*); - } // namespace gluten diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index e8bdaae15485..43c9ee62a69e 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -314,9 +314,13 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { JNIEXPORT jlong JNICALL Java_io_glutenproject_exec_ExecutionCtxJniWrapper_createExecutionCtx( // NOLINT JNIEnv* env, - jclass) { + jclass, + jstring jbackendType, + jbyteArray sessionConf) { JNI_METHOD_START - auto executionCtx = gluten::createExecutionCtx(); + auto backendType = jStringToCString(env, jbackendType); + auto sparkConf = gluten::getConfMap(env, sessionConf); + auto executionCtx = gluten::ExecutionCtx::create(backendType, sparkConf); return reinterpret_cast(executionCtx); JNI_METHOD_END(kInvalidResourceHandle) } @@ -328,7 +332,7 @@ JNIEXPORT void JNICALL Java_io_glutenproject_exec_ExecutionCtxJniWrapper_release JNI_METHOD_START auto executionCtx = jniCastOrThrow(ctxHandle); - gluten::releaseExecutionCtx(executionCtx); + gluten::ExecutionCtx::release(executionCtx); JNI_METHOD_END() } @@ -343,8 +347,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI jint partitionId, jlong taskId, jboolean saveInput, - jstring spillDir, - jbyteArray confArr) { + jstring spillDir) { JNI_METHOD_START auto ctx = gluten::getExecutionCtx(env, wrapper); @@ -356,8 +359,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI auto planSize = env->GetArrayLength(planArr); ctx->parsePlan(planData, planSize, {stageId, partitionId, taskId}); - - auto confs = getConfMap(env, confArr); + auto& conf = ctx->getConfMap(); // Handle the Java iters jsize itersLen = env->GetArrayLength(iterArr); @@ -365,12 +367,12 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI for (int idx = 0; idx < itersLen; idx++) { std::shared_ptr writer = nullptr; if (saveInput) { - auto dir = confs[kGlutenSaveDir]; + auto dir = conf.at(kGlutenSaveDir); std::filesystem::path f{dir}; if (!std::filesystem::exists(f)) { throw gluten::GlutenException("Save input path " + dir + " does not exists"); } - auto file = confs[kGlutenSaveDir] + "/input_" + std::to_string(taskId) + "_" + std::to_string(idx) + "_" + + auto file = conf.at(kGlutenSaveDir) + "/input_" + std::to_string(taskId) + "_" + std::to_string(idx) + "_" + std::to_string(partitionId) + ".parquet"; writer = std::make_shared(file); } @@ -380,7 +382,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI inputIters.push_back(std::move(resultIter)); } - return ctx->createResultIterator(memoryManager, spillDirStr, inputIters, confs); + return ctx->createResultIterator(memoryManager, spillDirStr, inputIters, conf); JNI_METHOD_END(kInvalidResourceHandle) } @@ -1066,13 +1068,13 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_datasource_velox_DatasourceJniWrap // Only inspect the schema and not write handle = ctx->createDatasource(jStringToCString(env, filePath), memoryManager, nullptr); } else { - auto sparkOptions = gluten::getConfMap(env, options); - auto sparkConf = ctx->getConfMap(); - sparkOptions.insert(sparkConf.begin(), sparkConf.end()); + auto datasourceOptions = gluten::getConfMap(env, options); + auto& sparkConf = ctx->getConfMap(); + datasourceOptions.insert(sparkConf.begin(), sparkConf.end()); auto schema = gluten::arrowGetOrThrow(arrow::ImportSchema(reinterpret_cast(cSchema))); handle = ctx->createDatasource(jStringToCString(env, filePath), memoryManager, schema); auto datasource = ctx->getDatasource(handle); - datasource->init(sparkOptions); + datasource->init(datasourceOptions); } return handle; @@ -1214,7 +1216,8 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_alloc_NativeMemoryAllocator JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_create( // NOLINT JNIEnv* env, jclass, - jstring jname, + jstring jbackendType, + jstring jnmmName, jlong allocatorId, jlong reservationBlockSize, jobject jlistener) { @@ -1235,11 +1238,12 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_cre listener = std::make_unique(std::move(listener)); } - auto name = jStringToCString(env, jname); + auto name = jStringToCString(env, jnmmName); + auto backendType = jStringToCString(env, jbackendType); // TODO: move memory manager into ExecutionCtx then we can use more general ExecutionCtx. - auto executionCtx = gluten::createExecutionCtx(); + auto executionCtx = gluten::ExecutionCtx::create(backendType); auto manager = executionCtx->createMemoryManager(name, *allocator, std::move(listener)); - gluten::releaseExecutionCtx(executionCtx); + gluten::ExecutionCtx::release(executionCtx); return reinterpret_cast(manager); JNI_METHOD_END(kInvalidResourceHandle) } diff --git a/cpp/velox/benchmarks/BenchmarkUtils.cc b/cpp/velox/benchmarks/BenchmarkUtils.cc index 29fdfb0a3eb8..1c973ec2423a 100644 --- a/cpp/velox/benchmarks/BenchmarkUtils.cc +++ b/cpp/velox/benchmarks/BenchmarkUtils.cc @@ -35,14 +35,14 @@ namespace { std::unordered_map bmConfMap = {{gluten::kSparkBatchSize, FLAGS_batch_size}}; -gluten::ExecutionCtx* veloxExecutionCtxFactory(const std::unordered_map& sparkConfs) { - return new gluten::VeloxExecutionCtx(sparkConfs); +gluten::ExecutionCtx* veloxExecutionCtxFactory(const std::unordered_map& sparkConf) { + return new gluten::VeloxExecutionCtx(sparkConf); } } // anonymous namespace void initVeloxBackend(std::unordered_map& conf) { - gluten::setExecutionCtxFactory(veloxExecutionCtxFactory, conf); + gluten::ExecutionCtx::registerFactory(gluten::kVeloxExecutionCtxKind, veloxExecutionCtxFactory); gluten::VeloxBackend::create(conf); } diff --git a/cpp/velox/benchmarks/GenericBenchmark.cc b/cpp/velox/benchmarks/GenericBenchmark.cc index ad47dda6633b..0f81e646592b 100644 --- a/cpp/velox/benchmarks/GenericBenchmark.cc +++ b/cpp/velox/benchmarks/GenericBenchmark.cc @@ -125,7 +125,7 @@ auto BM_Generic = [](::benchmark::State& state, setCpu(state.thread_index()); } auto memoryManager = getDefaultMemoryManager(); - auto executionCtx = gluten::createExecutionCtx(); + auto executionCtx = ExecutionCtx::create(kVeloxExecutionCtxKind); const auto& filePath = getExampleFilePath(substraitJsonFile); auto plan = getPlanFromFile(filePath); auto startTime = std::chrono::steady_clock::now(); @@ -206,7 +206,7 @@ auto BM_Generic = [](::benchmark::State& state, auto statsStr = facebook::velox::exec::printPlanWithStats(*planNode, task->taskStats(), true); std::cout << statsStr << std::endl; } - gluten::releaseExecutionCtx(executionCtx); + ExecutionCtx::release(executionCtx); auto endTime = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime).count(); diff --git a/cpp/velox/benchmarks/ParquetWriteBenchmark.cc b/cpp/velox/benchmarks/ParquetWriteBenchmark.cc index 1f12bc0234d6..33e88267e039 100644 --- a/cpp/velox/benchmarks/ParquetWriteBenchmark.cc +++ b/cpp/velox/benchmarks/ParquetWriteBenchmark.cc @@ -256,7 +256,7 @@ class GoogleBenchmarkVeloxParquetWriteCacheScanBenchmark : public GoogleBenchmar // reuse the ParquetWriteConverter for batches caused system % increase a lot auto fileName = "velox_parquet_write.parquet"; - auto executionCtx = gluten::createExecutionCtx(); + auto executionCtx = ExecutionCtx::create(kVeloxExecutionCtxKind); auto memoryManager = getDefaultMemoryManager(); auto veloxPool = memoryManager->getAggregateMemoryPool(); @@ -292,7 +292,7 @@ class GoogleBenchmarkVeloxParquetWriteCacheScanBenchmark : public GoogleBenchmar benchmark::Counter(initTime, benchmark::Counter::kAvgThreads, benchmark::Counter::OneK::kIs1000); state.counters["write_time"] = benchmark::Counter(writeTime, benchmark::Counter::kAvgThreads, benchmark::Counter::OneK::kIs1000); - gluten::releaseExecutionCtx(executionCtx); + ExecutionCtx::release(executionCtx); } }; diff --git a/cpp/velox/benchmarks/QueryBenchmark.cc b/cpp/velox/benchmarks/QueryBenchmark.cc index bfdd5c41d877..7d26f08c591a 100644 --- a/cpp/velox/benchmarks/QueryBenchmark.cc +++ b/cpp/velox/benchmarks/QueryBenchmark.cc @@ -78,7 +78,7 @@ auto BM = [](::benchmark::State& state, auto plan = getPlanFromFile(filePath); auto memoryManager = getDefaultMemoryManager(); - auto executionCtx = gluten::createExecutionCtx(); + auto executionCtx = ExecutionCtx::create(kVeloxExecutionCtxKind); auto veloxPool = memoryManager->getAggregateMemoryPool(); std::vector> scanInfos; @@ -109,7 +109,7 @@ auto BM = [](::benchmark::State& state, std::cout << maybeBatch.ValueOrDie()->ToString() << std::endl; } } - gluten::releaseExecutionCtx(executionCtx); + ExecutionCtx::release(executionCtx); }; #define orc_reader_decimal 1 diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 2860483676f1..a5700302ccbe 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -117,7 +117,7 @@ void VeloxBackend::printConf(const std::unordered_map& } void VeloxBackend::init(const std::unordered_map& conf) { - // In spark, planner takes care the parititioning and sorting, so the rows are sorted. + // In spark, planner takes care the partitioning and sorting, so the rows are sorted. // There is no need to sort the rows in window op again. FLAGS_SkipRowSortInWindowOp = true; // Avoid creating too many shared leaf pools. @@ -266,7 +266,7 @@ void VeloxBackend::initJolFilesystem(const std::unordered_mapsecond); } - // FIMXE It's known that if spill compression is disabled, the actual spill file size may + // FIXME It's known that if spill compression is disabled, the actual spill file size may // in crease beyond this limit a little (maximum 64 rows which is by default // one compression page) gluten::registerJolFileSystem(maxSpillFileSize); @@ -353,24 +353,18 @@ void VeloxBackend::initUdf(const std::unordered_map& c } } +std::unique_ptr VeloxBackend::instance_ = nullptr; + void VeloxBackend::create(const std::unordered_map& conf) { - std::lock_guard lockGuard(mutex_); - if (instance_ != nullptr) { - assert(false); - throw gluten::GlutenException("VeloxBackend already set"); - } - instance_.reset(new gluten::VeloxBackend(conf)); + instance_ = std::unique_ptr(new gluten::VeloxBackend(conf)); } -std::shared_ptr VeloxBackend::get() { - std::lock_guard lockGuard(mutex_); - if (instance_ == nullptr) { - LOG(INFO) << "VeloxBackend not set, using default VeloxBackend instance. This should only happen in test code."; - static const std::unordered_map kEmptyConf; - static std::shared_ptr defaultInstance{new gluten::VeloxBackend(kEmptyConf)}; - return defaultInstance; +VeloxBackend* VeloxBackend::get() { + if (!instance_) { + LOG(WARNING) << "VeloxBackend instance is null, please invoke VeloxBackend#create before use."; + throw GlutenException("VeloxBackend instance is null."); } - return instance_; + return instance_.get(); } } // namespace gluten diff --git a/cpp/velox/compute/VeloxBackend.h b/cpp/velox/compute/VeloxBackend.h index c2d3eca38d94..686862333ea0 100644 --- a/cpp/velox/compute/VeloxBackend.h +++ b/cpp/velox/compute/VeloxBackend.h @@ -46,7 +46,7 @@ class VeloxBackend { static void create(const std::unordered_map& conf); - static std::shared_ptr get(); + static VeloxBackend* get(); facebook::velox::memory::MemoryAllocator* getAsyncDataCache() const; @@ -68,8 +68,7 @@ class VeloxBackend { return "cache." + boost::lexical_cast(boost::uuids::random_generator()()) + "."; } - inline static std::shared_ptr instance_; - inline static std::mutex mutex_; + static std::unique_ptr instance_; // Instance of AsyncDataCache used for all large allocations. std::shared_ptr asyncDataCache_ = diff --git a/cpp/velox/compute/VeloxExecutionCtx.h b/cpp/velox/compute/VeloxExecutionCtx.h index 12eea5e9a9eb..46be6535d32a 100644 --- a/cpp/velox/compute/VeloxExecutionCtx.h +++ b/cpp/velox/compute/VeloxExecutionCtx.h @@ -30,6 +30,9 @@ namespace gluten { +// This kind string must be same with VeloxBackend#name in java side. +inline static const std::string kVeloxExecutionCtxKind{"velox"}; + class VeloxExecutionCtx final : public ExecutionCtx { public: explicit VeloxExecutionCtx(const std::unordered_map& confMap); diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc index f2fc20f77181..809e1c9f8439 100644 --- a/cpp/velox/jni/VeloxJniWrapper.cc +++ b/cpp/velox/jni/VeloxJniWrapper.cc @@ -34,11 +34,9 @@ using namespace facebook; namespace { - -gluten::ExecutionCtx* veloxExecutionCtxFactory(const std::unordered_map& sparkConfs) { - return new gluten::VeloxExecutionCtx(sparkConfs); +gluten::ExecutionCtx* veloxExecutionCtxFactory(const std::unordered_map& sessionConf) { + return new gluten::VeloxExecutionCtx(sessionConf); } - } // namespace #ifdef __cplusplus @@ -77,11 +75,11 @@ void JNI_OnUnload(JavaVM* vm, void*) { JNIEXPORT void JNICALL Java_io_glutenproject_init_NativeBackendInitializer_initialize( // NOLINT JNIEnv* env, jclass, - jbyteArray planArray) { + jbyteArray conf) { JNI_METHOD_START - auto sparkConfs = gluten::getConfMap(env, planArray); - gluten::setExecutionCtxFactory(veloxExecutionCtxFactory, sparkConfs); - gluten::VeloxBackend::create(sparkConfs); + auto sparkConf = gluten::getConfMap(env, conf); + gluten::ExecutionCtx::registerFactory(gluten::kVeloxExecutionCtxKind, veloxExecutionCtxFactory); + gluten::VeloxBackend::create(sparkConf); JNI_METHOD_END() } diff --git a/cpp/velox/tests/ExecutionCtxTest.cc b/cpp/velox/tests/ExecutionCtxTest.cc index 48c047ae2b3f..8242135c545d 100644 --- a/cpp/velox/tests/ExecutionCtxTest.cc +++ b/cpp/velox/tests/ExecutionCtxTest.cc @@ -23,6 +23,8 @@ namespace gluten { class DummyExecutionCtx final : public ExecutionCtx { public: + DummyExecutionCtx(const std::unordered_map& conf) : ExecutionCtx(conf) {} + ResourceHandle createResultIterator( MemoryManager* memoryManager, const std::string& spillDir, @@ -147,19 +149,19 @@ class DummyExecutionCtx final : public ExecutionCtx { }; }; -static ExecutionCtx* DummyExecutionCtxFactory() { - return new DummyExecutionCtx(); +static ExecutionCtx* DummyExecutionCtxFactory(const std::unordered_map conf) { + return new DummyExecutionCtx(conf); } TEST(TestExecutionCtx, CreateExecutionCtx) { - setExecutionCtxFactory(DummyExecutionCtxFactory); - auto executionCtx = createExecutionCtx(); + ExecutionCtx::registerFactory("DUMMY", DummyExecutionCtxFactory); + auto executionCtx = ExecutionCtx::create("DUMMY"); ASSERT_EQ(typeid(*executionCtx), typeid(DummyExecutionCtx)); - releaseExecutionCtx(executionCtx); + ExecutionCtx::release(executionCtx); } TEST(TestExecutionCtx, GetResultIterator) { - auto executionCtx = std::make_shared(); + auto executionCtx = std::make_shared(std::unordered_map()); auto handle = executionCtx->createResultIterator(nullptr, "/tmp/test-spill", {}, {}); auto iter = executionCtx->getResultIterator(handle); ASSERT_TRUE(iter->hasNext()); diff --git a/gluten-data/src/main/java/io/glutenproject/exec/ExecutionCtxJniWrapper.java b/gluten-data/src/main/java/io/glutenproject/exec/ExecutionCtxJniWrapper.java index 5206014b72de..2b36a12a130c 100644 --- a/gluten-data/src/main/java/io/glutenproject/exec/ExecutionCtxJniWrapper.java +++ b/gluten-data/src/main/java/io/glutenproject/exec/ExecutionCtxJniWrapper.java @@ -20,7 +20,7 @@ public class ExecutionCtxJniWrapper { private ExecutionCtxJniWrapper() {} - public static native long createExecutionCtx(); + public static native long createExecutionCtx(String backendType, byte[] sessionConf); public static native void releaseExecutionCtx(long handle); } diff --git a/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java b/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java index c7d0d8184563..33a7871d5d7b 100644 --- a/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java +++ b/gluten-data/src/main/java/io/glutenproject/memory/nmm/NativeMemoryManager.java @@ -17,6 +17,7 @@ package io.glutenproject.memory.nmm; import io.glutenproject.GlutenConfig; +import io.glutenproject.backendsapi.BackendsApiManager; import io.glutenproject.memory.alloc.NativeMemoryAllocators; import org.apache.spark.util.TaskResource; @@ -43,7 +44,10 @@ public static NativeMemoryManager create(String name, ReservationListener listen long allocatorId = NativeMemoryAllocators.getDefault().globalInstance().getNativeInstanceId(); long reservationBlockSize = GlutenConfig.getConf().memoryReservationBlockSize(); return new NativeMemoryManager( - name, create(name, allocatorId, reservationBlockSize, listener), listener); + name, + create( + BackendsApiManager.getBackendName(), name, allocatorId, reservationBlockSize, listener), + listener); } public long getNativeInstanceHandle() { @@ -61,7 +65,11 @@ public long shrink(long size) { private static native long shrink(long memoryManagerId, long size); private static native long create( - String name, long allocatorId, long reservationBlockSize, ReservationListener listener); + String backendType, + String name, + long allocatorId, + long reservationBlockSize, + ReservationListener listener); private static native void release(long memoryManagerId); diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java b/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java index 9b008e9612a7..e8a3a79002b7 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/NativePlanEvaluator.java @@ -16,29 +16,19 @@ */ package io.glutenproject.vectorized; -import io.glutenproject.GlutenConfig; import io.glutenproject.backendsapi.BackendsApiManager; import io.glutenproject.exec.ExecutionCtx; import io.glutenproject.exec.ExecutionCtxs; import io.glutenproject.memory.nmm.NativeMemoryManagers; -import io.glutenproject.substrait.expression.ExpressionBuilder; -import io.glutenproject.substrait.expression.StringMapNode; -import io.glutenproject.substrait.extensions.AdvancedExtensionNode; -import io.glutenproject.substrait.extensions.ExtensionBuilder; -import io.glutenproject.substrait.plan.PlanBuilder; -import io.glutenproject.substrait.plan.PlanNode; import io.glutenproject.utils.DebugUtil; import io.glutenproject.validate.NativePlanValidationInfo; -import com.google.protobuf.Any; import io.substrait.proto.Plan; import org.apache.spark.TaskContext; -import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.util.SparkDirectoryUtil; import java.io.IOException; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.atomic.AtomicReference; @@ -64,13 +54,6 @@ public NativePlanValidationInfo doNativeValidateWithFailureReason(byte[] subPlan return jniWrapper.nativeValidateWithFailureReason(subPlan); } - private PlanNode buildNativeConfNode(Map confs) { - StringMapNode stringMapNode = ExpressionBuilder.makeStringMap(confs); - AdvancedExtensionNode extensionNode = - ExtensionBuilder.makeAdvancedExtension(Any.pack(stringMapNode.toProtobuf())); - return PlanBuilder.makePlan(extensionNode); - } - // Used by WholeStageTransform to create the native computing pipeline and // return a columnar result iterator. public GeneralOutIterator createKernelWithBatchIterator( @@ -107,13 +90,7 @@ public GeneralOutIterator createKernelWithBatchIterator( TaskContext.getPartitionId(), TaskContext.get().taskAttemptId(), DebugUtil.saveInputToFile(), - BackendsApiManager.getSparkPlanExecApiInstance().rewriteSpillPath(spillDirPath), - buildNativeConfNode( - GlutenConfig.getNativeSessionConf( - BackendsApiManager.getSettings().getBackendConfigPrefix(), - SQLConf.get().getAllConfs())) - .toProtobuf() - .toByteArray()); + BackendsApiManager.getSparkPlanExecApiInstance().rewriteSpillPath(spillDirPath)); outIterator.set(createOutIterator(ExecutionCtxs.contextInstance(), iterHandle)); return outIterator.get(); } diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java b/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java index 765e2c988fd4..cabb86c4b38a 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/PlanEvaluatorJniWrapper.java @@ -68,8 +68,7 @@ public native long nativeCreateKernelWithIterator( int partitionId, long taskId, boolean saveInputToFile, - String spillDir, - byte[] confPlan) + String spillDir) throws RuntimeException; /** Create a native compute kernel and return a row iterator. */ diff --git a/gluten-data/src/main/scala/io/glutenproject/exec/ExecutionCtx.scala b/gluten-data/src/main/scala/io/glutenproject/exec/ExecutionCtx.scala index 02a144730642..4c61f5ed4924 100644 --- a/gluten-data/src/main/scala/io/glutenproject/exec/ExecutionCtx.scala +++ b/gluten-data/src/main/scala/io/glutenproject/exec/ExecutionCtx.scala @@ -16,11 +16,22 @@ */ package io.glutenproject.exec +import io.glutenproject.GlutenConfig +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.init.JniUtils + +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.TaskResource class ExecutionCtx private[exec] () extends TaskResource { - private val handle = ExecutionCtxJniWrapper.createExecutionCtx() + private val handle = ExecutionCtxJniWrapper.createExecutionCtx( + BackendsApiManager.getBackendName, + JniUtils.toNativeConf( + GlutenConfig.getNativeSessionConf( + BackendsApiManager.getSettings.getBackendConfigPrefix, + SQLConf.get.getAllConfs)) + ) def getHandle: Long = handle