Skip to content

Commit

Permalink
Return error code and file error message in C API for both PredictSyn…
Browse files Browse the repository at this point in the history
…c and PredictAsync

PiperOrigin-RevId: 713388507
  • Loading branch information
MediaPipe Team authored and copybara-github committed Jan 8, 2025
1 parent 32d28e7 commit 3d26832
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 28 deletions.
8 changes: 5 additions & 3 deletions mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,19 @@ ODML_EXPORT int LlmInferenceEngine_Session_AddImage(

// Return the generated output based on the previously added query chunks in
// sync mode.
ODML_EXPORT LlmResponseContext
LlmInferenceEngine_Session_PredictSync(LlmInferenceEngine_Session* session);
ODML_EXPORT int LlmInferenceEngine_Session_PredictSync(
LlmInferenceEngine_Session* session, LlmResponseContext* response_context,
char** error_msg);

// Run callback function in async mode.
// The callback will be invoked multiple times until `response_context.done`
// is `true`. You need to invoke `LlmInferenceEngine_CloseResponseContext` after
// each invocation to free memory.
// The callback context can be a pointer to any user defined data structure as
// it is passed to the callback unmodified.
ODML_EXPORT void LlmInferenceEngine_Session_PredictAsync(
ODML_EXPORT int LlmInferenceEngine_Session_PredictAsync(
LlmInferenceEngine_Session* session, void* callback_context,
char** error_msg,
void (*callback)(void* callback_context,
LlmResponseContext* response_context));

Expand Down
47 changes: 34 additions & 13 deletions mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,15 @@ ODML_EXPORT int LlmInferenceEngine_Session_AddImage(
return 12;
}

LlmResponseContext LlmInferenceEngine_Session_PredictSync(
LlmInferenceEngine_Session* session) {
LlmInferenceEngine_Session_PredictAsync(
session, nullptr,
int LlmInferenceEngine_Session_PredictSync(LlmInferenceEngine_Session* session,
LlmResponseContext* response_context,
char** error_msg) {
auto status = LlmInferenceEngine_Session_PredictAsync(
session, nullptr, error_msg,
[](void* callback_context, LlmResponseContext* response_context) {});
if (status != 0) {
return status;
}

auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);
pthread_join(cpu_session->work_id, nullptr);
Expand All @@ -604,31 +608,46 @@ LlmResponseContext LlmInferenceEngine_Session_PredictSync(

char** result = (char**)malloc(sizeof(char*) * 1);
if (result == nullptr) {
ABSL_LOG(FATAL) << "Failed to allocate result for cpu session.";
*error_msg = strdup("Failed to allocate result for cpu session.");
return static_cast<int>(absl::StatusCode::kResourceExhausted);
}

result[0] = (char*)malloc(sizeof(char*) * (final_output.size() + 1));
if (result[0] == nullptr) {
ABSL_LOG(FATAL) << "Failed to allocate result for cpu session.";
*error_msg = strdup("Failed to allocate result for cpu session.");
return static_cast<int>(absl::StatusCode::kResourceExhausted);
}

snprintf(result[0], final_output.size() + 1, "%s", final_output.c_str());

LlmResponseContext response_context = {
.response_array = result,
.response_count = 1,
.done = true,
};
response_context->response_array = result;
response_context->response_count = 1;
response_context->done = true;

return response_context;
return 0;
}

void LlmInferenceEngine_Session_PredictAsync(
int LlmInferenceEngine_Session_PredictAsync(
LlmInferenceEngine_Session* session, void* callback_context,
char** error_msg,
void (*callback)(void* callback_context,
LlmResponseContext* response_context)) {
if (session == nullptr) {
*error_msg = strdup("Session is null.");
return static_cast<int>(absl::StatusCode::kInvalidArgument);
}
if (callback == nullptr) {
*error_msg = strdup("Callback is null.");
return static_cast<int>(absl::StatusCode::kInvalidArgument);
}

auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);

if (cpu_session == nullptr) {
*error_msg = strdup("Provided session is not a CPU session.");
return static_cast<int>(absl::StatusCode::kInvalidArgument);
}

cpu_session->cpu_callback = [=](std::string responses) -> void {
char** result = (char**)malloc(sizeof(char*) * 1);
if (result == nullptr) {
Expand Down Expand Up @@ -656,6 +675,8 @@ void LlmInferenceEngine_Session_PredictAsync(
cpu_session->work_id = work_id;
pthread_create(&cpu_session->work_id, nullptr, start_llm_function,
cpu_session);

return 0;
}

int LlmInferenceEngine_Session_Clone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,15 @@ int main(int argc, char** argv) {
// ABSL_LOG(INFO) << "Number of tokens for input prompt: " << num_tokens;

ABSL_LOG(INFO) << "PredictAsync";
LlmInferenceEngine_Session_PredictAsync(llm_engine_session,
/*callback_context=*/nullptr,
async_callback_print);
error_code = LlmInferenceEngine_Session_PredictAsync(
llm_engine_session,
/*callback_context=*/nullptr, &error_msg, async_callback_print);
if (error_code) {
ABSL_LOG(ERROR) << "Failed to predict asyncously: "
<< std::string(error_msg);
free(error_msg);
return EXIT_FAILURE;
}

// Optional to use the following for the sync version.
// ABSL_LOG(INFO) << "PredictSync";
Expand Down
14 changes: 13 additions & 1 deletion mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public enum GenAiInferenceError: Error {
case failedToInitializeSession(String?)
case failedToInitializeEngine(String?)
case failedToAddQueryToSession(String, String?)
case failedToPredictSync(String?)
case failedToPredictAsync(String?)
case failedToCloneSession(String?)
}

Expand Down Expand Up @@ -50,6 +52,12 @@ extension GenAiInferenceError: LocalizedError {
case .failedToAddQueryToSession(let query, let message):
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to add query: \(query) to LlmInference session: \(explanation)"
case .failedToPredictSync(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to predict sync: \(explanation)"
case .failedToPredictAsync(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to predict async: \(explanation)"
case .failedToCloneSession(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to clone LlmInference session: \(explanation)"
Expand Down Expand Up @@ -77,8 +85,12 @@ extension GenAiInferenceError: CustomNSError {
return 4
case .failedToAddQueryToSession:
return 5
case .failedToCloneSession:
case .failedToPredictSync:
return 6
case .failedToPredictAsync:
return 7
case .failedToCloneSession:
return 8
}
}
}
18 changes: 15 additions & 3 deletions mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ final class LlmSessionRunner {
/// - Returns: Array of `String` responses from the LLM.
/// - Throws: An error if the LLM's response is invalid.
func predict() throws -> [String] {
var cErrorMessage: UnsafeMutablePointer<CChar>? = nil
/// No safe guards for the call since the C++ APIs only throw fatal errors.
/// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the
/// call completes.
var responseContext = LlmInferenceEngine_Session_PredictSync(cLlmSession)
var responseContext = LlmResponseContext()
guard
LlmInferenceEngine_Session_PredictSync(cLlmSession, &responseContext, &cErrorMessage)
== StatusCode.success.rawValue
else {
throw GenAiInferenceError.failedToPredictSync(String(allocatedCErrorMessage: cErrorMessage))
}

defer {
withUnsafeMutablePointer(to: &responseContext) {
Expand Down Expand Up @@ -88,11 +95,12 @@ final class LlmSessionRunner {
func predictAsync(
progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void,
completion: @escaping (() -> Void)
) {
) throws {
var cErrorMessage: UnsafeMutablePointer<CChar>? = nil
let callbackInfo = CallbackInfo(progress: progress, completion: completion)
let callbackContext = UnsafeMutableRawPointer(Unmanaged.passRetained(callbackInfo).toOpaque())

LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext) {
let errorCode = LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext, &cErrorMessage) {
context, responseContext in
guard let cContext = context else {
return
Expand Down Expand Up @@ -135,6 +143,10 @@ final class LlmSessionRunner {
cCallbackInfo.completion()
}
}

guard errorCode == StatusCode.success.rawValue else {
throw GenAiInferenceError.failedToPredictAsync(String(allocatedCErrorMessage: cErrorMessage))
}
}

/// Invokes the C LLM session to tokenize an input prompt using a pre-existing processor and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ extension LlmInference {
/// Used to make a decision about whitespace stripping.
var receivedFirstToken = true

llmSessionRunner.predictAsync(
try llmSessionRunner.predictAsync(
progress: { partialResponseStrings, error in
guard let responseStrings = partialResponseStrings,
let humanReadableLlmResponse = Session.humanReadableString(
Expand Down
21 changes: 17 additions & 4 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,15 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeAddImage)(JNIEnv* env, jclass thiz,

JNIEXPORT jbyteArray JNICALL
JNI_METHOD(nativePredictSync)(JNIEnv* env, jclass thiz, jlong session_handle) {
LlmResponseContext response_context = LlmInferenceEngine_Session_PredictSync(
reinterpret_cast<void*>(session_handle));
char* error_msg = nullptr;
LlmResponseContext response_context;
int error_code = LlmInferenceEngine_Session_PredictSync(
reinterpret_cast<void*>(session_handle), &response_context, &error_msg);
if (error_code) {
ThrowIfError(env, absl::InternalError(absl::StrCat(
"Failed to predict sync: %s", error_msg)));
free(error_msg);
}
const jbyteArray response_bytes = ToByteArray(env, response_context);
LlmInferenceEngine_CloseResponseContext(&response_context);
return response_bytes;
Expand All @@ -276,9 +283,15 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeRemoveCallback)(JNIEnv* env,
JNIEXPORT void JNICALL JNI_METHOD(nativePredictAsync)(JNIEnv* env, jclass thiz,
jlong session_handle,
jobject callback_ref) {
LlmInferenceEngine_Session_PredictAsync(
char* error_msg = nullptr;
int error_code = LlmInferenceEngine_Session_PredictAsync(
reinterpret_cast<LlmInferenceEngine_Session*>(session_handle),
reinterpret_cast<void*>(callback_ref), &ProcessAsyncResponse);
reinterpret_cast<void*>(callback_ref), &error_msg, &ProcessAsyncResponse);
if (error_code) {
ThrowIfError(env, absl::InternalError(absl::StrCat(
"Failed to predict async: %s", error_msg)));
free(error_msg);
}
}

JNIEXPORT jint JNICALL JNI_METHOD(nativeSizeInTokens)(JNIEnv* env, jclass thiz,
Expand Down

0 comments on commit 3d26832

Please sign in to comment.