Skip to content

Commit

Permalink
[bugfix] literals in template strings (#1132)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Fixes handling of literals in template strings by updating type
inference and test cases to correctly identify and evaluate literal
values.
> 
>   - **Behavior**:
> - Update `infer_const_type` in `expr.rs` to return `Type::Literal` for
`Bool`, `String`, and `Number` kinds.
> - Modify error messages in `test_expr.rs` and `test_stmt.rs` to
reflect literal types in function argument errors.
>   - **Tests**:
> - Update `test_ifexpr`, `test_call_function`, and `test_output_format`
in `test_expr.rs` to check for `Type::Literal`.
> - Update `if_else` in `test_stmt.rs` to handle literal types in
conditional expressions.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for a2180b3. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
imalsogreg authored Nov 1, 2024
1 parent 0b4a9bc commit b8a221f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
18 changes: 15 additions & 3 deletions engine/baml-lib/jinja/src/evaluate_type/expr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashMap;

use baml_types::LiteralValue;
use minijinja::machinery::ast;
use std::str::FromStr;

use super::{
pretty_print::pretty_print,
Expand Down Expand Up @@ -370,8 +372,13 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type {
match v.kind() {
minijinja::value::ValueKind::Undefined => Type::Undefined,
minijinja::value::ValueKind::None => Type::None,
minijinja::value::ValueKind::Bool => Type::Bool,
minijinja::value::ValueKind::String => Type::String,
minijinja::value::ValueKind::Bool => {
match bool::from_str(&v.to_string()) {
Ok(b) => Type::Literal(LiteralValue::Bool(b)),
Err(_) => Type::Bool,
}
},
minijinja::value::ValueKind::String => Type::Literal(LiteralValue::String(v.to_string())),
minijinja::value::ValueKind::Seq => {
let list = v.as_seq().unwrap();
match list.item_count() {
Expand Down Expand Up @@ -410,7 +417,12 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type {
}
minijinja::value::ValueKind::Map => Type::Unknown,
// We don't handle these types
minijinja::value::ValueKind::Number => Type::Number,
minijinja::value::ValueKind::Number => {
match i64::from_str(&v.to_string()) {
Ok(i) => Type::Literal(LiteralValue::Int(i)),
Err(_) => Type::Number,
}
},
minijinja::value::ValueKind::Bytes => Type::Undefined,
}
}
Expand Down
40 changes: 33 additions & 7 deletions engine/baml-lib/jinja/src/evaluate_type/test_expr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use baml_types::LiteralValue;
use minijinja::machinery::parse_expr;

use crate::evaluate_type::{
Expand Down Expand Up @@ -93,17 +94,23 @@ fn test_ifexpr() {
let mut types = PredefinedTypes::default(JinjaContext::Prompt);
assert_eq!(
assert_evaluates_to!("1 if true else 2", &types),
Type::Number
Type::Union(vec![
Type::Literal(LiteralValue::Int(1)),
Type::Literal(LiteralValue::Int(2))
])
);

assert_eq!(
assert_evaluates_to!("1 if true else '2'", &types),
Type::Union(vec![Type::Number, Type::String])
Type::Union(vec![Type::Literal(LiteralValue::String("2".to_string())), Type::Literal(LiteralValue::Int(1))])
);

assert_eq!(
assert_evaluates_to!("'1' if true else 2", &types),
Type::Union(vec![Type::Number, Type::String])
Type::Union(vec![
Type::Literal(LiteralValue::String("1".to_string())),
Type::Literal(LiteralValue::Int(2))
])
);

types.add_function("AnotherFunc", Type::Float, vec![("arg".into(), Type::Bool)]);
Expand Down Expand Up @@ -143,7 +150,7 @@ fn test_call_function() {
assert_eq!(assert_evaluates_to!("SomeFunc(true)", &types), Type::Float);
assert_eq!(
assert_fails_to!("SomeFunc(arg=1)", &types),
vec!["Function 'SomeFunc' expects argument 'arg' to be of type bool, but got number"]
vec!["Function 'SomeFunc' expects argument 'arg' to be of type bool, but got literal[1]"]
);

types.add_function(
Expand All @@ -158,14 +165,14 @@ fn test_call_function() {

assert_eq!(
assert_fails_to!("AnotherFunc(arg='true', arg2='1')", &types),
vec!["Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got string"]
vec![r#"Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got literal["true"]"#]
);

assert_eq!(
assert_fails_to!("AnotherFunc(arg=SomeFunc(true) ~ 1, arg2=1)", &types),
vec![
"Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got string",
"Function 'AnotherFunc' expects argument 'arg2' to be of type string, but got number"
"Function 'AnotherFunc' expects argument 'arg2' to be of type string, but got literal[1]"
]
);

Expand Down Expand Up @@ -201,6 +208,25 @@ fn test_call_function() {
"Function 'AnotherFunc' does not have an argument 'arg4'. Did you mean 'arg3'?"
]
);

types.add_function(
"TakesLiteralFoo",
Type::Float,
vec![
("arg".to_string(),
Type::Union(vec![
Type::Literal(LiteralValue::String("Foo".to_string())),
Type::Literal(LiteralValue::String("Bar".to_string()))
])
)
]
);

assert_eq!(
assert_evaluates_to!("TakesLiteralFoo('Foo')", &types),
Type::Float
);

}

#[test]
Expand Down Expand Up @@ -229,7 +255,7 @@ fn test_output_format() {
"ctx.output_format(prefix='1', always_hoist_enums=1)",
&types
),
vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (none | bool), but got number"]
vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (none | bool), but got literal[1]"]
);

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/jinja/src/evaluate_type/test_stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fn if_else() {
"#
.trim(),
types,
vec!["Function 'Foo' expects argument 'arg' to be of type string, but got (undefined | number | string)"]
vec![r#"Function 'Foo' expects argument 'arg' to be of type string, but got (undefined | literal["2"] | literal[1])"#]
);

let mut types = PredefinedTypes::default(JinjaContext::Prompt);
Expand Down

0 comments on commit b8a221f

Please sign in to comment.