Skip to content

Commit

Permalink
[Fix] Improve Enum Substring Alias Handling (#1098)
Browse files Browse the repository at this point in the history
## Summary:

This pull request introduces improvements to the handling of enum
substring aliases in the deserializer. The update modifies the string
match strategy to favor the longest substring match, enhancing the
accuracy and robustness of enum alias recognition.

## Changes

- Updated the string match strategy in the deserializer to prioritize
the longest substring match for enum aliases.

## Benefits

- Improved handling of enum aliases, reducing potential mismatches.
- 
- Enhances overall functionality and user experience when working with
enums.
    
Related Issue: #1085

---------

Co-authored-by: Vaibhav Gupta <[email protected]>
Co-authored-by: Antonio Sarosi <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2024
1 parent 23e590b commit 0c5cbd4
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 53 deletions.
116 changes: 67 additions & 49 deletions engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! Used mostly for matching enum variants or literal strings.
use std::cmp::Ordering;
use std::{cmp::Ordering, collections::HashMap};

use anyhow::Result;
use baml_types::FieldType;
Expand Down Expand Up @@ -125,10 +125,10 @@ fn try_match_only_once(
/// Heuristic string match algorithm.
///
/// The algorithm is case sensitive so for case insensitive matches it must
/// recieve lowercase strings. This algorithm will first try to look for exact
/// receive lowercase strings. This algorithm will first try to look for exact
/// matches in the input string, if it doesn't find any it will look for
/// substring matches and return the one with the most matches. Whether that is
/// an ambigous match or not is up to the caller to decide.
/// an ambiguous match or not is up to the caller to decide.
fn string_match_strategy<'c>(
value_str: &str,
candidates: &'c [(&'c str, Vec<String>)],
Expand All @@ -142,60 +142,78 @@ fn string_match_strategy<'c>(
}
}

// Now find all the candidates which occur in the value, by frequency.
let mut result = Vec::from_iter(candidates.iter().filter_map(|(variant, valid_names)| {
// Check how many counts of the variant are in the value.
let match_count_pos = valid_names
.iter()
.filter_map(|valid_name| {
// Match ocurrences of valid name.
let matches = value_str.match_indices(valid_name);
// Return (count, first_idx)
matches.fold(None, |acc, (idx, _)| match acc {
Some((count, prev_idx)) => Some((count + 1, prev_idx)),
None => Some((1, idx)),
})
})
.reduce(|a, b| match a.0.cmp(&b.0) {
// Return the one with more matches.
Ordering::Less => b,
Ordering::Greater => a,
// Return the one that matches earlier
Ordering::Equal => match a.1.cmp(&b.1) {
Ordering::Less => a,
_ => b,
},
});

match_count_pos.map(|(count, pos)| (count, pos, variant))
}));
// (start_index, end_index, valid_name, variant)
// TODO: Consider using a struct with named fields instead of a 4-tuple.
let mut all_matches: Vec<(usize, usize, &'c str, &'c str)> = Vec::new();

// Look for substrings of valid values
for (variant, valid_names) in candidates {
for valid_name in valid_names {
for (start_idx, _) in value_str.match_indices(valid_name) {
let end_idx = start_idx + valid_name.len();
all_matches.push((start_idx, end_idx, valid_name, variant));
}
}
}

// Sort by max count, then min pos.
result.sort_by(|a, b| match a.0.cmp(&b.0) {
Ordering::Less => Ordering::Greater,
Ordering::Greater => Ordering::Less,
Ordering::Equal => a.1.cmp(&b.1),
// No substring match at all for any variant, early return.
if all_matches.is_empty() {
return None;
}

// Sort by position and length
all_matches.sort_by(|a, b| {
match a.0.cmp(&b.0) {
Ordering::Equal => b.1.cmp(&a.1), // Longer first
ordering => ordering, // Less or Greater stays the same
}
});

// Filter for max count.
let max_count = result.first().map(|r| r.0).unwrap_or(0);
result.retain(|r| r.0 == max_count);
// Filter out overlapping matches
let mut filtered_matches = Vec::new();
let mut last_end = 0;

// Return the best match if there is one.
if let Some((_, _, candidate)) = result.first() {
for current_match in all_matches {
if current_match.0 >= last_end {
// No overlap with previous match
last_end = current_match.1;
filtered_matches.push(current_match);
}
}

// Count occurrences of each variant in non-overlapping matches.
// (count, variant)
let mut variant_counts = HashMap::<&'c str, usize>::new();
for (_, _, _, variant) in &filtered_matches {
if let Some(count) = variant_counts.get_mut(*variant) {
// Increment count if variant already exists.
*count += 1;
} else {
// Add new variant.
variant_counts.insert(variant, 1);
}
}

// Return the best match if there is one
if let Some((best_match, max_count)) = variant_counts
.iter()
.max_by(|(_, count_a), (_, count_b)| count_a.cmp(count_b))
{
flags.add_flag(Flag::SubstringMatch(value_str.into()));

// Add flag for multiple matches.
if result.len() > 1 {
flags.add_flag(Flag::StrMatchOneFromMany(
result
.iter()
.map(|(count, _, candidate)| ((*count) as usize, candidate.to_string()))
.collect(),
));
// Find all variants with the same count
let ties: Vec<_> = variant_counts
.iter()
.filter(|(_, count)| *count == max_count)
.map(|(variant, count)| (variant.to_string(), *count))
.collect();

// If there are multiple matches, add a flag
if ties.len() > 1 {
flags.add_flag(Flag::StrMatchOneFromMany(ties));
}

return Some(candidate);
return Some(best_match);
}

// No match found.
Expand Down
7 changes: 4 additions & 3 deletions engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ pub enum Flag {
FirstMatch(usize, Vec<Result<BamlValueWithFlags, ParsingError>>),
UnionMatch(usize, Vec<Result<BamlValueWithFlags, ParsingError>>),

StrMatchOneFromMany(Vec<(usize, String)>),
/// `[(value, count)]`
StrMatchOneFromMany(Vec<(String, usize)>),

DefaultFromNoValue,
DefaultButHadValue(crate::jsonish::Value),
Expand Down Expand Up @@ -177,8 +178,8 @@ impl std::fmt::Display for Flag {
}
Flag::StrMatchOneFromMany(values) => {
write!(f, "Enum one from many: ")?;
for (idx, value) in values {
writeln!(f, "Item {}: {}", idx, value)?;
for (value, count) in values {
writeln!(f, "Item {value}: {count}")?;
}
}
Flag::DefaultButHadUnparseableValue(value) => {
Expand Down
5 changes: 4 additions & 1 deletion engine/baml-lib/jsonish/src/deserializer/score.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ impl WithScore for Flag {
Flag::FirstMatch(_, _) => 1,
// No penalty for picking an option from a union
Flag::UnionMatch(_, _) => 0,
Flag::StrMatchOneFromMany(i) => i.into_iter().map(|(i, _)| *i as i32).sum::<i32>(),
Flag::StrMatchOneFromMany(values) => values
.into_iter()
.map(|(_, count)| *count as i32)
.sum::<i32>(),
Flag::StringToBool(_) => 1,
Flag::StringToNull(_) => 1,
Flag::StringToChar(_) => 1,
Expand Down
37 changes: 37 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,40 @@ test_deserializer!(
FieldType::List(FieldType::Enum("Category".to_string()).into()),
["ONE", "TWO", "THREE"]
);

test_deserializer!(
test_numerical_enum,
r#"
enum TaxReturnFormType {
F9325 @alias("9325")
F9465 @alias("9465")
F1040 @alias("1040")
F1040X @alias("1040-X")
}
"#,
r#"
(such as 1040-X, 1040, etc.) or any payment vouchers.
Based on the criteria provided, this page does not qualify as a tax return form page. Therefore, the appropriate response is:
```json
null
```
This indicates that there is no relevant tax return form type present on the page.
"#,
FieldType::Enum("TaxReturnFormType".to_string()).as_optional(),
null
);

test_failing_deserializer!(
test_ambiguous_substring_enum,
r#"
enum Car {
A @alias("car")
B @alias("car-2")
}
"#,
"The answer is not car or car-2!",
FieldType::Enum("Car".to_string())
);

0 comments on commit 0c5cbd4

Please sign in to comment.