Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596028214
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Jan 5, 2024
1 parent 9a207e3 commit eb937cf
Show file tree
Hide file tree
Showing 8 changed files with 601 additions and 2 deletions.
34 changes: 32 additions & 2 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ android_library(

android_library(
name = "core_java",
srcs = glob(["*.java"]),
srcs = glob(
["*.java"],
exclude = ["LlmTaskRunner.java"],
),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
Expand All @@ -41,7 +44,6 @@ android_library(
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/gpu:gpu_origin_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
Expand All @@ -53,6 +55,34 @@ android_library(
],
)

android_library(
name = "llm",
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "AndroidManifest.xml",
exports = [
":llm_java",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:llm",
],
deps = ["@maven//:com_google_guava_guava"],
)

android_library(
name = "llm_java",
srcs = ["LlmTaskRunner.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "AndroidManifest.xml",
deps = [
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite",
"//third_party/java/protobuf:protobuf_lite",
"@maven//:com_google_guava_guava",
],
)

android_library(
name = "logging",
srcs = glob(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.mediapipe.tasks.core;

import com.google.mediapipe.tasks.core.jni.LlmOptionsProto.LlmModelParameters;
import com.google.mediapipe.tasks.core.jni.LlmOptionsProto.LlmSessionConfig;
import com.google.mediapipe.tasks.core.jni.LlmResponseContextProto.LlmResponseContext;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

/**
* Internal Task Runner class for all LLM Tasks.
*
* @hide
*/
public final class LlmTaskRunner implements AutoCloseable {
private final long sessionHandle;
private final Optional<Function<List<String>, Void>> resultListener;
private final long callbackHandle;

public LlmTaskRunner(
LlmModelParameters modelParameters,
LlmSessionConfig sessionConfig,
Optional<Function<List<String>, Void>> resultListener) {
this.sessionHandle =
nativeCreateSession(modelParameters.toByteArray(), sessionConfig.toByteArray());

this.resultListener = resultListener;
if (resultListener.isPresent()) {
this.callbackHandle = nativeRegisterCallback(this);
} else {
this.callbackHandle = 0;
}
}

/** Invokes the LLM with the provided input and waits for the result. */
public List<String> predictSync(String input) {
byte[] responseBytes = nativePredictSync(sessionHandle, input);
return parseResponse(responseBytes);
}

/** Invokes the LLM with the provided input and calls the callback with the result. */
public void predictAsync(String input) {
if (callbackHandle == 0) {
throw new IllegalStateException("No result listener provided.");
}
nativePredictAsync(sessionHandle, callbackHandle, input);
}

private List<String> parseResponse(byte[] reponse) {
try {
LlmResponseContext result =
LlmResponseContext.parseFrom(reponse, ExtensionRegistryLite.getGeneratedRegistry());
return result.getResponsesList();
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException("Failed to parse response", e);
}
}

private void onAsyncResponse(byte[] responseBytes) {
resultListener.get().apply(parseResponse(responseBytes));
}

@Override
public void close() {
if (callbackHandle != 0) {
nativeRemoveCallback(callbackHandle);
}
nativeDeleteSession(sessionHandle);
}

private static native long nativeCreateSession(byte[] modelParameters, byte[] sessionConfig);

private static native void nativeDeleteSession(long sessionPointer);

private static native byte[] nativePredictSync(long sessionPointer, String input);

private static native long nativeRegisterCallback(Object callback);

private static native void nativeRemoveCallback(long callbackHandle);

private static native void nativePredictAsync(
long sessionPointer, long callbackContextHandle, String input);
}
23 changes: 23 additions & 0 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ cc_library_with_tflite(
}),
alwayslink = 1,
)

cc_library(
name = "llm",
srcs = ["llm.cc"],
hdrs = ["llm.h"],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_cc_proto",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_cc_proto",
"//third_party/odml/infra/genai/inference/c:libllm_inference_engine",
"//third_party/odml/infra/genai/inference/c:libllm_inference_engine_deps",
"@com_google_absl//absl/status",
] + select({
"//mediapipe:android": [],
}),
alwayslink = 1,
)

cc_binary(
name = "llm_jni",
linkshared = 1,
deps = [":llm"],
)
Loading

0 comments on commit eb937cf

Please sign in to comment.