From 0c5cbd4ae03d2bc836ee4b61a7df638855bb72ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20C=C3=A1rdenas?= <78029302+miguelcsx@users.noreply.github.com> Date: Sat, 26 Oct 2024 17:26:59 -0500 Subject: [PATCH] [Fix] Improve Enum Substring Alias Handling (#1098) ## Summary: This pull request introduces improvements to the handling of enum substring aliases in the deserializer. The update modifies the string match strategy to favor the longest substring match, enhancing the accuracy and robustness of enum alias recognition. ## Changes - Updated the string match strategy in the deserializer to prioritize the longest substring match for enum aliases. ## Benefits - Improved handling of enum aliases, reducing potential mismatches. - - Enhances overall functionality and user experience when working with enums. Related Issue: #1085 --------- Co-authored-by: Vaibhav Gupta Co-authored-by: Antonio Sarosi --- .../src/deserializer/coercer/match_string.rs | 116 ++++++++++-------- .../src/deserializer/deserialize_flags.rs | 7 +- .../jsonish/src/deserializer/score.rs | 5 +- .../baml-lib/jsonish/src/tests/test_enum.rs | 37 ++++++ 4 files changed, 112 insertions(+), 53 deletions(-) diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs index a5475642d..096e79bab 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs @@ -2,7 +2,7 @@ //! //! Used mostly for matching enum variants or literal strings. -use std::cmp::Ordering; +use std::{cmp::Ordering, collections::HashMap}; use anyhow::Result; use baml_types::FieldType; @@ -125,10 +125,10 @@ fn try_match_only_once( /// Heuristic string match algorithm. /// /// The algorithm is case sensitive so for case insensitive matches it must -/// recieve lowercase strings. This algorithm will first try to look for exact +/// receive lowercase strings. This algorithm will first try to look for exact /// matches in the input string, if it doesn't find any it will look for /// substring matches and return the one with the most matches. Whether that is -/// an ambigous match or not is up to the caller to decide. +/// an ambiguous match or not is up to the caller to decide. fn string_match_strategy<'c>( value_str: &str, candidates: &'c [(&'c str, Vec)], @@ -142,60 +142,78 @@ fn string_match_strategy<'c>( } } - // Now find all the candidates which occur in the value, by frequency. - let mut result = Vec::from_iter(candidates.iter().filter_map(|(variant, valid_names)| { - // Check how many counts of the variant are in the value. - let match_count_pos = valid_names - .iter() - .filter_map(|valid_name| { - // Match ocurrences of valid name. - let matches = value_str.match_indices(valid_name); - // Return (count, first_idx) - matches.fold(None, |acc, (idx, _)| match acc { - Some((count, prev_idx)) => Some((count + 1, prev_idx)), - None => Some((1, idx)), - }) - }) - .reduce(|a, b| match a.0.cmp(&b.0) { - // Return the one with more matches. - Ordering::Less => b, - Ordering::Greater => a, - // Return the one that matches earlier - Ordering::Equal => match a.1.cmp(&b.1) { - Ordering::Less => a, - _ => b, - }, - }); - - match_count_pos.map(|(count, pos)| (count, pos, variant)) - })); + // (start_index, end_index, valid_name, variant) + // TODO: Consider using a struct with named fields instead of a 4-tuple. + let mut all_matches: Vec<(usize, usize, &'c str, &'c str)> = Vec::new(); + + // Look for substrings of valid values + for (variant, valid_names) in candidates { + for valid_name in valid_names { + for (start_idx, _) in value_str.match_indices(valid_name) { + let end_idx = start_idx + valid_name.len(); + all_matches.push((start_idx, end_idx, valid_name, variant)); + } + } + } - // Sort by max count, then min pos. - result.sort_by(|a, b| match a.0.cmp(&b.0) { - Ordering::Less => Ordering::Greater, - Ordering::Greater => Ordering::Less, - Ordering::Equal => a.1.cmp(&b.1), + // No substring match at all for any variant, early return. + if all_matches.is_empty() { + return None; + } + + // Sort by position and length + all_matches.sort_by(|a, b| { + match a.0.cmp(&b.0) { + Ordering::Equal => b.1.cmp(&a.1), // Longer first + ordering => ordering, // Less or Greater stays the same + } }); - // Filter for max count. - let max_count = result.first().map(|r| r.0).unwrap_or(0); - result.retain(|r| r.0 == max_count); + // Filter out overlapping matches + let mut filtered_matches = Vec::new(); + let mut last_end = 0; - // Return the best match if there is one. - if let Some((_, _, candidate)) = result.first() { + for current_match in all_matches { + if current_match.0 >= last_end { + // No overlap with previous match + last_end = current_match.1; + filtered_matches.push(current_match); + } + } + + // Count occurrences of each variant in non-overlapping matches. + // (count, variant) + let mut variant_counts = HashMap::<&'c str, usize>::new(); + for (_, _, _, variant) in &filtered_matches { + if let Some(count) = variant_counts.get_mut(*variant) { + // Increment count if variant already exists. + *count += 1; + } else { + // Add new variant. + variant_counts.insert(variant, 1); + } + } + + // Return the best match if there is one + if let Some((best_match, max_count)) = variant_counts + .iter() + .max_by(|(_, count_a), (_, count_b)| count_a.cmp(count_b)) + { flags.add_flag(Flag::SubstringMatch(value_str.into())); - // Add flag for multiple matches. - if result.len() > 1 { - flags.add_flag(Flag::StrMatchOneFromMany( - result - .iter() - .map(|(count, _, candidate)| ((*count) as usize, candidate.to_string())) - .collect(), - )); + // Find all variants with the same count + let ties: Vec<_> = variant_counts + .iter() + .filter(|(_, count)| *count == max_count) + .map(|(variant, count)| (variant.to_string(), *count)) + .collect(); + + // If there are multiple matches, add a flag + if ties.len() > 1 { + flags.add_flag(Flag::StrMatchOneFromMany(ties)); } - return Some(candidate); + return Some(best_match); } // No match found. diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index e555e0d59..5aab40624 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -27,7 +27,8 @@ pub enum Flag { FirstMatch(usize, Vec>), UnionMatch(usize, Vec>), - StrMatchOneFromMany(Vec<(usize, String)>), + /// `[(value, count)]` + StrMatchOneFromMany(Vec<(String, usize)>), DefaultFromNoValue, DefaultButHadValue(crate::jsonish::Value), @@ -177,8 +178,8 @@ impl std::fmt::Display for Flag { } Flag::StrMatchOneFromMany(values) => { write!(f, "Enum one from many: ")?; - for (idx, value) in values { - writeln!(f, "Item {}: {}", idx, value)?; + for (value, count) in values { + writeln!(f, "Item {value}: {count}")?; } } Flag::DefaultButHadUnparseableValue(value) => { diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index fdf9748fe..95bb6a5b0 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -58,7 +58,10 @@ impl WithScore for Flag { Flag::FirstMatch(_, _) => 1, // No penalty for picking an option from a union Flag::UnionMatch(_, _) => 0, - Flag::StrMatchOneFromMany(i) => i.into_iter().map(|(i, _)| *i as i32).sum::(), + Flag::StrMatchOneFromMany(values) => values + .into_iter() + .map(|(_, count)| *count as i32) + .sum::(), Flag::StringToBool(_) => 1, Flag::StringToNull(_) => 1, Flag::StringToChar(_) => 1, diff --git a/engine/baml-lib/jsonish/src/tests/test_enum.rs b/engine/baml-lib/jsonish/src/tests/test_enum.rs index 6e97cffa4..1164bdb61 100644 --- a/engine/baml-lib/jsonish/src/tests/test_enum.rs +++ b/engine/baml-lib/jsonish/src/tests/test_enum.rs @@ -256,3 +256,40 @@ test_deserializer!( FieldType::List(FieldType::Enum("Category".to_string()).into()), ["ONE", "TWO", "THREE"] ); + +test_deserializer!( + test_numerical_enum, + r#" +enum TaxReturnFormType { + F9325 @alias("9325") + F9465 @alias("9465") + F1040 @alias("1040") + F1040X @alias("1040-X") +} +"#, + r#" +(such as 1040-X, 1040, etc.) or any payment vouchers. + +Based on the criteria provided, this page does not qualify as a tax return form page. Therefore, the appropriate response is: + +```json +null +``` + +This indicates that there is no relevant tax return form type present on the page. + "#, + FieldType::Enum("TaxReturnFormType".to_string()).as_optional(), + null +); + +test_failing_deserializer!( + test_ambiguous_substring_enum, + r#" + enum Car { + A @alias("car") + B @alias("car-2") + } + "#, + "The answer is not car or car-2!", + FieldType::Enum("Car".to_string()) +);