Skip to content

Commit

Permalink
[GLUTEN-6279][CH] Inroduce JNI safe array (#6280)
Browse files Browse the repository at this point in the history
* jni safe array

* return JString

* better
  • Loading branch information
baibaichen authored Jul 2, 2024
1 parent f9bd490 commit dde87b1
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.vectorized;

import org.apache.gluten.metrics.IMetrics;
import org.apache.gluten.metrics.NativeMetrics;

import org.apache.spark.sql.execution.utils.CHExecUtil;
import org.apache.spark.sql.vectorized.ColumnVector;
Expand Down Expand Up @@ -50,7 +51,7 @@ public String getId() {

private native void nativeCancel(long nativeHandle);

private native IMetrics nativeFetchMetrics(long nativeHandle);
private native String nativeFetchMetrics(long nativeHandle);

@Override
public boolean hasNextInternal() throws IOException {
Expand All @@ -72,8 +73,8 @@ public ColumnarBatch nextInternal() throws IOException {
}

@Override
public IMetrics getMetricsInternal() throws IOException, ClassNotFoundException {
return nativeFetchMetrics(handle);
public IMetrics getMetricsInternal() {
return new NativeMetrics(nativeFetchMetrics(handle));
}

@Override
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)

using namespace DB;

std::map<std::string, std::string> BackendInitializerUtil::getBackendConfMap(const std::string & plan)
std::map<std::string, std::string> BackendInitializerUtil::getBackendConfMap(const std::string_view plan)
{
std::map<std::string, std::string> ch_backend_conf;
if (plan.empty())
Expand Down Expand Up @@ -972,7 +972,7 @@ void BackendInitializerUtil::init(const std::string & plan)
});
}

void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, const std::string & plan)
void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, const std::string_view plan)
{
std::map<std::string, std::string> backend_conf_map = getBackendConfMap(plan);

Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class BackendInitializerUtil
/// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime
/// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver
static void init(const std::string & plan);
static void updateConfig(const DB::ContextMutablePtr &, const std::string &);
static void updateConfig(const DB::ContextMutablePtr &, const std::string_view);


// use excel text parser
Expand Down Expand Up @@ -199,7 +199,7 @@ class BackendInitializerUtil
static std::vector<String> wrapDiskPathConfig(const String & path_prefix, const String & path_suffix, Poco::Util::AbstractConfiguration & config);


static std::map<std::string, std::string> getBackendConfMap(const std::string & plan);
static std::map<std::string, std::string> getBackendConfMap(const std::string_view plan);

inline static std::once_flag init_flag;
inline static Poco::Logger * logger;
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/RelMetric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ const String & RelMetric::getName() const
return name;
}

std::string RelMetricSerializer::serializeRelMetric(RelMetricPtr rel_metric, bool flatten)
std::string RelMetricSerializer::serializeRelMetric(const RelMetricPtr & rel_metric, bool flatten)
{
StringBuffer result;
Writer<StringBuffer> writer(result);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/RelMetric.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class RelMetric
class RelMetricSerializer
{
public:
static std::string serializeRelMetric(RelMetricPtr rel_metric, bool flatten = true);
static std::string serializeRelMetric(const RelMetricPtr & rel_metric, bool flatten = true);
};
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,7 @@ std::unique_ptr<LocalExecutor> SerializedPlanParser::createExecutor(DB::QueryPla
context, std::move(query_plan), std::move(pipeline), query_plan->getCurrentDataStream().header.cloneEmpty());
}

QueryPlanPtr SerializedPlanParser::parse(const std::string_view & plan)
QueryPlanPtr SerializedPlanParser::parse(const std::string_view plan)
{
substrait::Plan s_plan;
/// https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data
Expand Down
6 changes: 3 additions & 3 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class SerializedPlanParser

std::unique_ptr<LocalExecutor> createExecutor(DB::QueryPlanPtr query_plan);

DB::QueryPlanPtr parse(const std::string_view & plan);
DB::QueryPlanPtr parse(const std::string_view plan);
DB::QueryPlanPtr parse(const substrait::Plan & plan);

public:
Expand All @@ -270,7 +270,7 @@ class SerializedPlanParser
///

template <bool JsonPlan>
std::unique_ptr<LocalExecutor> createExecutor(const std::string_view & plan);
std::unique_ptr<LocalExecutor> createExecutor(const std::string_view plan);

DB::QueryPlanStepPtr parseReadRealWithLocalFile(const substrait::ReadRel & rel);
DB::QueryPlanStepPtr parseReadRealWithJavaIter(const substrait::ReadRel & rel);
Expand Down Expand Up @@ -407,7 +407,7 @@ class SerializedPlanParser
};

template <bool JsonPlan>
std::unique_ptr<LocalExecutor> SerializedPlanParser::createExecutor(const std::string_view & plan)
std::unique_ptr<LocalExecutor> SerializedPlanParser::createExecutor(const std::string_view plan)
{
return createExecutor(JsonPlan ? parseJson(plan) : parse(plan));
}
Expand Down
14 changes: 7 additions & 7 deletions cpp-ch/local-engine/jni/jni_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name,

jstring charTojstring(JNIEnv * env, const char * pat)
{
jclass str_class = (env)->FindClass("Ljava/lang/String;");
jmethodID ctor_id = (env)->GetMethodID(str_class, "<init>", "([BLjava/lang/String;)V");
jsize strSize = static_cast<jsize>(strlen(pat));
jbyteArray bytes = (env)->NewByteArray(strSize);
(env)->SetByteArrayRegion(bytes, 0, strSize, reinterpret_cast<jbyte *>(const_cast<char *>(pat)));
jstring encoding = (env)->NewStringUTF("UTF-8");
jstring result = static_cast<jstring>((env)->NewObject(str_class, ctor_id, bytes, encoding));
const jclass str_class = (env)->FindClass("Ljava/lang/String;");
const jmethodID ctor_id = (env)->GetMethodID(str_class, "<init>", "([BLjava/lang/String;)V");
const jsize str_size = static_cast<jsize>(strlen(pat));
const jbyteArray bytes = (env)->NewByteArray(str_size);
(env)->SetByteArrayRegion(bytes, 0, str_size, reinterpret_cast<jbyte *>(const_cast<char *>(pat)));
const jstring encoding = (env)->NewStringUTF("UTF-8");
const auto result = static_cast<jstring>((env)->NewObject(str_class, ctor_id, bytes, encoding));
env->DeleteLocalRef(bytes);
env->DeleteLocalRef(encoding);
return result;
Expand Down
99 changes: 99 additions & 0 deletions cpp-ch/local-engine/jni/jni_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,103 @@ jlong safeCallStaticLongMethod(JNIEnv * env, jclass clazz, jmethodID method_id,
LOCAL_ENGINE_JNI_JMETHOD_END(env)
return ret;
}

// Safe version of JNI {Get|Release}<PrimitiveType>ArrayElements routines.
// SafeNativeArray would release the managed array elements automatically
// during destruction.

enum class JniPrimitiveArrayType {
kBoolean = 0,
kByte = 1,
kChar = 2,
kShort = 3,
kInt = 4,
kLong = 5,
kFloat = 6,
kDouble = 7
};

#define CONCATENATE(t1, t2, t3) t1##t2##t3

#define DEFINE_PRIMITIVE_ARRAY(PRIM_TYPE, JAVA_TYPE, JNI_NATIVE_TYPE, NATIVE_TYPE, METHOD_VAR) \
template <> \
struct JniPrimitiveArray<JniPrimitiveArrayType::PRIM_TYPE> { \
using JavaType = JAVA_TYPE; \
using JniNativeType = JNI_NATIVE_TYPE; \
using NativeType = NATIVE_TYPE; \
\
static JniNativeType get(JNIEnv* env, JavaType javaArray) { \
return env->CONCATENATE(Get, METHOD_VAR, ArrayElements)(javaArray, nullptr); \
} \
\
static void release(JNIEnv* env, JavaType javaArray, JniNativeType nativeArray) { \
env->CONCATENATE(Release, METHOD_VAR, ArrayElements)(javaArray, nativeArray, JNI_ABORT); \
} \
};

template <JniPrimitiveArrayType TYPE>
struct JniPrimitiveArray {};

DEFINE_PRIMITIVE_ARRAY(kBoolean, jbooleanArray, jboolean*, bool*, Boolean)
DEFINE_PRIMITIVE_ARRAY(kByte, jbyteArray, jbyte*, uint8_t*, Byte)
DEFINE_PRIMITIVE_ARRAY(kChar, jcharArray, jchar*, uint16_t*, Char)
DEFINE_PRIMITIVE_ARRAY(kShort, jshortArray, jshort*, int16_t*, Short)
DEFINE_PRIMITIVE_ARRAY(kInt, jintArray, jint*, int32_t*, Int)
DEFINE_PRIMITIVE_ARRAY(kLong, jlongArray, jlong*, int64_t*, Long)
DEFINE_PRIMITIVE_ARRAY(kFloat, jfloatArray, jfloat*, float_t*, Float)
DEFINE_PRIMITIVE_ARRAY(kDouble, jdoubleArray, jdouble*, double_t*, Double)

template <JniPrimitiveArrayType TYPE>
class SafeNativeArray {
using PrimitiveArray = JniPrimitiveArray<TYPE>;
using JavaArrayType = typename PrimitiveArray::JavaType;
using JniNativeArrayType = typename PrimitiveArray::JniNativeType;
using NativeArrayType = typename PrimitiveArray::NativeType;

public:
virtual ~SafeNativeArray() {
PrimitiveArray::release(env_, javaArray_, nativeArray_);
}

SafeNativeArray(const SafeNativeArray&) = delete;
SafeNativeArray(SafeNativeArray&&) = delete;
SafeNativeArray& operator=(const SafeNativeArray&) = delete;
SafeNativeArray& operator=(SafeNativeArray&&) = delete;

const NativeArrayType elems() const {
return reinterpret_cast<const NativeArrayType>(nativeArray_);
}

const jsize length() const {
return env_->GetArrayLength(javaArray_);
}

static SafeNativeArray<TYPE> get(JNIEnv* env, JavaArrayType javaArray) {
JniNativeArrayType nativeArray = PrimitiveArray::get(env, javaArray);
return SafeNativeArray<TYPE>(env, javaArray, nativeArray);
}

private:
SafeNativeArray(JNIEnv* env, JavaArrayType javaArray, JniNativeArrayType nativeArray)
: env_(env), javaArray_(javaArray), nativeArray_(nativeArray){};

JNIEnv* env_;
JavaArrayType javaArray_;
JniNativeArrayType nativeArray_;
};

#define DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(PRIM_TYPE, JAVA_TYPE, METHOD_VAR) \
inline SafeNativeArray<JniPrimitiveArrayType::PRIM_TYPE> CONCATENATE(get, METHOD_VAR, ArrayElementsSafe)( \
JNIEnv * env, JAVA_TYPE array) { \
return SafeNativeArray<JniPrimitiveArrayType::PRIM_TYPE>::get(env, array); \
}

DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kBoolean, jbooleanArray, Boolean)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kByte, jbyteArray, Byte)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kChar, jcharArray, Char)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kShort, jshortArray, Short)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kInt, jintArray, Int)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kLong, jlongArray, Long)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kFloat, jfloatArray, Float)
DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kDouble, jdoubleArray, Double)
}
Loading

0 comments on commit dde87b1

Please sign in to comment.