From 3d268328eb4b0b064a9560ce77eae7a36a8cd817 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 8 Jan 2025 13:00:36 -0800 Subject: [PATCH] Return error code and file error message in C API for both PredictSync and PredictAsync PiperOrigin-RevId: 713388507 --- .../genai/inference/c/llm_inference_engine.h | 8 ++-- .../inference/c/llm_inference_engine_cpu.cc | 47 ++++++++++++++----- .../c/llm_inference_engine_cpu_main.cc | 12 +++-- .../core/sources/GenAiInferenceError.swift | 14 +++++- .../genai/core/sources/LlmSessionRunner.swift | 18 +++++-- .../sources/LlmInference+Session.swift | 2 +- .../google/mediapipe/tasks/core/jni/llm.cc | 21 +++++++-- 7 files changed, 94 insertions(+), 28 deletions(-) diff --git a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h index 91f9a08c70..11040b3477 100644 --- a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h +++ b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h @@ -186,8 +186,9 @@ ODML_EXPORT int LlmInferenceEngine_Session_AddImage( // Return the generated output based on the previously added query chunks in // sync mode. -ODML_EXPORT LlmResponseContext -LlmInferenceEngine_Session_PredictSync(LlmInferenceEngine_Session* session); +ODML_EXPORT int LlmInferenceEngine_Session_PredictSync( + LlmInferenceEngine_Session* session, LlmResponseContext* response_context, + char** error_msg); // Run callback function in async mode. // The callback will be invoked multiple times until `response_context.done` @@ -195,8 +196,9 @@ LlmInferenceEngine_Session_PredictSync(LlmInferenceEngine_Session* session); // each invocation to free memory. // The callback context can be a pointer to any user defined data structure as // it is passed to the callback unmodified. -ODML_EXPORT void LlmInferenceEngine_Session_PredictAsync( +ODML_EXPORT int LlmInferenceEngine_Session_PredictAsync( LlmInferenceEngine_Session* session, void* callback_context, + char** error_msg, void (*callback)(void* callback_context, LlmResponseContext* response_context)); diff --git a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc index 7c1e5bd233..6939f55b54 100644 --- a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc +++ b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc @@ -591,11 +591,15 @@ ODML_EXPORT int LlmInferenceEngine_Session_AddImage( return 12; } -LlmResponseContext LlmInferenceEngine_Session_PredictSync( - LlmInferenceEngine_Session* session) { - LlmInferenceEngine_Session_PredictAsync( - session, nullptr, +int LlmInferenceEngine_Session_PredictSync(LlmInferenceEngine_Session* session, + LlmResponseContext* response_context, + char** error_msg) { + auto status = LlmInferenceEngine_Session_PredictAsync( + session, nullptr, error_msg, [](void* callback_context, LlmResponseContext* response_context) {}); + if (status != 0) { + return status; + } auto cpu_session = reinterpret_cast(session); pthread_join(cpu_session->work_id, nullptr); @@ -604,31 +608,46 @@ LlmResponseContext LlmInferenceEngine_Session_PredictSync( char** result = (char**)malloc(sizeof(char*) * 1); if (result == nullptr) { - ABSL_LOG(FATAL) << "Failed to allocate result for cpu session."; + *error_msg = strdup("Failed to allocate result for cpu session."); + return static_cast(absl::StatusCode::kResourceExhausted); } result[0] = (char*)malloc(sizeof(char*) * (final_output.size() + 1)); if (result[0] == nullptr) { - ABSL_LOG(FATAL) << "Failed to allocate result for cpu session."; + *error_msg = strdup("Failed to allocate result for cpu session."); + return static_cast(absl::StatusCode::kResourceExhausted); } snprintf(result[0], final_output.size() + 1, "%s", final_output.c_str()); - LlmResponseContext response_context = { - .response_array = result, - .response_count = 1, - .done = true, - }; + response_context->response_array = result; + response_context->response_count = 1; + response_context->done = true; - return response_context; + return 0; } -void LlmInferenceEngine_Session_PredictAsync( +int LlmInferenceEngine_Session_PredictAsync( LlmInferenceEngine_Session* session, void* callback_context, + char** error_msg, void (*callback)(void* callback_context, LlmResponseContext* response_context)) { + if (session == nullptr) { + *error_msg = strdup("Session is null."); + return static_cast(absl::StatusCode::kInvalidArgument); + } + if (callback == nullptr) { + *error_msg = strdup("Callback is null."); + return static_cast(absl::StatusCode::kInvalidArgument); + } + auto cpu_session = reinterpret_cast(session); + if (cpu_session == nullptr) { + *error_msg = strdup("Provided session is not a CPU session."); + return static_cast(absl::StatusCode::kInvalidArgument); + } + cpu_session->cpu_callback = [=](std::string responses) -> void { char** result = (char**)malloc(sizeof(char*) * 1); if (result == nullptr) { @@ -656,6 +675,8 @@ void LlmInferenceEngine_Session_PredictAsync( cpu_session->work_id = work_id; pthread_create(&cpu_session->work_id, nullptr, start_llm_function, cpu_session); + + return 0; } int LlmInferenceEngine_Session_Clone( diff --git a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu_main.cc b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu_main.cc index 3dc269d101..7825a8fe3a 100644 --- a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu_main.cc +++ b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu_main.cc @@ -162,9 +162,15 @@ int main(int argc, char** argv) { // ABSL_LOG(INFO) << "Number of tokens for input prompt: " << num_tokens; ABSL_LOG(INFO) << "PredictAsync"; - LlmInferenceEngine_Session_PredictAsync(llm_engine_session, - /*callback_context=*/nullptr, - async_callback_print); + error_code = LlmInferenceEngine_Session_PredictAsync( + llm_engine_session, + /*callback_context=*/nullptr, &error_msg, async_callback_print); + if (error_code) { + ABSL_LOG(ERROR) << "Failed to predict asyncously: " + << std::string(error_msg); + free(error_msg); + return EXIT_FAILURE; + } // Optional to use the following for the sync version. // ABSL_LOG(INFO) << "PredictSync"; diff --git a/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift b/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift index db6e535392..804182c3a7 100644 --- a/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift +++ b/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift @@ -22,6 +22,8 @@ public enum GenAiInferenceError: Error { case failedToInitializeSession(String?) case failedToInitializeEngine(String?) case failedToAddQueryToSession(String, String?) + case failedToPredictSync(String?) + case failedToPredictAsync(String?) case failedToCloneSession(String?) } @@ -50,6 +52,12 @@ extension GenAiInferenceError: LocalizedError { case .failedToAddQueryToSession(let query, let message): let explanation = message.flatMap { $0 } ?? "An internal error occurred." return "Failed to add query: \(query) to LlmInference session: \(explanation)" + case .failedToPredictSync(let message): + let explanation = message.flatMap { $0 } ?? "An internal error occurred." + return "Failed to predict sync: \(explanation)" + case .failedToPredictAsync(let message): + let explanation = message.flatMap { $0 } ?? "An internal error occurred." + return "Failed to predict async: \(explanation)" case .failedToCloneSession(let message): let explanation = message.flatMap { $0 } ?? "An internal error occurred." return "Failed to clone LlmInference session: \(explanation)" @@ -77,8 +85,12 @@ extension GenAiInferenceError: CustomNSError { return 4 case .failedToAddQueryToSession: return 5 - case .failedToCloneSession: + case .failedToPredictSync: return 6 + case .failedToPredictAsync: + return 7 + case .failedToCloneSession: + return 8 } } } diff --git a/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift b/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift index 62bf6cde6b..5e45bbb67e 100644 --- a/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift +++ b/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift @@ -57,10 +57,17 @@ final class LlmSessionRunner { /// - Returns: Array of `String` responses from the LLM. /// - Throws: An error if the LLM's response is invalid. func predict() throws -> [String] { + var cErrorMessage: UnsafeMutablePointer? = nil /// No safe guards for the call since the C++ APIs only throw fatal errors. /// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the /// call completes. - var responseContext = LlmInferenceEngine_Session_PredictSync(cLlmSession) + var responseContext = LlmResponseContext() + guard + LlmInferenceEngine_Session_PredictSync(cLlmSession, &responseContext, &cErrorMessage) + == StatusCode.success.rawValue + else { + throw GenAiInferenceError.failedToPredictSync(String(allocatedCErrorMessage: cErrorMessage)) + } defer { withUnsafeMutablePointer(to: &responseContext) { @@ -88,11 +95,12 @@ final class LlmSessionRunner { func predictAsync( progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void, completion: @escaping (() -> Void) - ) { + ) throws { + var cErrorMessage: UnsafeMutablePointer? = nil let callbackInfo = CallbackInfo(progress: progress, completion: completion) let callbackContext = UnsafeMutableRawPointer(Unmanaged.passRetained(callbackInfo).toOpaque()) - LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext) { + let errorCode = LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext, &cErrorMessage) { context, responseContext in guard let cContext = context else { return @@ -135,6 +143,10 @@ final class LlmSessionRunner { cCallbackInfo.completion() } } + + guard errorCode == StatusCode.success.rawValue else { + throw GenAiInferenceError.failedToPredictAsync(String(allocatedCErrorMessage: cErrorMessage)) + } } /// Invokes the C LLM session to tokenize an input prompt using a pre-existing processor and diff --git a/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift b/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift index a01b80e20c..01ced05ed6 100644 --- a/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift +++ b/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift @@ -172,7 +172,7 @@ extension LlmInference { /// Used to make a decision about whitespace stripping. var receivedFirstToken = true - llmSessionRunner.predictAsync( + try llmSessionRunner.predictAsync( progress: { partialResponseStrings, error in guard let responseStrings = partialResponseStrings, let humanReadableLlmResponse = Session.humanReadableString( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc index 89ba4d3563..bf30df1cb4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc @@ -249,8 +249,15 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeAddImage)(JNIEnv* env, jclass thiz, JNIEXPORT jbyteArray JNICALL JNI_METHOD(nativePredictSync)(JNIEnv* env, jclass thiz, jlong session_handle) { - LlmResponseContext response_context = LlmInferenceEngine_Session_PredictSync( - reinterpret_cast(session_handle)); + char* error_msg = nullptr; + LlmResponseContext response_context; + int error_code = LlmInferenceEngine_Session_PredictSync( + reinterpret_cast(session_handle), &response_context, &error_msg); + if (error_code) { + ThrowIfError(env, absl::InternalError(absl::StrCat( + "Failed to predict sync: %s", error_msg))); + free(error_msg); + } const jbyteArray response_bytes = ToByteArray(env, response_context); LlmInferenceEngine_CloseResponseContext(&response_context); return response_bytes; @@ -276,9 +283,15 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeRemoveCallback)(JNIEnv* env, JNIEXPORT void JNICALL JNI_METHOD(nativePredictAsync)(JNIEnv* env, jclass thiz, jlong session_handle, jobject callback_ref) { - LlmInferenceEngine_Session_PredictAsync( + char* error_msg = nullptr; + int error_code = LlmInferenceEngine_Session_PredictAsync( reinterpret_cast(session_handle), - reinterpret_cast(callback_ref), &ProcessAsyncResponse); + reinterpret_cast(callback_ref), &error_msg, &ProcessAsyncResponse); + if (error_code) { + ThrowIfError(env, absl::InternalError(absl::StrCat( + "Failed to predict async: %s", error_msg))); + free(error_msg); + } } JNIEXPORT jint JNICALL JNI_METHOD(nativeSizeInTokens)(JNIEnv* env, jclass thiz,