diff --git a/utils/local-engine/Parser/SparkRowToCHColumn.h b/utils/local-engine/Parser/SparkRowToCHColumn.h index 49330e1b0813..c303bdedf7a8 100644 --- a/utils/local-engine/Parser/SparkRowToCHColumn.h +++ b/utils/local-engine/Parser/SparkRowToCHColumn.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace local_engine { @@ -142,9 +143,9 @@ class SparkRowToCHColumn int attached; JNIEnv * env = JNIUtils::getENV(&attached); - while (env->CallBooleanMethod(java_iter, spark_row_interator_hasNext)) + while (safeCallBooleanMethod(env, java_iter, spark_row_interator_hasNext)) { - jobject rows_buf = env->CallObjectMethod(java_iter, spark_row_iterator_nextBatch); + jobject rows_buf = safeCallObjectMethod(env, java_iter, spark_row_iterator_nextBatch); auto * rows_buf_ptr = static_cast(env->GetDirectBufferAddress(rows_buf)); int len = *(reinterpret_cast(rows_buf_ptr)); diff --git a/utils/local-engine/Shuffle/NativeSplitter.cpp b/utils/local-engine/Shuffle/NativeSplitter.cpp index c3bd8a158513..cb912af8093f 100644 --- a/utils/local-engine/Shuffle/NativeSplitter.cpp +++ b/utils/local-engine/Shuffle/NativeSplitter.cpp @@ -1,8 +1,9 @@ #include "NativeSplitter.h" #include #include +#include "Common/Exception.h" #include - +#include namespace local_engine { @@ -120,7 +121,7 @@ bool NativeSplitter::inputHasNext() { int attached; JNIEnv * env = JNIUtils::getENV(&attached); - bool next = env->CallBooleanMethod(input, iterator_has_next); + bool next = safeCallBooleanMethod(env, input, iterator_has_next); if (attached) { JNIUtils::detachCurrentThread(); @@ -132,7 +133,7 @@ int64_t NativeSplitter::inputNext() { int attached; JNIEnv * env = JNIUtils::getENV(&attached); - int64_t result = env->CallLongMethod(input, iterator_next); + int64_t result = safeCallLongMethod(env, input, iterator_next); if (attached) { JNIUtils::detachCurrentThread(); diff --git a/utils/local-engine/Shuffle/ShuffleReader.cpp b/utils/local-engine/Shuffle/ShuffleReader.cpp index ec42f53b6a92..a9d696eb6286 100644 --- a/utils/local-engine/Shuffle/ShuffleReader.cpp +++ b/utils/local-engine/Shuffle/ShuffleReader.cpp @@ -2,6 +2,7 @@ #include #include #include +#include using namespace DB; @@ -57,7 +58,7 @@ int ReadBufferFromJavaInputStream::readFromJava() buf = static_cast(env->NewGlobalRef(local_buf)); env->DeleteLocalRef(local_buf); } - jint count = env->CallIntMethod(java_in, ShuffleReader::input_stream_read, buf); + jint count = safeCallIntMethod(env, java_in, ShuffleReader::input_stream_read, buf); if (count > 0) { env->GetByteArrayRegion(buf, 0, count, reinterpret_cast(internal_buffer.begin())); diff --git a/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp b/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp index e7ba216b3fc4..e012ff957e25 100644 --- a/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp +++ b/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp @@ -1,5 +1,6 @@ #include "WriteBufferFromJavaOutputStream.h" #include +#include namespace local_engine { @@ -17,7 +18,7 @@ void WriteBufferFromJavaOutputStream::nextImpl() { size_t copy_num = std::min(offset() - bytes_write, buffer_size); env->SetByteArrayRegion(buffer, 0 , copy_num, reinterpret_cast(this->working_buffer.begin() + bytes_write)); - env->CallVoidMethod(output_stream, output_stream_write, buffer, 0, copy_num); + safeCallVoidMethod(env, output_stream, output_stream_write, buffer, 0, copy_num); bytes_write += copy_num; } if (attached) @@ -42,7 +43,7 @@ void WriteBufferFromJavaOutputStream::finalizeImpl() next(); int attached; JNIEnv * env = JNIUtils::getENV(&attached); - env->CallVoidMethod(output_stream, output_stream_flush); + safeCallVoidMethod(env, output_stream, output_stream_flush); if (attached) { JNIUtils::detachCurrentThread(); diff --git a/utils/local-engine/Storages/SourceFromJavaIter.cpp b/utils/local-engine/Storages/SourceFromJavaIter.cpp index d27c2cb5deaa..da086f10c94a 100644 --- a/utils/local-engine/Storages/SourceFromJavaIter.cpp +++ b/utils/local-engine/Storages/SourceFromJavaIter.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace local_engine { @@ -14,11 +15,11 @@ DB::Chunk SourceFromJavaIter::generate() { int attached; JNIEnv * env = JNIUtils::getENV(&attached); - jboolean has_next = env->CallBooleanMethod(java_iter,serialized_record_batch_iterator_hasNext); + jboolean has_next = safeCallBooleanMethod(env, java_iter,serialized_record_batch_iterator_hasNext); DB::Chunk result; if (has_next) { - jbyteArray block = static_cast(env->CallObjectMethod(java_iter, serialized_record_batch_iterator_next)); + jbyteArray block = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); DB::Block * data = reinterpret_cast(byteArrayToLong(env, block)); if (data->rows() > 0) { diff --git a/utils/local-engine/jni/jni_common.h b/utils/local-engine/jni/jni_common.h index 2a7b74f625b9..02d92cd154ca 100644 --- a/utils/local-engine/jni/jni_common.h +++ b/utils/local-engine/jni/jni_common.h @@ -1,5 +1,16 @@ #pragma once +#include +#include #include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +} namespace local_engine { @@ -13,4 +24,56 @@ jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name, jstring charTojstring(JNIEnv* env, const char* pat); +#define LOCAL_ENGINE_JNI_JMETHOD_START +#define LOCAL_ENGINE_JNI_JMETHOD_END(env) \ + if ((env)->ExceptionCheck())\ + {\ + (env)->ExceptionDescribe();\ + (env)->ExceptionClear();\ + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Call java method failed");\ + } + +template +jobject safeCallObjectMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallObjectMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env) + return ret; +} + +template +jboolean safeCallBooleanMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallBooleanMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); + return ret; +} + +template +jlong safeCallLongMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallLongMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); + return ret; +} + +template +jint safeCallIntMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallIntMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); + return ret; +} + +template +void safeCallVoidMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + env->CallVoidMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); +} } diff --git a/utils/local-engine/local_engine_jni.cpp b/utils/local-engine/local_engine_jni.cpp index 0f112e188f7f..461970fb0860 100644 --- a/utils/local-engine/local_engine_jni.cpp +++ b/utils/local-engine/local_engine_jni.cpp @@ -58,7 +58,7 @@ std::string jstring2string(JNIEnv * env, jstring jStr) jclass string_class = env->GetObjectClass(jStr); jmethodID get_bytes = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); - jbyteArray string_jbytes = static_cast(env->CallObjectMethod(jStr, get_bytes, env->NewStringUTF("UTF-8"))); + jbyteArray string_jbytes = static_cast(local_engine::safeCallObjectMethod(env, jStr, get_bytes, env->NewStringUTF("UTF-8"))); size_t length = static_cast(env->GetArrayLength(string_jbytes)); jbyte * p_bytes = env->GetByteArrayElements(string_jbytes, nullptr);