diff --git a/engine/baml-lib/jinja/src/evaluate_type/expr.rs b/engine/baml-lib/jinja/src/evaluate_type/expr.rs index 1dfb645d1..c52d9833f 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/expr.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; +use baml_types::LiteralValue; use minijinja::machinery::ast; +use std::str::FromStr; use super::{ pretty_print::pretty_print, @@ -370,8 +372,13 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type { match v.kind() { minijinja::value::ValueKind::Undefined => Type::Undefined, minijinja::value::ValueKind::None => Type::None, - minijinja::value::ValueKind::Bool => Type::Bool, - minijinja::value::ValueKind::String => Type::String, + minijinja::value::ValueKind::Bool => { + match bool::from_str(&v.to_string()) { + Ok(b) => Type::Literal(LiteralValue::Bool(b)), + Err(_) => Type::Bool, + } + }, + minijinja::value::ValueKind::String => Type::Literal(LiteralValue::String(v.to_string())), minijinja::value::ValueKind::Seq => { let list = v.as_seq().unwrap(); match list.item_count() { @@ -410,7 +417,12 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type { } minijinja::value::ValueKind::Map => Type::Unknown, // We don't handle these types - minijinja::value::ValueKind::Number => Type::Number, + minijinja::value::ValueKind::Number => { + match i64::from_str(&v.to_string()) { + Ok(i) => Type::Literal(LiteralValue::Int(i)), + Err(_) => Type::Number, + } + }, minijinja::value::ValueKind::Bytes => Type::Undefined, } } diff --git a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs index 4980f549d..fceace569 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs @@ -1,3 +1,4 @@ +use baml_types::LiteralValue; use minijinja::machinery::parse_expr; use crate::evaluate_type::{ @@ -93,17 +94,23 @@ fn test_ifexpr() { let mut types = PredefinedTypes::default(JinjaContext::Prompt); assert_eq!( assert_evaluates_to!("1 if true else 2", &types), - Type::Number + Type::Union(vec![ + Type::Literal(LiteralValue::Int(1)), + Type::Literal(LiteralValue::Int(2)) + ]) ); assert_eq!( assert_evaluates_to!("1 if true else '2'", &types), - Type::Union(vec![Type::Number, Type::String]) + Type::Union(vec![Type::Literal(LiteralValue::String("2".to_string())), Type::Literal(LiteralValue::Int(1))]) ); assert_eq!( assert_evaluates_to!("'1' if true else 2", &types), - Type::Union(vec![Type::Number, Type::String]) + Type::Union(vec![ + Type::Literal(LiteralValue::String("1".to_string())), + Type::Literal(LiteralValue::Int(2)) + ]) ); types.add_function("AnotherFunc", Type::Float, vec![("arg".into(), Type::Bool)]); @@ -143,7 +150,7 @@ fn test_call_function() { assert_eq!(assert_evaluates_to!("SomeFunc(true)", &types), Type::Float); assert_eq!( assert_fails_to!("SomeFunc(arg=1)", &types), - vec!["Function 'SomeFunc' expects argument 'arg' to be of type bool, but got number"] + vec!["Function 'SomeFunc' expects argument 'arg' to be of type bool, but got literal[1]"] ); types.add_function( @@ -158,14 +165,14 @@ fn test_call_function() { assert_eq!( assert_fails_to!("AnotherFunc(arg='true', arg2='1')", &types), - vec!["Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got string"] + vec![r#"Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got literal["true"]"#] ); assert_eq!( assert_fails_to!("AnotherFunc(arg=SomeFunc(true) ~ 1, arg2=1)", &types), vec![ "Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got string", - "Function 'AnotherFunc' expects argument 'arg2' to be of type string, but got number" + "Function 'AnotherFunc' expects argument 'arg2' to be of type string, but got literal[1]" ] ); @@ -201,6 +208,25 @@ fn test_call_function() { "Function 'AnotherFunc' does not have an argument 'arg4'. Did you mean 'arg3'?" ] ); + + types.add_function( + "TakesLiteralFoo", + Type::Float, + vec![ + ("arg".to_string(), + Type::Union(vec![ + Type::Literal(LiteralValue::String("Foo".to_string())), + Type::Literal(LiteralValue::String("Bar".to_string())) + ]) + ) + ] + ); + + assert_eq!( + assert_evaluates_to!("TakesLiteralFoo('Foo')", &types), + Type::Float + ); + } #[test] @@ -229,7 +255,7 @@ fn test_output_format() { "ctx.output_format(prefix='1', always_hoist_enums=1)", &types ), - vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (none | bool), but got number"] + vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (none | bool), but got literal[1]"] ); assert_eq!( diff --git a/engine/baml-lib/jinja/src/evaluate_type/test_stmt.rs b/engine/baml-lib/jinja/src/evaluate_type/test_stmt.rs index 2a4bea856..de15f6f8e 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/test_stmt.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/test_stmt.rs @@ -246,7 +246,7 @@ fn if_else() { "# .trim(), types, - vec!["Function 'Foo' expects argument 'arg' to be of type string, but got (undefined | number | string)"] + vec![r#"Function 'Foo' expects argument 'arg' to be of type string, but got (undefined | literal["2"] | literal[1])"#] ); let mut types = PredefinedTypes::default(JinjaContext::Prompt);