diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index 0d89c820ec..4bb5808731 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -98,3 +98,26 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "language_detection_result_converter", + srcs = ["language_detection_result_converter.cc"], + hdrs = ["language_detection_result_converter.h"], + deps = [ + "//mediapipe/tasks/c/text/language_detector", + "//mediapipe/tasks/cc/text/language_detector", + ], +) + +cc_test( + name = "language_detection_result_converter_test", + srcs = ["language_detection_result_converter_test.cc"], + linkstatic = 1, + deps = [ + ":language_detection_result_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/c/text/language_detector", + "//mediapipe/tasks/cc/text/language_detector", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/components/containers/language_detection_result_converter.cc b/mediapipe/tasks/c/components/containers/language_detection_result_converter.cc new file mode 100644 index 0000000000..89b112e454 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/language_detection_result_converter.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" + +#include +#include + +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToLanguageDetectionResult( + const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in, + LanguageDetectorResult* out) { + out->predictions_count = in.size(); + out->predictions = + out->predictions_count + ? new LanguageDetectorPrediction[out->predictions_count] + : nullptr; + + for (uint32_t i = 0; i < out->predictions_count; ++i) { + auto language_detection_prediction_in = in[i]; + auto& language_detection_prediction_out = out->predictions[i]; + language_detection_prediction_out.probability = + language_detection_prediction_in.probability; + language_detection_prediction_out.language_code = + strdup(language_detection_prediction_in.language_code.c_str()); + } +} + +void CppCloseLanguageDetectionResult(LanguageDetectorResult* in) { + for (uint32_t i = 0; i < in->predictions_count; ++i) { + auto prediction_in = in->predictions[i]; + + free(prediction_in.language_code); + prediction_in.language_code = nullptr; + } + delete[] in->predictions; + in->predictions = nullptr; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/language_detection_result_converter.h b/mediapipe/tasks/c/components/containers/language_detection_result_converter.h new file mode 100644 index 0000000000..c9cfd55bdb --- /dev/null +++ b/mediapipe/tasks/c/components/containers/language_detection_result_converter.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ + +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToLanguageDetectionResult( + const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in, + LanguageDetectorResult* out); + +void CppCloseLanguageDetectionResult(LanguageDetectorResult* in); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/language_detection_result_converter_test.cc b/mediapipe/tasks/c/components/containers/language_detection_result_converter_test.cc new file mode 100644 index 0000000000..633b77eae1 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/language_detection_result_converter_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(LanguageDetectionResultConverterTest, + ConvertsLanguageDetectionResultCustomResult) { + mediapipe::tasks::text::language_detector::LanguageDetectorResult + cpp_detector_result = {{/* language_code= */ "fr", + /* probability= */ 0.5}, + {/* language_code= */ "en", + /* probability= */ 0.5}}; + + LanguageDetectorResult c_detector_result; + CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result); + EXPECT_NE(c_detector_result.predictions, nullptr); + EXPECT_EQ(c_detector_result.predictions_count, 2); + EXPECT_NE(c_detector_result.predictions[0].language_code, "fr"); + EXPECT_EQ(c_detector_result.predictions[0].probability, 0.5); + + CppCloseLanguageDetectionResult(&c_detector_result); +} + +TEST(LanguageDetectionResultConverterTest, FreesMemory) { + mediapipe::tasks::text::language_detector::LanguageDetectorResult + cpp_detector_result = {{"fr", 0.5}}; + + LanguageDetectorResult c_detector_result; + CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result); + EXPECT_NE(c_detector_result.predictions, nullptr); + + CppCloseLanguageDetectionResult(&c_detector_result); + EXPECT_EQ(c_detector_result.predictions, nullptr); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/text/language_detector/BUILD b/mediapipe/tasks/c/text/language_detector/BUILD new file mode 100644 index 0000000000..9a3ce21e70 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/BUILD @@ -0,0 +1,93 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "language_detector_lib", + srcs = ["language_detector.cc"], + hdrs = ["language_detector.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/c/components/containers:language_detection_result_converter", + "//mediapipe/tasks/c/components/processors:classifier_options", + "//mediapipe/tasks/c/components/processors:classifier_options_converter", + "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/cc/text/language_detector", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \ +# //mediapipe/tasks/c/text/language_detector:liblanguage_detector.so +cc_binary( + name = "liblanguage_detector.so", + linkopts = [ + "-Wl,-soname=liblanguage_detector.so", + "-fvisibility=hidden", + ], + linkshared = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [":language_detector_lib"], +) + +# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \ +# //mediapipe/tasks/c/text/language_detector:liblanguage_detector.dylib +cc_binary( + name = "liblanguage_detector.dylib", + linkopts = [ + "-Wl,-install_name,liblanguage_detector.dylib", + "-fvisibility=hidden", + ], + linkshared = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [":language_detector_lib"], +) + +cc_library( + name = "language_detector", + hdrs = ["language_detector.h"], + deps = [ + "//mediapipe/tasks/c/components/processors:classifier_options", + "//mediapipe/tasks/c/core:base_options", + ], +) + +cc_test( + name = "language_detector_test", + srcs = ["language_detector_test.cc"], + data = ["//mediapipe/tasks/testdata/text:language_detector"], + linkstatic = 1, + deps = [ + ":language_detector_lib", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/text/language_detector/language_detector.cc b/mediapipe/tasks/c/text/language_detector/language_detector.cc new file mode 100644 index 0000000000..c71433fdc6 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/language_detector.cc @@ -0,0 +1,124 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" + +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" +#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" +#include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::text::language_detector { + +namespace { + +using ::mediapipe::tasks::c::components::containers:: + CppCloseLanguageDetectionResult; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToLanguageDetectionResult; +using ::mediapipe::tasks::c::components::processors:: + CppConvertToClassifierOptions; +using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; +using ::mediapipe::tasks::text::language_detector::LanguageDetector; + +int CppProcessError(absl::Status status, char** error_msg) { + if (error_msg) { + *error_msg = strdup(status.ToString().c_str()); + } + return status.raw_code(); +} + +} // namespace + +LanguageDetector* CppLanguageDetectorCreate( + const LanguageDetectorOptions& options, char** error_msg) { + auto cpp_options = std::make_unique< + ::mediapipe::tasks::text::language_detector::LanguageDetectorOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToClassifierOptions(options.classifier_options, + &cpp_options->classifier_options); + + auto detector = LanguageDetector::Create(std::move(cpp_options)); + if (!detector.ok()) { + ABSL_LOG(ERROR) << "Failed to create LanguageDetector: " + << detector.status(); + CppProcessError(detector.status(), error_msg); + return nullptr; + } + return detector->release(); +} + +int CppLanguageDetectorDetect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, + char** error_msg) { + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->Detect(utf8_str); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Language Detection failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + + CppConvertToLanguageDetectionResult(*cpp_result, result); + return 0; +} + +void CppLanguageDetectorCloseResult(LanguageDetectorResult* result) { + CppCloseLanguageDetectionResult(result); +} + +int CppLanguageDetectorClose(void* detector, char** error_msg) { + auto cpp_detector = static_cast(detector); + auto result = cpp_detector->Close(); + if (!result.ok()) { + ABSL_LOG(ERROR) << "Failed to close LanguageDetector: " << result; + return CppProcessError(result, error_msg); + } + delete cpp_detector; + return 0; +} + +} // namespace mediapipe::tasks::c::text::language_detector + +extern "C" { + +void* language_detector_create(struct LanguageDetectorOptions* options, + char** error_msg) { + return mediapipe::tasks::c::text::language_detector:: + CppLanguageDetectorCreate(*options, error_msg); +} + +int language_detector_detect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, char** error_msg) { + return mediapipe::tasks::c::text::language_detector:: + CppLanguageDetectorDetect(detector, utf8_str, result, error_msg); +} + +void language_detector_close_result(LanguageDetectorResult* result) { + mediapipe::tasks::c::text::language_detector::CppLanguageDetectorCloseResult( + result); +} + +int language_detector_close(void* detector, char** error_ms) { + return mediapipe::tasks::c::text::language_detector::CppLanguageDetectorClose( + detector, error_ms); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/text/language_detector/language_detector.h b/mediapipe/tasks/c/text/language_detector/language_detector.h new file mode 100644 index 0000000000..f1c85069f6 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/language_detector.h @@ -0,0 +1,91 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ +#define MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ + +#include + +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/core/base_options.h" + +#ifndef MP_EXPORT +#define MP_EXPORT __attribute__((visibility("default"))) +#endif // MP_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +// A language code and its probability. +struct LanguageDetectorPrediction { + // An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, + // "ja"-Latn for Japanese (romaji). + char* language_code; + + float probability; +}; + +// Task output. +struct LanguageDetectorResult { + struct LanguageDetectorPrediction* predictions; + + // The count of predictions. + uint32_t predictions_count; +}; + +// The options for configuring a MediaPipe language detector task. +struct LanguageDetectorOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // Options for configuring the detector behavior, such as score threshold, + // number of results, etc. + struct ClassifierOptions classifier_options; +}; + +// Creates a LanguageDetector from the provided `options`. +// Returns a pointer to the language detector on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT void* language_detector_create( + struct LanguageDetectorOptions* options, char** error_msg = nullptr); + +// Performs language detection on the input `text`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT int language_detector_detect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, + char** error_msg = nullptr); + +// Frees the memory allocated inside a LanguageDetectorResult result. Does not +// free the result pointer itself. +MP_EXPORT void language_detector_close_result(LanguageDetectorResult* result); + +// Shuts down the LanguageDetector when all the work is done. Frees all memory. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT int language_detector_close(void* detector, + char** error_msg = nullptr); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ diff --git a/mediapipe/tasks/c/text/language_detector/language_detector_test.cc b/mediapipe/tasks/c/text/language_detector/language_detector_test.cc new file mode 100644 index 0000000000..b8653e616f --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/language_detector_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace { + +using ::mediapipe::file::JoinPath; +using testing::HasSubstr; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestLanguageDetectorModelPath[] = "language_detector.tflite"; +constexpr char kTestString[] = + "Il y a beaucoup de bouches qui parlent et fort peu " + "de tĂȘtes qui pensent."; +constexpr float kPrecision = 1e-6; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +TEST(LanguageDetectorTest, SmokeTest) { + std::string model_path = GetFullPath(kTestLanguageDetectorModelPath); + LanguageDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + }; + + void* detector = language_detector_create(&options); + EXPECT_NE(detector, nullptr); + + LanguageDetectorResult result; + language_detector_detect(detector, kTestString, &result); + EXPECT_EQ(std::string(result.predictions[0].language_code), "fr"); + EXPECT_NEAR(result.predictions[0].probability, 0.999781, kPrecision); + + language_detector_close_result(&result); + language_detector_close(detector); +} + +TEST(LanguageDetectorTest, ErrorHandling) { + // It is an error to set neither the asset buffer nor the path. + LanguageDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ nullptr}, + /* classifier_options= */ {}, + }; + + char* error_msg; + void* detector = language_detector_create(&options, &error_msg); + EXPECT_EQ(detector, nullptr); + + EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT")); + + free(error_msg); +} + +} // namespace