diff --git a/mediapipe/tasks/cc/genai/inference/utils/llm_utils/BUILD b/mediapipe/tasks/cc/genai/inference/utils/llm_utils/BUILD index c61e0c6052..b91bd90391 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/llm_utils/BUILD +++ b/mediapipe/tasks/cc/genai/inference/utils/llm_utils/BUILD @@ -58,6 +58,7 @@ cc_library( hdrs = ["metadata_utils.h"], deps = [ "//mediapipe/tasks/cc/genai/inference/proto:llm_params_cc_proto", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/genai/inference/utils/llm_utils/metadata_utils.h b/mediapipe/tasks/cc/genai/inference/utils/llm_utils/metadata_utils.h index 1d07f5c3c5..f4e7abfd51 100644 --- a/mediapipe/tasks/cc/genai/inference/utils/llm_utils/metadata_utils.h +++ b/mediapipe/tasks/cc/genai/inference/utils/llm_utils/metadata_utils.h @@ -15,6 +15,7 @@ #ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_LLM_UTILS_METADATA_UTILS_H_ #define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_LLM_UTILS_METADATA_UTILS_H_ +#include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/tasks/cc/genai/inference/proto/llm_params.pb.h" @@ -39,8 +40,11 @@ inline bool RequireBytesToUnicodeMapping( } inline bool RequireFp32Model(odml::infra::proto::LlmModelType model_type) { - return model_type == odml::infra::proto::LLM_MODEL_TYPE_PHI_2 || - model_type == odml::infra::proto::LLM_MODEL_TYPE_FALCON_RW_1B; + constexpr odml::infra::proto::LlmModelType kFp32Models[] = { + odml::infra::proto::LLM_MODEL_TYPE_PHI_2, + odml::infra::proto::LLM_MODEL_TYPE_FALCON_RW_1B, + }; + return absl::c_linear_search(kFp32Models, model_type); } } // namespace mediapipe::tasks::genai::llm_utils