Skip to content

Commit

Permalink
Merge pull request #4943 from kinaryml:c-image-embedder-api
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580618718
  • Loading branch information
copybara-github committed Nov 8, 2023
2 parents 000314a + c442d61 commit d4d3076
Show file tree
Hide file tree
Showing 14 changed files with 1,048 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ void CppConvertToEmbeddingResult(
}
}

void CppConvertToCppEmbedding(
const Embedding& in, // C struct as input
mediapipe::tasks::components::containers::Embedding* out) {
// Handle float embeddings
if (in.float_embedding != nullptr) {
out->float_embedding.assign(in.float_embedding,
in.float_embedding + in.values_count);
}

// Handle quantized embeddings
if (in.quantized_embedding != nullptr) {
out->quantized_embedding.assign(in.quantized_embedding,
in.quantized_embedding + in.values_count);
}

out->head_index = in.head_index;

// Copy head_name if it is present.
if (in.head_name) {
out->head_name = std::string(in.head_name);
}
}

void CppCloseEmbeddingResult(EmbeddingResult* in) {
for (uint32_t i = 0; i < in->embeddings_count; ++i) {
auto embedding_in = in->embeddings[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ void CppConvertToEmbeddingResult(
const mediapipe::tasks::components::containers::EmbeddingResult& in,
EmbeddingResult* out);

void CppConvertToCppEmbedding(
const Embedding& in,
mediapipe::tasks::components::containers::Embedding* out);

void CppCloseEmbedding(Embedding* in);

void CppCloseEmbeddingResult(EmbeddingResult* in);
Expand Down
1 change: 1 addition & 0 deletions mediapipe/tasks/c/text/text_embedder/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cc_library(
"//mediapipe/tasks/c/components/processors:embedder_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/text/text_embedder",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
Expand Down
28 changes: 28 additions & 0 deletions mediapipe/tasks/c/text/text_embedder/text_embedder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@ limitations under the License.

#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "mediapipe/tasks/c/components/containers/embedding_result.h"
#include "mediapipe/tasks/c/components/containers/embedding_result_converter.h"
#include "mediapipe/tasks/c/components/processors/embedder_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"

namespace mediapipe::tasks::c::text::text_embedder {

namespace {

using ::mediapipe::tasks::c::components::containers::CppCloseEmbeddingResult;
using ::mediapipe::tasks::c::components::containers::CppConvertToCppEmbedding;
using ::mediapipe::tasks::c::components::containers::
CppConvertToEmbeddingResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToEmbedderOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::text::text_embedder::TextEmbedder;
typedef ::mediapipe::tasks::components::containers::Embedding CppEmbedding;

int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
Expand Down Expand Up @@ -91,6 +95,24 @@ int CppTextEmbedderClose(void* embedder, char** error_msg) {
return 0;
}

int CppTextEmbedderCosineSimilarity(const Embedding& u, const Embedding& v,
double* similarity, char** error_msg) {
CppEmbedding cpp_u;
CppConvertToCppEmbedding(u, &cpp_u);
CppEmbedding cpp_v;
CppConvertToCppEmbedding(v, &cpp_v);
auto status_or_similarity =
mediapipe::tasks::text::text_embedder::TextEmbedder::CosineSimilarity(
cpp_u, cpp_v);
if (status_or_similarity.ok()) {
*similarity = status_or_similarity.value();
} else {
ABSL_LOG(ERROR) << "Cannot compute cosine similarity.";
return CppProcessError(status_or_similarity.status(), error_msg);
}
return 0;
}

} // namespace mediapipe::tasks::c::text::text_embedder

extern "C" {
Expand All @@ -116,4 +138,10 @@ int text_embedder_close(void* embedder, char** error_ms) {
embedder, error_ms);
}

int text_embedder_cosine_similarity(const Embedding& u, const Embedding& v,
double* similarity, char** error_msg) {
return mediapipe::tasks::c::text::text_embedder::
CppTextEmbedderCosineSimilarity(u, v, similarity, error_msg);
}

} // extern "C"
11 changes: 11 additions & 0 deletions mediapipe/tasks/c/text/text_embedder/text_embedder.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ MP_EXPORT void text_embedder_close_result(TextEmbedderResult* result);
// allocated for the error message.
MP_EXPORT int text_embedder_close(void* embedder, char** error_msg);

// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
MP_EXPORT int text_embedder_cosine_similarity(const Embedding& u,
const Embedding& v,
double* similarity,
char** error_msg);

#ifdef __cplusplus
} // extern C
#endif
Expand Down
43 changes: 41 additions & 2 deletions mediapipe/tasks/c/text/text_embedder/text_embedder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] =
"mobilebert_embedding_with_metadata.tflite";
constexpr char kTestString[] = "It's beautiful outside.";
constexpr char kTestString0[] =
"When you go to this restaurant, they hold the pancake upside-down "
"before they hand it to you. It's a great gimmick.";
constexpr char kTestString1[] =
"Let's make a plan to steal the declaration of independence.";
constexpr float kPrecision = 1e-3;

std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
Expand All @@ -52,14 +57,48 @@ TEST(TextEmbedderTest, SmokeTest) {
EXPECT_NE(embedder, nullptr);

TextEmbedderResult result;
text_embedder_embed(embedder, kTestString, &result, /* error_msg */ nullptr);
text_embedder_embed(embedder, kTestString0, &result, /* error_msg */ nullptr);
EXPECT_EQ(result.embeddings_count, 1);
EXPECT_EQ(result.embeddings[0].values_count, 512);

text_embedder_close_result(&result);
text_embedder_close(embedder, /* error_msg */ nullptr);
}

TEST(TextEmbedderTest, SucceedsWithCosineSimilarity) {
std::string model_path = GetFullPath(kTestBertModelPath);
TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* embedder_options= */
{/* l2_normalize= */ false,
/* quantize= */ false}};

void* embedder = text_embedder_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(embedder, nullptr);

// Extract both embeddings.
TextEmbedderResult result0;
text_embedder_embed(embedder, kTestString0, &result0,
/* error_msg */ nullptr);
TextEmbedderResult result1;
text_embedder_embed(embedder, kTestString1, &result1,
/* error_msg */ nullptr);

// Check cosine similarity.
double similarity;
text_embedder_cosine_similarity(result0.embeddings[0], result1.embeddings[0],
&similarity, nullptr);
double expected_similarity = 0.98077;
EXPECT_LE(abs(similarity - expected_similarity), kPrecision);

text_embedder_close_result(&result0);
text_embedder_close_result(&result1);
text_embedder_close(embedder, /* error_msg */ nullptr);
}

TEST(TextEmbedderTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path.
TextEmbedderOptions options = {
Expand Down
22 changes: 22 additions & 0 deletions mediapipe/tasks/c/vision/core/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

package(default_visibility = ["//mediapipe/tasks:internal"])

licenses(["notice"])

cc_library(
name = "common",
hdrs = ["common.h"],
)
68 changes: 68 additions & 0 deletions mediapipe/tasks/c/vision/core/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
#define MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_

#include <cstdint>

#ifdef __cplusplus
extern "C" {
#endif

// Supported image formats.
enum ImageFormat {
UNKNOWN = 0,
SRGB = 1,
SRGBA = 2,
GRAY8 = 3,
SBGRA = 11 // compatible with Flutter `bgra8888` format.
};

// Supported processing modes.
enum RunningMode {
IMAGE = 1,
VIDEO = 2,
LIVE_STREAM = 3,
};

// Structure to hold image frame.
struct ImageFrame {
enum ImageFormat format;
const uint8_t* image_buffer;
int width;
int height;
};

// TODO: Add GPU buffer declaration and processing logic for it.
struct GpuBuffer {
int width;
int height;
};

// The object to contain an image, realizes `OneOf` concept.
struct MpImage {
enum { IMAGE_FRAME, GPU_BUFFER } type;
union {
struct ImageFrame image_frame;
struct GpuBuffer gpu_buffer;
};
};

#ifdef __cplusplus
} // extern C
#endif

#endif // MEDIAPIPE_TASKS_C_VISION_CORE_COMMON_H_
1 change: 1 addition & 0 deletions mediapipe/tasks/c/vision/image_classifier/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/image_classifier",
"//mediapipe/tasks/cc/vision/utils:image_utils",
Expand Down
71 changes: 29 additions & 42 deletions mediapipe/tasks/c/vision/image_classifier/image_classifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
#define MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_

#include <cstdint>

#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/c/vision/core/common.h"

#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
Expand All @@ -32,46 +31,7 @@ extern "C" {

typedef ClassificationResult ImageClassifierResult;

// Supported image formats.
enum ImageFormat {
UNKNOWN = 0,
SRGB = 1,
SRGBA = 2,
GRAY8 = 3,
SBGRA = 11 // compatible with Flutter `bgra8888` format.
};

// Supported processing modes.
enum RunningMode {
IMAGE = 1,
VIDEO = 2,
LIVE_STREAM = 3,
};

// Structure to hold image frame.
struct ImageFrame {
enum ImageFormat format;
const uint8_t* image_buffer;
int width;
int height;
};

// TODO: Add GPU buffer declaration and proccessing logic for it.
struct GpuBuffer {
int width;
int height;
};

// The object to contain an image, realizes `OneOf` concept.
struct MpImage {
enum { IMAGE_FRAME, GPU_BUFFER } type;
union {
struct ImageFrame image_frame;
struct GpuBuffer gpu_buffer;
};
};

// The options for configuring a Mediapipe image classifier task.
// The options for configuring a MediaPipe image classifier task.
struct ImageClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
Expand Down Expand Up @@ -122,12 +82,39 @@ MP_EXPORT int image_classifier_classify_image(void* classifier,
ImageClassifierResult* result,
char** error_msg);

// Performs image classification on the provided video frame.
// Only use this method when the ImageClassifier is created with the video
// running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_classifier_classify_for_video(void* classifier,
const MpImage* image,
int64_t timestamp_ms,
ImageClassifierResult* result,
char** error_msg);

// Sends live image data to image classification, and the results will be
// available via the `result_callback` provided in the ImageClassifierOptions.
// Only use this method when the ImageClassifier is created with the live
// stream running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the object detector. The input timestamps must be monotonically
// increasing.
// The `result_callback` provides:
// - The classification results as an ImageClassifierResult object.
// - The const reference to the corresponding input image that the image
// classifier runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int image_classifier_classify_async(void* classifier,
const MpImage* image,
int64_t timestamp_ms,
Expand Down
Loading

0 comments on commit d4d3076

Please sign in to comment.