Skip to content

Commit

Permalink
improve: catch java exception in c++ (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc authored Oct 10, 2022
1 parent d6a6cc9 commit 1eed312
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 11 deletions.
5 changes: 3 additions & 2 deletions utils/local-engine/Parser/SparkRowToCHColumn.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <Parser/CHColumnToSparkRow.h>
#include <base/StringRef.h>
#include <Common/JNIUtils.h>
#include <jni/jni_common.h>

namespace local_engine
{
Expand Down Expand Up @@ -142,9 +143,9 @@ class SparkRowToCHColumn

int attached;
JNIEnv * env = JNIUtils::getENV(&attached);
while (env->CallBooleanMethod(java_iter, spark_row_interator_hasNext))
while (safeCallBooleanMethod(env, java_iter, spark_row_interator_hasNext))
{
jobject rows_buf = env->CallObjectMethod(java_iter, spark_row_iterator_nextBatch);
jobject rows_buf = safeCallObjectMethod(env, java_iter, spark_row_iterator_nextBatch);
auto * rows_buf_ptr = static_cast<char*>(env->GetDirectBufferAddress(rows_buf));
int len = *(reinterpret_cast<int*>(rows_buf_ptr));

Expand Down
7 changes: 4 additions & 3 deletions utils/local-engine/Shuffle/NativeSplitter.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "NativeSplitter.h"
#include <Functions/FunctionFactory.h>
#include <Parser/SerializedPlanParser.h>
#include "Common/Exception.h"
#include <Common/JNIUtils.h>

#include <jni/jni_common.h>

namespace local_engine
{
Expand Down Expand Up @@ -120,7 +121,7 @@ bool NativeSplitter::inputHasNext()
{
int attached;
JNIEnv * env = JNIUtils::getENV(&attached);
bool next = env->CallBooleanMethod(input, iterator_has_next);
bool next = safeCallBooleanMethod(env, input, iterator_has_next);
if (attached)
{
JNIUtils::detachCurrentThread();
Expand All @@ -132,7 +133,7 @@ int64_t NativeSplitter::inputNext()
{
int attached;
JNIEnv * env = JNIUtils::getENV(&attached);
int64_t result = env->CallLongMethod(input, iterator_next);
int64_t result = safeCallLongMethod(env, input, iterator_next);
if (attached)
{
JNIUtils::detachCurrentThread();
Expand Down
3 changes: 2 additions & 1 deletion utils/local-engine/Shuffle/ShuffleReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <Common/DebugUtils.h>
#include <Common/Stopwatch.h>
#include <Common/JNIUtils.h>
#include <jni/jni_common.h>

using namespace DB;

Expand Down Expand Up @@ -57,7 +58,7 @@ int ReadBufferFromJavaInputStream::readFromJava()
buf = static_cast<jbyteArray>(env->NewGlobalRef(local_buf));
env->DeleteLocalRef(local_buf);
}
jint count = env->CallIntMethod(java_in, ShuffleReader::input_stream_read, buf);
jint count = safeCallIntMethod(env, java_in, ShuffleReader::input_stream_read, buf);
if (count > 0)
{
env->GetByteArrayRegion(buf, 0, count, reinterpret_cast<jbyte *>(internal_buffer.begin()));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "WriteBufferFromJavaOutputStream.h"
#include <Common/JNIUtils.h>
#include <jni/jni_common.h>

namespace local_engine
{
Expand All @@ -17,7 +18,7 @@ void WriteBufferFromJavaOutputStream::nextImpl()
{
size_t copy_num = std::min(offset() - bytes_write, buffer_size);
env->SetByteArrayRegion(buffer, 0 , copy_num, reinterpret_cast<const jbyte *>(this->working_buffer.begin() + bytes_write));
env->CallVoidMethod(output_stream, output_stream_write, buffer, 0, copy_num);
safeCallVoidMethod(env, output_stream, output_stream_write, buffer, 0, copy_num);
bytes_write += copy_num;
}
if (attached)
Expand All @@ -42,7 +43,7 @@ void WriteBufferFromJavaOutputStream::finalizeImpl()
next();
int attached;
JNIEnv * env = JNIUtils::getENV(&attached);
env->CallVoidMethod(output_stream, output_stream_flush);
safeCallVoidMethod(env, output_stream, output_stream_flush);
if (attached)
{
JNIUtils::detachCurrentThread();
Expand Down
5 changes: 3 additions & 2 deletions utils/local-engine/Storages/SourceFromJavaIter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <Common/DebugUtils.h>
#include <Common/JNIUtils.h>
#include <Columns/ColumnNullable.h>
#include <jni/jni_common.h>

namespace local_engine
{
Expand All @@ -14,11 +15,11 @@ DB::Chunk SourceFromJavaIter::generate()
{
int attached;
JNIEnv * env = JNIUtils::getENV(&attached);
jboolean has_next = env->CallBooleanMethod(java_iter,serialized_record_batch_iterator_hasNext);
jboolean has_next = safeCallBooleanMethod(env, java_iter,serialized_record_batch_iterator_hasNext);
DB::Chunk result;
if (has_next)
{
jbyteArray block = static_cast<jbyteArray>(env->CallObjectMethod(java_iter, serialized_record_batch_iterator_next));
jbyteArray block = static_cast<jbyteArray>(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next));
DB::Block * data = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block));
if (data->rows() > 0)
{
Expand Down
63 changes: 63 additions & 0 deletions utils/local-engine/jni/jni_common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
#pragma once
#include <exception>
#include <stdexcept>
#include <jni.h>
#include <Common/Exception.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
}

namespace local_engine
{
Expand All @@ -13,4 +24,56 @@ jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name,

jstring charTojstring(JNIEnv* env, const char* pat);

#define LOCAL_ENGINE_JNI_JMETHOD_START
#define LOCAL_ENGINE_JNI_JMETHOD_END(env) \
if ((env)->ExceptionCheck())\
{\
(env)->ExceptionDescribe();\
(env)->ExceptionClear();\
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Call java method failed");\
}

template <typename ... Args>
jobject safeCallObjectMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args)
{
LOCAL_ENGINE_JNI_JMETHOD_START
auto ret = env->CallObjectMethod(obj, method_id, args...);
LOCAL_ENGINE_JNI_JMETHOD_END(env)
return ret;
}

template <typename ... Args>
jboolean safeCallBooleanMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args)
{
LOCAL_ENGINE_JNI_JMETHOD_START
auto ret = env->CallBooleanMethod(obj, method_id, args...);
LOCAL_ENGINE_JNI_JMETHOD_END(env);
return ret;
}

template <typename ... Args>
jlong safeCallLongMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args)
{
LOCAL_ENGINE_JNI_JMETHOD_START
auto ret = env->CallLongMethod(obj, method_id, args...);
LOCAL_ENGINE_JNI_JMETHOD_END(env);
return ret;
}

template <typename ... Args>
jint safeCallIntMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args)
{
LOCAL_ENGINE_JNI_JMETHOD_START
auto ret = env->CallIntMethod(obj, method_id, args...);
LOCAL_ENGINE_JNI_JMETHOD_END(env);
return ret;
}

template <typename ... Args>
void safeCallVoidMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args)
{
LOCAL_ENGINE_JNI_JMETHOD_START
env->CallVoidMethod(obj, method_id, args...);
LOCAL_ENGINE_JNI_JMETHOD_END(env);
}
}
2 changes: 1 addition & 1 deletion utils/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::string jstring2string(JNIEnv * env, jstring jStr)

jclass string_class = env->GetObjectClass(jStr);
jmethodID get_bytes = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
jbyteArray string_jbytes = static_cast<jbyteArray>(env->CallObjectMethod(jStr, get_bytes, env->NewStringUTF("UTF-8")));
jbyteArray string_jbytes = static_cast<jbyteArray>(local_engine::safeCallObjectMethod(env, jStr, get_bytes, env->NewStringUTF("UTF-8")));

size_t length = static_cast<size_t>(env->GetArrayLength(string_jbytes));
jbyte * p_bytes = env->GetByteArrayElements(string_jbytes, nullptr);
Expand Down

0 comments on commit 1eed312

Please sign in to comment.