Skip to content

Commit

Permalink
Prefer case sensitive match over case insensitive (#1063)
Browse files Browse the repository at this point in the history
Related to #860, improves upon #1056.
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Enhance `match_string` to prefer case-sensitive matches, update
`coerce_bool` for case variations, and adjust tests accordingly.
> 
>   - **Behavior**:
> - `match_string` in `match_string.rs` now prioritizes case-sensitive
matches over case-insensitive ones.
>     - Handles punctuation by stripping it and retrying the match.
> - `coerce_bool` in `coerce_primitive.rs` updated to handle different
case variations of boolean strings.
>   - **Tests**:
> - Added `case_sensitive_non_ambiguous_match` test in `test_enum.rs`
for case-sensitive matching.
> - Added `case_insensitive_ambiguous_match` failing test in
`test_enum.rs` to ensure ambiguous matches are not incorrectly resolved.
> - Updated existing tests in `test_enum.rs` to align with new matching
logic.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 8fdc5b4. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Oct 19, 2024
1 parent 830b0cb commit cd6b141
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,11 @@ pub(super) fn coerce_bool(
target,
Some(value),
&[
("true", vec!["true".into()]),
("false", vec!["false".into()]),
("true", vec!["true".into(), "True".into(), "TRUE".into()]),
(
"false",
vec!["false".into(), "False".into(), "FALSE".into()],
),
],
) {
Ok(val) => match val.value().as_str() {
Expand Down
63 changes: 41 additions & 22 deletions engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{

use super::ParsingContext;

/// Heuristic match of different possible values against an input string.
pub(super) fn match_string(
parsing_context: &ParsingContext,
target: &FieldType,
Expand All @@ -35,6 +36,7 @@ pub(super) fn match_string(

let mut flags = DeserializerConditions::new();

// Grab context.
let jsonish_string = match value {
jsonish::Value::String(s) => s.clone(),
jsonish::Value::AnyOf(_, s) => {
Expand All @@ -47,25 +49,42 @@ pub(super) fn match_string(
}
};

// Trim whitespaces.
let match_context = jsonish_string.trim();

// First attempt, case sensitive match ignoring possible pucntuation.
if let Some(string_match) = string_match_strategy(&match_context, &candidates, &mut flags) {
return try_match_only_once(parsing_context, target, string_match, flags);
}

// Strip punctuation and try again.
let match_context = strip_punctuation(match_context);

let candidates = candidates
.iter()
.map(|(candidate, valid_values)| {
(
*candidate,
valid_values.iter().map(|v| strip_punctuation(v)).collect(),
)
})
.collect::<Vec<_>>();
// TODO: If the candidates don't contain any punctuation themselves there's
// no point in removing the punctuation from the input string and running
// the entire algorithm again because it should've already matched the
// substrings in the previous attempt. This can be optimized.
let mut candidates = Vec::from_iter(candidates.iter().map(|(candidate, valid_values)| {
let stripped_valid_values = valid_values.iter().map(|v| strip_punctuation(v)).collect();
(*candidate, stripped_valid_values)
}));

// Second attempt, case sensitive match without punctuation.
if let Some(string_match) = string_match_strategy(&match_context, &candidates, &mut flags) {
return try_match_only_once(parsing_context, target, string_match, flags);
}

// Last hope, case insensitive match without punctuation. This could yield
// wrong results since the name of a candidate could appear as a "normal"
// word used by the LLM to explain the output.
let match_context = match_context.to_lowercase();

// TODO: Consider adding a flag for case insensitive match.
candidates.iter_mut().for_each(|(_, valid_values)| {
valid_values.iter_mut().for_each(|v| *v = v.to_lowercase());
});

// There goes our last hope :)
if let Some(string_match) = string_match_strategy(&match_context, &candidates, &mut flags) {
return try_match_only_once(parsing_context, target, string_match, flags);
}
Expand All @@ -79,6 +98,9 @@ fn strip_punctuation(s: &str) -> String {
.collect::<String>()
}

/// Helper function to return a single string match result.
///
/// Multiple results will yield an error.
fn try_match_only_once(
parsing_context: &ParsingContext<'_>,
target: &FieldType,
Expand All @@ -100,37 +122,34 @@ fn try_match_only_once(
Ok((string_match.to_string(), flags).into())
}

/// Heuristic string match algorithm.
///
/// The algorithm is case sensitive so for case insensitive matches it must
/// recieve lowercase strings. This algorithm will first try to look for exact
/// matches in the input string, if it doesn't find any it will look for
/// substring matches and return the one with the most matches. Whether that is
/// an ambigous match or not is up to the caller to decide.
fn string_match_strategy<'c>(
value_str: &str,
candidates: &'c [(&'c str, Vec<String>)],
flags: &mut DeserializerConditions,
) -> Option<&'c str> {
// Try and look for an exact match against valid values.
for (candidate, valid_values) in candidates {
// Consider adding a flag for case insensitive match.
if valid_values
.iter()
.any(|v| v.eq_ignore_ascii_case(value_str))
{
if valid_values.iter().any(|v| v == value_str) {
// We did nothing fancy, so no extra flags.
return Some(candidate);
}
}

// We'll match evetything using lower case and then return the original
// variants.
let case_insensitive_str = value_str.to_lowercase();

// Now find all the candidates which occur in the value, by frequency.
let mut result = Vec::from_iter(candidates.iter().filter_map(|(variant, valid_names)| {
// Check how many counts of the variant are in the value.
let match_count_pos = valid_names
.iter()
.filter_map(|valid_name| {
// Convert to lower.
let case_insensitive_name = valid_name.to_lowercase();
// Match against full lower case input.
let matches = case_insensitive_str.match_indices(&case_insensitive_name);
// Match ocurrences of valid name.
let matches = value_str.match_indices(valid_name);
// Return (count, first_idx)
matches.fold(None, |acc, (idx, _)| match acc {
Some((count, prev_idx)) => Some((count + 1, prev_idx)),
Expand Down
15 changes: 15 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ test_deserializer!(
"ONE"
);

test_deserializer!(
case_sensitive_non_ambiguous_match,
ENUM_FILE,
r#"TWO" is one of the correct answers."#,
FieldType::Enum("Category".to_string()),
"TWO"
);

test_failing_deserializer!(
case_insensitive_ambiguous_match,
ENUM_FILE,
r#"Two" is one of the correct answers."#,
FieldType::Enum("Category".to_string())
);

test_failing_deserializer!(
from_string_with_extra_text_after_3,
ENUM_FILE,
Expand Down

0 comments on commit cd6b141

Please sign in to comment.