diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs index a520221cb..0041652c2 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs @@ -236,8 +236,11 @@ pub(super) fn coerce_bool( target, Some(value), &[ - ("true", vec!["true".into()]), - ("false", vec!["false".into()]), + ("true", vec!["true".into(), "True".into(), "TRUE".into()]), + ( + "false", + vec!["false".into(), "False".into(), "FALSE".into()], + ), ], ) { Ok(val) => match val.value().as_str() { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs index 44274ede0..a5475642d 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs @@ -18,6 +18,7 @@ use crate::{ use super::ParsingContext; +/// Heuristic match of different possible values against an input string. pub(super) fn match_string( parsing_context: &ParsingContext, target: &FieldType, @@ -35,6 +36,7 @@ pub(super) fn match_string( let mut flags = DeserializerConditions::new(); + // Grab context. let jsonish_string = match value { jsonish::Value::String(s) => s.clone(), jsonish::Value::AnyOf(_, s) => { @@ -47,8 +49,10 @@ pub(super) fn match_string( } }; + // Trim whitespaces. let match_context = jsonish_string.trim(); + // First attempt, case sensitive match ignoring possible pucntuation. if let Some(string_match) = string_match_strategy(&match_context, &candidates, &mut flags) { return try_match_only_once(parsing_context, target, string_match, flags); } @@ -56,16 +60,31 @@ pub(super) fn match_string( // Strip punctuation and try again. let match_context = strip_punctuation(match_context); - let candidates = candidates - .iter() - .map(|(candidate, valid_values)| { - ( - *candidate, - valid_values.iter().map(|v| strip_punctuation(v)).collect(), - ) - }) - .collect::>(); + // TODO: If the candidates don't contain any punctuation themselves there's + // no point in removing the punctuation from the input string and running + // the entire algorithm again because it should've already matched the + // substrings in the previous attempt. This can be optimized. + let mut candidates = Vec::from_iter(candidates.iter().map(|(candidate, valid_values)| { + let stripped_valid_values = valid_values.iter().map(|v| strip_punctuation(v)).collect(); + (*candidate, stripped_valid_values) + })); + + // Second attempt, case sensitive match without punctuation. + if let Some(string_match) = string_match_strategy(&match_context, &candidates, &mut flags) { + return try_match_only_once(parsing_context, target, string_match, flags); + } + // Last hope, case insensitive match without punctuation. This could yield + // wrong results since the name of a candidate could appear as a "normal" + // word used by the LLM to explain the output. + let match_context = match_context.to_lowercase(); + + // TODO: Consider adding a flag for case insensitive match. + candidates.iter_mut().for_each(|(_, valid_values)| { + valid_values.iter_mut().for_each(|v| *v = v.to_lowercase()); + }); + + // There goes our last hope :) if let Some(string_match) = string_match_strategy(&match_context, &candidates, &mut flags) { return try_match_only_once(parsing_context, target, string_match, flags); } @@ -79,6 +98,9 @@ fn strip_punctuation(s: &str) -> String { .collect::() } +/// Helper function to return a single string match result. +/// +/// Multiple results will yield an error. fn try_match_only_once( parsing_context: &ParsingContext<'_>, target: &FieldType, @@ -100,6 +122,13 @@ fn try_match_only_once( Ok((string_match.to_string(), flags).into()) } +/// Heuristic string match algorithm. +/// +/// The algorithm is case sensitive so for case insensitive matches it must +/// recieve lowercase strings. This algorithm will first try to look for exact +/// matches in the input string, if it doesn't find any it will look for +/// substring matches and return the one with the most matches. Whether that is +/// an ambigous match or not is up to the caller to decide. fn string_match_strategy<'c>( value_str: &str, candidates: &'c [(&'c str, Vec)], @@ -107,30 +136,20 @@ fn string_match_strategy<'c>( ) -> Option<&'c str> { // Try and look for an exact match against valid values. for (candidate, valid_values) in candidates { - // Consider adding a flag for case insensitive match. - if valid_values - .iter() - .any(|v| v.eq_ignore_ascii_case(value_str)) - { + if valid_values.iter().any(|v| v == value_str) { // We did nothing fancy, so no extra flags. return Some(candidate); } } - // We'll match evetything using lower case and then return the original - // variants. - let case_insensitive_str = value_str.to_lowercase(); - // Now find all the candidates which occur in the value, by frequency. let mut result = Vec::from_iter(candidates.iter().filter_map(|(variant, valid_names)| { // Check how many counts of the variant are in the value. let match_count_pos = valid_names .iter() .filter_map(|valid_name| { - // Convert to lower. - let case_insensitive_name = valid_name.to_lowercase(); - // Match against full lower case input. - let matches = case_insensitive_str.match_indices(&case_insensitive_name); + // Match ocurrences of valid name. + let matches = value_str.match_indices(valid_name); // Return (count, first_idx) matches.fold(None, |acc, (idx, _)| match acc { Some((count, prev_idx)) => Some((count + 1, prev_idx)), diff --git a/engine/baml-lib/jsonish/src/tests/test_enum.rs b/engine/baml-lib/jsonish/src/tests/test_enum.rs index ab124ea97..6e97cffa4 100644 --- a/engine/baml-lib/jsonish/src/tests/test_enum.rs +++ b/engine/baml-lib/jsonish/src/tests/test_enum.rs @@ -96,6 +96,21 @@ test_deserializer!( "ONE" ); +test_deserializer!( + case_sensitive_non_ambiguous_match, + ENUM_FILE, + r#"TWO" is one of the correct answers."#, + FieldType::Enum("Category".to_string()), + "TWO" +); + +test_failing_deserializer!( + case_insensitive_ambiguous_match, + ENUM_FILE, + r#"Two" is one of the correct answers."#, + FieldType::Enum("Category".to_string()) +); + test_failing_deserializer!( from_string_with_extra_text_after_3, ENUM_FILE,