Skip to content

Commit

Permalink
Make substring match algorithm case insensitive (#1056)
Browse files Browse the repository at this point in the history
Fixes #860 
<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Make substring match algorithm case insensitive and update tests for
boolean, enum, and literal string matches.
> 
>   - **Behavior**:
> - Make substring match algorithm case insensitive by converting
strings to lowercase in `coerce_primitive.rs` and `match_string.rs`.
> - Fix boolean coercion in `coerce_bool()` to handle case-insensitive
matches for "true" and "false".
>   - **Tests**:
> - Add tests in `test_basics.rs` for case-insensitive boolean matches
and handling of text around boolean values.
> - Add tests in `test_enum.rs` for case-insensitive enum matches and
handling of text around enum values.
> - Add tests in `test_literals.rs` for case-insensitive literal string
matches and handling of text around literal values.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for ef5be3c. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Oct 18, 2024
1 parent f09d943 commit fa2c477
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ pub(super) fn coerce_bool(
if let Some(value) = value {
match value {
crate::jsonish::Value::Boolean(b) => Ok(BamlValueWithFlags::Bool((*b).into())),
crate::jsonish::Value::String(s) => match s.as_str() {
crate::jsonish::Value::String(s) => match s.to_lowercase().as_str() {
"true" => Ok(BamlValueWithFlags::Bool(
(true, Flag::StringToBool(s.clone())).into(),
)),
"false" => Ok(BamlValueWithFlags::Bool(
(true, Flag::StringToBool(s.clone())).into(),
(false, Flag::StringToBool(s.clone())).into(),
)),
_ => {
match super::match_string::match_string(
Expand All @@ -255,7 +255,7 @@ pub(super) fn coerce_bool(
},
crate::jsonish::Value::Array(items) => {
coerce_array_to_singular(ctx, target, &items.iter().collect::<Vec<_>>(), &|value| {
coerce_float(ctx, target, Some(value))
coerce_bool(ctx, target, Some(value))
})
}
_ => Err(ctx.error_unexpected_type(target, value)),
Expand Down
57 changes: 31 additions & 26 deletions engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,34 +117,39 @@ fn string_match_strategy<'c>(
}
}

// We'll match evetything using lower case and then return the original
// variants.
let case_insensitive_str = value_str.to_lowercase();

// Now find all the candidates which occur in the value, by frequency.
let mut result = candidates
.iter()
.filter_map(|(variant, valid_names)| {
// Check how many counts of the variant are in the value.
let match_count_pos = valid_names
.iter()
.filter_map(|valid_name| {
let matches = value_str.match_indices(valid_name);
// Return (count, first_idx)
matches.fold(None, |acc, (idx, _)| match acc {
Some((count, prev_idx)) => Some((count + 1, prev_idx)),
None => Some((1, idx)),
})
let mut result = Vec::from_iter(candidates.iter().filter_map(|(variant, valid_names)| {
// Check how many counts of the variant are in the value.
let match_count_pos = valid_names
.iter()
.filter_map(|valid_name| {
// Convert to lower.
let case_insensitive_name = valid_name.to_lowercase();
// Match against full lower case input.
let matches = case_insensitive_str.match_indices(&case_insensitive_name);
// Return (count, first_idx)
matches.fold(None, |acc, (idx, _)| match acc {
Some((count, prev_idx)) => Some((count + 1, prev_idx)),
None => Some((1, idx)),
})
.reduce(|a, b| match a.0.cmp(&b.0) {
// Return the one with more matches.
Ordering::Less => b,
Ordering::Greater => a,
// Return the one that matches earlier
Ordering::Equal => match a.1.cmp(&b.1) {
Ordering::Less => a,
_ => b,
},
});
match_count_pos.map(|(count, pos)| (count, pos, variant))
})
.collect::<Vec<_>>();
})
.reduce(|a, b| match a.0.cmp(&b.0) {
// Return the one with more matches.
Ordering::Less => b,
Ordering::Greater => a,
// Return the one that matches earlier
Ordering::Equal => match a.1.cmp(&b.1) {
Ordering::Less => a,
_ => b,
},
});

match_count_pos.map(|(count, pos)| (count, pos, variant))
}));

// Sort by max count, then min pos.
result.sort_by(|a, b| match a.0.cmp(&b.0) {
Expand Down
31 changes: 31 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,44 @@ test_deserializer!(
[true]
);

test_deserializer!(
test_bool_wrapped_mismatched_case,
EMPTY_FILE,
"The answer is True",
FieldType::bool().as_list(),
[true]
);

test_deserializer!(
test_bool_wrapped_mismatched_case_preceded_by_text,
EMPTY_FILE,
"The tax return you provided has section for dependents.\n\nAnswer: **True**",
FieldType::bool(),
true
);

test_deserializer!(
test_bool_mismatched_case_followed_by_text,
EMPTY_FILE,
r#"False.\n\nThe statement "2 + 2 = 5" is mathematically incorrect. The correct sum of 2 + 2 is 4, not 5."#,
FieldType::bool(),
false
);

test_failing_deserializer!(
test_ambiguous_bool,
EMPTY_FILE,
"The answer is true or false",
FieldType::bool()
);

test_failing_deserializer!(
test_elaborate_ambiguous_bool,
EMPTY_FILE,
r#"False. The statement "2 + 2 = 5" is not accurate according to basic arithmetic. In standard arithmetic, the sum of 2 and 2 is equal to 4, not 5. Therefore, the statement does not hold true."#,
FieldType::bool()
);

test_deserializer!(
test_float,
EMPTY_FILE,
Expand Down
32 changes: 32 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ TWO
}
"#;

const PASCAL_CASE_ENUM_FILE: &str = r#"
// Enums
enum PascalCaseCategory {
One
Two
}
"#;

test_deserializer!(
test_enum,
ENUM_FILE,
Expand Down Expand Up @@ -56,6 +64,30 @@ test_deserializer!(
"ONE"
);

test_deserializer!(
from_string_and_case_mismatch,
ENUM_FILE,
"The answer is One",
FieldType::Enum("Category".to_string()),
"ONE"
);

test_deserializer!(
from_string_and_case_mismatch_wrapped,
ENUM_FILE,
"**one** is the answer",
FieldType::Enum("Category".to_string()),
"ONE"
);

test_deserializer!(
from_string_and_case_mismatch_upper,
PASCAL_CASE_ENUM_FILE,
"**ONE** is the answer",
FieldType::Enum("PascalCaseCategory".to_string()),
"One"
);

test_deserializer!(
from_string_with_extra_text_after_2,
ENUM_FILE,
Expand Down
42 changes: 42 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_literals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ test_deserializer!(
"TWO"
);

test_deserializer!(
test_literal_string_preceded_by_extra_text_case_mismatch,
EMPTY_FILE,
"The answer is Two",
FieldType::Literal(LiteralValue::String("TWO".into())),
"TWO"
);

test_deserializer!(
test_literal_string_followed_by_extra_text,
EMPTY_FILE,
Expand All @@ -90,6 +98,14 @@ test_deserializer!(
"TWO"
);

test_deserializer!(
test_literal_string_followed_by_extra_text_case_mismatch,
EMPTY_FILE,
"Two is the answer",
FieldType::Literal(LiteralValue::String("TWO".into())),
"TWO"
);

test_deserializer!(
test_literal_string_with_quotes_preceded_by_extra_text,
EMPTY_FILE,
Expand All @@ -98,6 +114,14 @@ test_deserializer!(
"TWO"
);

test_deserializer!(
test_literal_string_with_quotes_preceded_by_extra_text_case_mismatch,
EMPTY_FILE,
r#"The answer is "two""#,
FieldType::Literal(LiteralValue::String("TWO".into())),
"TWO"
);

test_deserializer!(
test_literal_string_with_quotes_followed_by_extra_text,
EMPTY_FILE,
Expand All @@ -106,6 +130,24 @@ test_deserializer!(
"TWO"
);

test_deserializer!(
test_literal_string_with_quotes_followed_by_extra_text_case_mismatch,
EMPTY_FILE,
r#""Two" is the answer"#,
FieldType::Literal(LiteralValue::String("TWO".into())),
"TWO"
);

test_deserializer!(
test_literal_string_case_mismatch_upper,
EMPTY_FILE,
// Came up with this example unintentioanlly but this causes ambiguity
// issues with unions ("two" | "one"), see the TODO at the end of this file.
r#"The ansewr "TWO" is the correct one"#,
FieldType::Literal(LiteralValue::String("two".into())),
"two"
);

test_deserializer!(
test_literal_string_with_special_characters,
EMPTY_FILE,
Expand Down

0 comments on commit fa2c477

Please sign in to comment.