Skip to content

Commit

Permalink
Type jinja constant ints and bools as literals
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Nov 1, 2024
1 parent 408736f commit a2180b3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
15 changes: 13 additions & 2 deletions engine/baml-lib/jinja/src/evaluate_type/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;

use baml_types::LiteralValue;
use minijinja::machinery::ast;
use std::str::FromStr;

use super::{
pretty_print::pretty_print,
Expand Down Expand Up @@ -371,7 +372,12 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type {
match v.kind() {
minijinja::value::ValueKind::Undefined => Type::Undefined,
minijinja::value::ValueKind::None => Type::None,
minijinja::value::ValueKind::Bool => Type::Bool,
minijinja::value::ValueKind::Bool => {
match bool::from_str(&v.to_string()) {
Ok(b) => Type::Literal(LiteralValue::Bool(b)),
Err(_) => Type::Bool,
}
},
minijinja::value::ValueKind::String => Type::Literal(LiteralValue::String(v.to_string())),
minijinja::value::ValueKind::Seq => {
let list = v.as_seq().unwrap();
Expand Down Expand Up @@ -411,7 +417,12 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type {
}
minijinja::value::ValueKind::Map => Type::Unknown,
// We don't handle these types
minijinja::value::ValueKind::Number => Type::Number,
minijinja::value::ValueKind::Number => {
match i64::from_str(&v.to_string()) {
Ok(i) => Type::Literal(LiteralValue::Int(i)),
Err(_) => Type::Number,
}
},
minijinja::value::ValueKind::Bytes => Type::Undefined,
}
}
Expand Down
18 changes: 12 additions & 6 deletions engine/baml-lib/jinja/src/evaluate_type/test_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,23 @@ fn test_ifexpr() {
let mut types = PredefinedTypes::default(JinjaContext::Prompt);
assert_eq!(
assert_evaluates_to!("1 if true else 2", &types),
Type::Number
Type::Union(vec![
Type::Literal(LiteralValue::Int(1)),
Type::Literal(LiteralValue::Int(2))
])
);

assert_eq!(
assert_evaluates_to!("1 if true else '2'", &types),
Type::Union(vec![Type::Number, Type::Literal(LiteralValue::String("2".to_string()))])
Type::Union(vec![Type::Literal(LiteralValue::String("2".to_string())), Type::Literal(LiteralValue::Int(1))])
);

assert_eq!(
assert_evaluates_to!("'1' if true else 2", &types),
Type::Union(vec![Type::Number, Type::Literal(LiteralValue::String("1".to_string()))])
Type::Union(vec![
Type::Literal(LiteralValue::String("1".to_string())),
Type::Literal(LiteralValue::Int(2))
])
);

types.add_function("AnotherFunc", Type::Float, vec![("arg".into(), Type::Bool)]);
Expand Down Expand Up @@ -144,7 +150,7 @@ fn test_call_function() {
assert_eq!(assert_evaluates_to!("SomeFunc(true)", &types), Type::Float);
assert_eq!(
assert_fails_to!("SomeFunc(arg=1)", &types),
vec!["Function 'SomeFunc' expects argument 'arg' to be of type bool, but got number"]
vec!["Function 'SomeFunc' expects argument 'arg' to be of type bool, but got literal[1]"]
);

types.add_function(
Expand All @@ -166,7 +172,7 @@ fn test_call_function() {
assert_fails_to!("AnotherFunc(arg=SomeFunc(true) ~ 1, arg2=1)", &types),
vec![
"Function 'AnotherFunc' expects argument 'arg' to be of type bool, but got string",
"Function 'AnotherFunc' expects argument 'arg2' to be of type string, but got number"
"Function 'AnotherFunc' expects argument 'arg2' to be of type string, but got literal[1]"
]
);

Expand Down Expand Up @@ -249,7 +255,7 @@ fn test_output_format() {
"ctx.output_format(prefix='1', always_hoist_enums=1)",
&types
),
vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (none | bool), but got number"]
vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (none | bool), but got literal[1]"]
);

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/jinja/src/evaluate_type/test_stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fn if_else() {
"#
.trim(),
types,
vec![r#"Function 'Foo' expects argument 'arg' to be of type string, but got (undefined | number | literal["2"])"#]
vec![r#"Function 'Foo' expects argument 'arg' to be of type string, but got (undefined | literal["2"] | literal[1])"#]
);

let mut types = PredefinedTypes::default(JinjaContext::Prompt);
Expand Down

0 comments on commit a2180b3

Please sign in to comment.