Skip to content

Commit

Permalink
[VL] Move session config to ExecutionCtx
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohahaha authored Oct 12, 2023
1 parent 537c096 commit dfea063
Show file tree
Hide file tree
Showing 17 changed files with 131 additions and 180 deletions.
52 changes: 37 additions & 15 deletions cpp/core/compute/ExecutionCtx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,50 @@

namespace gluten {

static ExecutionCtxFactoryContext* getExecutionCtxFactoryContext() {
static ExecutionCtxFactoryContext* executionCtxFactoryCtx = new ExecutionCtxFactoryContext;
return executionCtxFactoryCtx;
}
namespace {
class FactoryRegistry {
public:
void registerFactory(const std::string& kind, ExecutionCtx::Factory factory) {
std::lock_guard<std::mutex> l(mutex_);
GLUTEN_CHECK(map_.find(kind) == map_.end(), "ExecutionCtx factory already registered for " + kind);
map_[kind] = std::move(factory);
}

ExecutionCtx::Factory& getFactory(const std::string& kind) {
std::lock_guard<std::mutex> l(mutex_);
GLUTEN_CHECK(map_.find(kind) != map_.end(), "ExecutionCtx factory not registered for " + kind);
return map_[kind];
}

bool unregisterFactory(const std::string& kind) {
std::lock_guard<std::mutex> l(mutex_);
GLUTEN_CHECK(map_.find(kind) != map_.end(), "ExecutionCtx factory not registered for " + kind);
return map_.erase(kind);
}

private:
std::mutex mutex_;
std::unordered_map<std::string, ExecutionCtx::Factory> map_;
};

void setExecutionCtxFactory(
ExecutionCtxFactoryWithConf factory,
const std::unordered_map<std::string, std::string>& sparkConfs) {
getExecutionCtxFactoryContext()->set(factory, sparkConfs);
DEBUG_OUT << "Set execution context factory with conf." << std::endl;
FactoryRegistry& executionCtxFactories() {
static FactoryRegistry registry;
return registry;
}
} // namespace

void setExecutionCtxFactory(ExecutionCtxFactory factory) {
getExecutionCtxFactoryContext()->set(factory);
DEBUG_OUT << "Set execution context factory." << std::endl;
void ExecutionCtx::registerFactory(const std::string& kind, ExecutionCtx::Factory factory) {
executionCtxFactories().registerFactory(kind, std::move(factory));
}

ExecutionCtx* createExecutionCtx() {
return getExecutionCtxFactoryContext()->create();
ExecutionCtx* ExecutionCtx::create(
const std::string& kind,
const std::unordered_map<std::string, std::string>& sessionConf) {
auto& factory = executionCtxFactories().getFactory(kind);
return factory(sessionConf);
}

void releaseExecutionCtx(ExecutionCtx* executionCtx) {
void ExecutionCtx::release(ExecutionCtx* executionCtx) {
delete executionCtx;
}

Expand Down
84 changes: 9 additions & 75 deletions cpp/core/compute/ExecutionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ struct SparkTaskInfo {
/// ExecutionCtx is stateful and manager all kinds of native resources' lifecycle during execute a computation fragment.
class ExecutionCtx : public std::enable_shared_from_this<ExecutionCtx> {
public:
using Factory = std::function<ExecutionCtx*(const std::unordered_map<std::string, std::string>&)>;
static void registerFactory(const std::string& kind, Factory factory);
static ExecutionCtx* create(
const std::string& kind,
const std::unordered_map<std::string, std::string>& sessionConf = {});
static void release(ExecutionCtx*);

ExecutionCtx() = default;
ExecutionCtx(const std::unordered_map<std::string, std::string>& confMap) : confMap_(confMap) {}
virtual ~ExecutionCtx() = default;
Expand Down Expand Up @@ -156,7 +163,7 @@ class ExecutionCtx : public std::enable_shared_from_this<ExecutionCtx> {
throw GlutenException("Not implement getNonPartitionedColumnarBatch");
}

std::unordered_map<std::string, std::string> getConfMap() {
const std::unordered_map<std::string, std::string>& getConfMap() {
return confMap_;
}

Expand All @@ -168,79 +175,6 @@ class ExecutionCtx : public std::enable_shared_from_this<ExecutionCtx> {
::substrait::Plan substraitPlan_;
SparkTaskInfo taskInfo_;
// static conf map
std::unordered_map<std::string, std::string> confMap_;
const std::unordered_map<std::string, std::string> confMap_;
};

using ExecutionCtxFactoryWithConf = ExecutionCtx* (*)(const std::unordered_map<std::string, std::string>&);
using ExecutionCtxFactory = ExecutionCtx* (*)();

struct ExecutionCtxFactoryContext {
std::mutex mutex;

enum {
kExecutionCtxFactoryInvalid,
kExecutionCtxFactoryDefault,
kExecutionCtxFactoryWithConf
} type = kExecutionCtxFactoryInvalid;

union {
ExecutionCtxFactoryWithConf backendFactoryWithConf;
ExecutionCtxFactory backendFactory;
};

std::unordered_map<std::string, std::string> sparkConf_;

void set(ExecutionCtxFactoryWithConf factory, const std::unordered_map<std::string, std::string>& sparkConf) {
std::lock_guard<std::mutex> lockGuard(mutex);

if (type != kExecutionCtxFactoryInvalid) {
assert(false);
abort();
return;
}

type = kExecutionCtxFactoryWithConf;
backendFactoryWithConf = factory;
this->sparkConf_.clear();
for (auto& x : sparkConf) {
this->sparkConf_[x.first] = x.second;
}
}

void set(ExecutionCtxFactory factory) {
std::lock_guard<std::mutex> lockGuard(mutex);
if (type != kExecutionCtxFactoryInvalid) {
assert(false);
abort();
return;
}

type = kExecutionCtxFactoryDefault;
backendFactory = factory;
}

ExecutionCtx* create() {
std::lock_guard<std::mutex> lockGuard(mutex);
if (type == kExecutionCtxFactoryInvalid) {
assert(false);
abort();
return nullptr;
} else if (type == kExecutionCtxFactoryWithConf) {
return backendFactoryWithConf(sparkConf_);
} else {
return backendFactory();
}
}
};

void setExecutionCtxFactory(
ExecutionCtxFactoryWithConf factory,
const std::unordered_map<std::string, std::string>& sparkConf);

void setExecutionCtxFactory(ExecutionCtxFactory factory);

ExecutionCtx* createExecutionCtx();

void releaseExecutionCtx(ExecutionCtx*);

} // namespace gluten
40 changes: 22 additions & 18 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,13 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {

JNIEXPORT jlong JNICALL Java_io_glutenproject_exec_ExecutionCtxJniWrapper_createExecutionCtx( // NOLINT
JNIEnv* env,
jclass) {
jclass,
jstring jbackendType,
jbyteArray sessionConf) {
JNI_METHOD_START
auto executionCtx = gluten::createExecutionCtx();
auto backendType = jStringToCString(env, jbackendType);
auto sparkConf = gluten::getConfMap(env, sessionConf);
auto executionCtx = gluten::ExecutionCtx::create(backendType, sparkConf);
return reinterpret_cast<jlong>(executionCtx);
JNI_METHOD_END(kInvalidResourceHandle)
}
Expand All @@ -328,7 +332,7 @@ JNIEXPORT void JNICALL Java_io_glutenproject_exec_ExecutionCtxJniWrapper_release
JNI_METHOD_START
auto executionCtx = jniCastOrThrow<ExecutionCtx>(ctxHandle);

gluten::releaseExecutionCtx(executionCtx);
gluten::ExecutionCtx::release(executionCtx);
JNI_METHOD_END()
}

Expand All @@ -343,8 +347,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI
jint partitionId,
jlong taskId,
jboolean saveInput,
jstring spillDir,
jbyteArray confArr) {
jstring spillDir) {
JNI_METHOD_START

auto ctx = gluten::getExecutionCtx(env, wrapper);
Expand All @@ -356,21 +359,20 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI
auto planSize = env->GetArrayLength(planArr);

ctx->parsePlan(planData, planSize, {stageId, partitionId, taskId});

auto confs = getConfMap(env, confArr);
auto& conf = ctx->getConfMap();

// Handle the Java iters
jsize itersLen = env->GetArrayLength(iterArr);
std::vector<std::shared_ptr<ResultIterator>> inputIters;
for (int idx = 0; idx < itersLen; idx++) {
std::shared_ptr<ArrowWriter> writer = nullptr;
if (saveInput) {
auto dir = confs[kGlutenSaveDir];
auto dir = conf.at(kGlutenSaveDir);
std::filesystem::path f{dir};
if (!std::filesystem::exists(f)) {
throw gluten::GlutenException("Save input path " + dir + " does not exists");
}
auto file = confs[kGlutenSaveDir] + "/input_" + std::to_string(taskId) + "_" + std::to_string(idx) + "_" +
auto file = conf.at(kGlutenSaveDir) + "/input_" + std::to_string(taskId) + "_" + std::to_string(idx) + "_" +
std::to_string(partitionId) + ".parquet";
writer = std::make_shared<ArrowWriter>(file);
}
Expand All @@ -380,7 +382,7 @@ Java_io_glutenproject_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWithI
inputIters.push_back(std::move(resultIter));
}

return ctx->createResultIterator(memoryManager, spillDirStr, inputIters, confs);
return ctx->createResultIterator(memoryManager, spillDirStr, inputIters, conf);
JNI_METHOD_END(kInvalidResourceHandle)
}

Expand Down Expand Up @@ -1066,13 +1068,13 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_datasource_velox_DatasourceJniWrap
// Only inspect the schema and not write
handle = ctx->createDatasource(jStringToCString(env, filePath), memoryManager, nullptr);
} else {
auto sparkOptions = gluten::getConfMap(env, options);
auto sparkConf = ctx->getConfMap();
sparkOptions.insert(sparkConf.begin(), sparkConf.end());
auto datasourceOptions = gluten::getConfMap(env, options);
auto& sparkConf = ctx->getConfMap();
datasourceOptions.insert(sparkConf.begin(), sparkConf.end());
auto schema = gluten::arrowGetOrThrow(arrow::ImportSchema(reinterpret_cast<struct ArrowSchema*>(cSchema)));
handle = ctx->createDatasource(jStringToCString(env, filePath), memoryManager, schema);
auto datasource = ctx->getDatasource(handle);
datasource->init(sparkOptions);
datasource->init(datasourceOptions);
}

return handle;
Expand Down Expand Up @@ -1214,7 +1216,8 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_alloc_NativeMemoryAllocator
JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_create( // NOLINT
JNIEnv* env,
jclass,
jstring jname,
jstring jbackendType,
jstring jnmmName,
jlong allocatorId,
jlong reservationBlockSize,
jobject jlistener) {
Expand All @@ -1235,11 +1238,12 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_memory_nmm_NativeMemoryManager_cre
listener = std::make_unique<BacktraceAllocationListener>(std::move(listener));
}

auto name = jStringToCString(env, jname);
auto name = jStringToCString(env, jnmmName);
auto backendType = jStringToCString(env, jbackendType);
// TODO: move memory manager into ExecutionCtx then we can use more general ExecutionCtx.
auto executionCtx = gluten::createExecutionCtx();
auto executionCtx = gluten::ExecutionCtx::create(backendType);
auto manager = executionCtx->createMemoryManager(name, *allocator, std::move(listener));
gluten::releaseExecutionCtx(executionCtx);
gluten::ExecutionCtx::release(executionCtx);
return reinterpret_cast<jlong>(manager);
JNI_METHOD_END(kInvalidResourceHandle)
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/velox/benchmarks/BenchmarkUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ namespace {

std::unordered_map<std::string, std::string> bmConfMap = {{gluten::kSparkBatchSize, FLAGS_batch_size}};

gluten::ExecutionCtx* veloxExecutionCtxFactory(const std::unordered_map<std::string, std::string>& sparkConfs) {
return new gluten::VeloxExecutionCtx(sparkConfs);
gluten::ExecutionCtx* veloxExecutionCtxFactory(const std::unordered_map<std::string, std::string>& sparkConf) {
return new gluten::VeloxExecutionCtx(sparkConf);
}

} // anonymous namespace

void initVeloxBackend(std::unordered_map<std::string, std::string>& conf) {
gluten::setExecutionCtxFactory(veloxExecutionCtxFactory, conf);
gluten::ExecutionCtx::registerFactory(gluten::kVeloxExecutionCtxKind, veloxExecutionCtxFactory);
gluten::VeloxBackend::create(conf);
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/benchmarks/GenericBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ auto BM_Generic = [](::benchmark::State& state,
setCpu(state.thread_index());
}
auto memoryManager = getDefaultMemoryManager();
auto executionCtx = gluten::createExecutionCtx();
auto executionCtx = ExecutionCtx::create(kVeloxExecutionCtxKind);
const auto& filePath = getExampleFilePath(substraitJsonFile);
auto plan = getPlanFromFile(filePath);
auto startTime = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -206,7 +206,7 @@ auto BM_Generic = [](::benchmark::State& state,
auto statsStr = facebook::velox::exec::printPlanWithStats(*planNode, task->taskStats(), true);
std::cout << statsStr << std::endl;
}
gluten::releaseExecutionCtx(executionCtx);
ExecutionCtx::release(executionCtx);

auto endTime = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - startTime).count();
Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/benchmarks/ParquetWriteBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class GoogleBenchmarkVeloxParquetWriteCacheScanBenchmark : public GoogleBenchmar
// reuse the ParquetWriteConverter for batches caused system % increase a lot
auto fileName = "velox_parquet_write.parquet";

auto executionCtx = gluten::createExecutionCtx();
auto executionCtx = ExecutionCtx::create(kVeloxExecutionCtxKind);
auto memoryManager = getDefaultMemoryManager();
auto veloxPool = memoryManager->getAggregateMemoryPool();

Expand Down Expand Up @@ -292,7 +292,7 @@ class GoogleBenchmarkVeloxParquetWriteCacheScanBenchmark : public GoogleBenchmar
benchmark::Counter(initTime, benchmark::Counter::kAvgThreads, benchmark::Counter::OneK::kIs1000);
state.counters["write_time"] =
benchmark::Counter(writeTime, benchmark::Counter::kAvgThreads, benchmark::Counter::OneK::kIs1000);
gluten::releaseExecutionCtx(executionCtx);
ExecutionCtx::release(executionCtx);
}
};

Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/benchmarks/QueryBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ auto BM = [](::benchmark::State& state,
auto plan = getPlanFromFile(filePath);

auto memoryManager = getDefaultMemoryManager();
auto executionCtx = gluten::createExecutionCtx();
auto executionCtx = ExecutionCtx::create(kVeloxExecutionCtxKind);
auto veloxPool = memoryManager->getAggregateMemoryPool();

std::vector<std::shared_ptr<SplitInfo>> scanInfos;
Expand Down Expand Up @@ -109,7 +109,7 @@ auto BM = [](::benchmark::State& state,
std::cout << maybeBatch.ValueOrDie()->ToString() << std::endl;
}
}
gluten::releaseExecutionCtx(executionCtx);
ExecutionCtx::release(executionCtx);
};

#define orc_reader_decimal 1
Expand Down
Loading

0 comments on commit dfea063

Please sign in to comment.