From 423faa1af5a594b7f78f7bb5620e3146a8989da5 Mon Sep 17 00:00:00 2001 From: hellovai Date: Sat, 29 Jun 2024 03:37:39 -0700 Subject: [PATCH] Add improved static analysis for jinja (#734) * Includes better testing for ctx.output_format, unions, filters --- .../functions_v2/prompt_errors/prompt1.baml | 6 - .../baml-lib/jinja/src/evaluate_type/expr.rs | 265 +++++++++++++++--- .../baml-lib/jinja/src/evaluate_type/mod.rs | 57 +++- .../jinja/src/evaluate_type/test_expr.rs | 42 ++- .../baml-lib/jinja/src/evaluate_type/types.rs | 116 ++++++-- 5 files changed, 413 insertions(+), 73 deletions(-) diff --git a/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml b/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml index c8314407e..76396e2ce 100644 --- a/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml +++ b/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml @@ -50,9 +50,3 @@ function Bar1(a: string) -> int { // 23 | prompt #" // 24 | {{ Foo(a) }} // | -// warning: 'b' is a (function Foo | function Foo2), expected function -// --> functions_v2/prompt_errors/prompt1.baml:37 -// | -// 36 | -// 37 | {{ b() }} -// | diff --git a/engine/baml-lib/jinja/src/evaluate_type/expr.rs b/engine/baml-lib/jinja/src/evaluate_type/expr.rs index 1b92d1bfa..8362fdf10 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/expr.rs @@ -8,6 +8,90 @@ use super::{ ScopeTracker, TypeError, }; +fn parse_as_function_call<'a>( + expr: &ast::Spanned, + state: &mut ScopeTracker, + types: &PredefinedTypes, + t: &Type, +) -> (Type, Vec) { + match t { + Type::FunctionRef(name) => { + let mut positional_args = Vec::new(); + let mut kwargs = HashMap::new(); + for arg in &expr.args { + match arg { + ast::Expr::Kwargs(kkwargs) => { + for (k, v) in &kkwargs.pairs { + let t = tracker_visit_expr(v, state, types); + kwargs.insert(*k, t); + } + } + _ => { + let t = tracker_visit_expr(arg, state, types); + positional_args.push(t); + } + } + } + + types.check_function_args((&name, expr), &positional_args, &kwargs) + } + Type::Both(x, y) => { + match (x.as_ref(), y.as_ref()) { + (Type::FunctionRef(_), Type::FunctionRef(_)) => {} + (Type::FunctionRef(_), _) => return parse_as_function_call(expr, state, types, x), + (_, Type::FunctionRef(_)) => return parse_as_function_call(expr, state, types, y), + _ => {} + } + + let (t1, e1) = parse_as_function_call(expr, state, types, x); + let (t2, e2) = parse_as_function_call(expr, state, types, y); + match (e1.is_empty(), e2.is_empty()) { + (true, true) => (Type::merge([t1, t2]), vec![]), + (true, false) => (t1, e1), + (false, true) => (t2, e2), + (false, false) => ( + Type::merge([t1, t2]), + e1.into_iter().chain(e2.into_iter()).collect(), + ), + } + } + Type::Union(items) => { + let items = items + .iter() + .map(|x| parse_as_function_call(expr, state, types, x)) + .reduce(|acc, x| { + let (t1, e1) = acc; + let (t2, e2) = x; + ( + Type::merge([t1, t2]), + e1.into_iter().chain(e2.into_iter()).collect(), + ) + }); + match items { + Some(x) => x, + None => ( + Type::Unknown, + vec![TypeError::new_invalid_type( + &expr.expr, + t, + "function", + expr.span(), + )], + ), + } + } + _ => ( + Type::Unknown, + vec![TypeError::new_invalid_type( + &expr.expr, + t, + "function", + expr.span(), + )], + ), + } +} + fn tracker_visit_expr<'a>( expr: &ast::Expr<'a>, state: &mut ScopeTracker, @@ -79,16 +163,145 @@ fn tracker_visit_expr<'a>( .unwrap_or(Type::Unknown); Type::merge([true_expr, false_expr]) } - ast::Expr::Filter(expr) => match expr.name { - "items" => { - let inner = tracker_visit_expr(expr.expr.as_ref().unwrap(), state, types); - match inner { + ast::Expr::Filter(expr) => { + // Filters have a name + let inner = tracker_visit_expr(expr.expr.as_ref().unwrap(), state, types); + + let mut ensure_type = |error_string: &str| { + state.errors.push(TypeError::new_invalid_type( + expr.expr.as_ref().unwrap(), + &inner, + error_string, + expr.span(), + )); + }; + + let valid_filters = vec![ + "abs", + "attrs", + "batch", + "bool", + "capitalize", + "escape", + "first", + "last", + "default", + "float", + "indent", + "int", + "dictsort", + "items", + "join", + "length", + "list", + "lower", + "upper", + "map", + "max", + "min", + "pprint", + "reject", + "rejectattr", + "replace", + "reverse", + "round", + "safe", + "select", + "selectattr", + "slice", + "sort", + "split", + "title", + "tojson", + "json", + "trim", + "unique", + "urlencode", + ]; + match expr.name { + "abs" => { + if inner.matches(&Type::Number) { + ensure_type("number"); + } + Type::Number + } + "attrs" => Type::Unknown, + "batch" => Type::Unknown, + "bool" => Type::Bool, + "capitalize" | "escape" => { + if inner.matches(&Type::String) { + ensure_type("string"); + } + Type::String + } + "first" | "last" => match inner { + Type::List(t) => Type::merge([*t, Type::None]), + Type::Unknown => Type::Unknown, + _ => { + ensure_type("list"); + Type::Unknown + } + }, + "default" => Type::Unknown, + "float" => Type::Float, + "indent" => Type::String, + "int" => Type::Int, + "dictsort" | "items" => match inner { Type::Map(k, v) => Type::List(Box::new(Type::Tuple(vec![*k, *v]))), - _ => Type::Unknown, + Type::ClassRef(_) => { + Type::List(Box::new(Type::Tuple(vec![Type::String, Type::Unknown]))) + } + _ => { + ensure_type("map or class"); + Type::Unknown + } + }, + "join" => Type::String, + "length" => match inner { + Type::List(_) | Type::String | Type::ClassRef(_) | Type::Map(_, _) => Type::Int, + Type::Unknown => Type::Unknown, + _ => { + ensure_type("list, string, class or map"); + Type::Unknown + } + }, + "list" => Type::List(Box::new(Type::Unknown)), + "lower" | "upper" => { + if inner.matches(&Type::String) { + ensure_type("string"); + } + Type::String + } + "map" => Type::Unknown, + "max" => Type::Unknown, + "min" => Type::Unknown, + "pprint" => Type::Unknown, + "reject" => Type::Unknown, + "rejectattr" => Type::Unknown, + "replace" => Type::String, + "reverse" => Type::Unknown, + "round" => Type::Float, + "safe" => Type::String, + "select" => Type::Unknown, + "selectattr" => Type::Unknown, + "slice" => Type::Unknown, + "sort" => Type::Unknown, + "split" => Type::List(Box::new(Type::String)), + "title" => Type::String, + "tojson" | "json" => Type::String, + "trim" => Type::String, + "unique" => Type::Unknown, + "urlencode" => Type::String, + other => { + state.errors.push(TypeError::new_invalid_filter( + other, + expr.span(), + &valid_filters, + )); + Type::Unknown } } - _ => Type::Unknown, - }, + } ast::Expr::Test(expr) => { let _test = tracker_visit_expr(&expr.expr, state, types); // TODO: Check for type compatibility @@ -122,41 +335,9 @@ fn tracker_visit_expr<'a>( ast::Expr::Slice(_slice) => Type::Unknown, ast::Expr::Call(expr) => { let func = tracker_visit_expr(&expr.expr, state, types); - - match func { - Type::FunctionRef(name) => { - // lets segregate positional and keyword arguments - let mut positional_args = Vec::new(); - let mut kwargs = HashMap::new(); - for arg in &expr.args { - match arg { - ast::Expr::Kwargs(kkwargs) => { - for (k, v) in &kkwargs.pairs { - let t = tracker_visit_expr(v, state, types); - kwargs.insert(*k, t); - } - } - _ => { - let t = tracker_visit_expr(arg, state, types); - positional_args.push(t); - } - } - } - - let res = types.check_function_args((&name, expr), &positional_args, &kwargs); - state.errors.extend(res.1); - res.0 - } - t => { - state.errors.push(TypeError::new_invalid_type( - &expr.expr, - &t, - "function", - expr.span(), - )); - Type::Unknown - } - } + let (t, errs) = parse_as_function_call(expr, state, types, &func); + state.errors.extend(errs); + t } ast::Expr::List(expr) => { let inner = Type::merge( diff --git a/engine/baml-lib/jinja/src/evaluate_type/mod.rs b/engine/baml-lib/jinja/src/evaluate_type/mod.rs index 8dfb4f4d4..aafa1530c 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/mod.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/mod.rs @@ -7,6 +7,7 @@ mod test_expr; mod test_stmt; mod types; +use std::collections::HashSet; use std::fmt::Debug; use std::ops::Index; @@ -143,11 +144,57 @@ impl TypeError { } } - fn new_unknown_arg(func: &str, span: Span, name: &str) -> Self { - Self { - message: format!("Function '{}' does not have an argument '{}'", func, name), - span, - } + fn new_unknown_arg(func: &str, span: Span, name: &str, valid_args: HashSet<&String>) -> Self { + let names = valid_args.into_iter().collect::>(); + let mut close_names = sort_by_match(name, &names, Some(3)); + close_names.sort(); + let close_names = close_names; + + let message = if close_names.is_empty() { + // If no names are close enough, suggest nothing or provide a generic message + format!("Function '{}' does not have an argument '{}'.", func, name) + } else if close_names.len() == 1 { + // If there's only one close name, suggest it + format!( + "Function '{}' does not have an argument '{}'. Did you mean '{}'?", + func, name, close_names[0] + ) + } else { + // If there are multiple close names, suggest them all + let suggestions = close_names.join("', '"); + format!( + "Function '{}' does not have an argument '{}'. Did you mean one of these: '{}'?", + func, name, suggestions + ) + }; + + Self { message, span } + } + + fn new_invalid_filter(name: &str, span: Span, valid_filters: &Vec<&str>) -> Self { + let mut close_names = sort_by_match(name, valid_filters, Some(5)); + close_names.sort(); + let close_names = close_names; + + let message = if close_names.is_empty() { + // If no names are close enough, suggest nothing or provide a generic message + format!("Filter '{}' does not exist", name) + } else if close_names.len() == 1 { + // If there's only one close name, suggest it + format!( + "Filter '{}' does not exist. Did you mean '{}'?", + name, close_names[0] + ) + } else { + // If there are multiple close names, suggest them all + let suggestions = close_names.join("', '"); + format!( + "Filter '{}' does not exist. Did you mean one of these: '{}'?", + name, suggestions + ) + }; + + Self { message: format!("{message}\n\nSee: https://docs.rs/minijinja/latest/minijinja/filters/index.html#functions for the compelete list"), span } } fn new_invalid_type(expr: &Expr, got: &Type, expected: &str, span: Span) -> Self { diff --git a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs index a8eaa2028..9c484c392 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs @@ -100,6 +100,11 @@ fn test_ifexpr() { Type::Union(vec![Type::Number, Type::String]) ); + assert_eq!( + assert_evaluates_to!("1 if true else '2'", &types), + Type::Union(vec![Type::String, Type::Number]) + ); + types.add_function("AnotherFunc", Type::Float, vec![("arg".into(), Type::Bool)]); types.add_variable("BasicTest", Type::Int); @@ -192,7 +197,42 @@ fn test_call_function() { assert_fails_to!("AnotherFunc(true, arg2='1', arg4=1)", &types), vec![ "Function 'AnotherFunc' expects argument 'arg3'", - "Function 'AnotherFunc' does not have an argument 'arg4'" + "Function 'AnotherFunc' does not have an argument 'arg4'. Did you mean 'arg3'?" ] ); } + +#[test] +fn test_output_format() { + let types = PredefinedTypes::default(); + assert_eq!( + assert_evaluates_to!("ctx.output_format(prefix='hi')", &types), + Type::String + ); + + assert_eq!( + assert_evaluates_to!("ctx.output_format(prefix='1', or_splitter='1')", &types), + Type::String + ); + + assert_eq!( + assert_evaluates_to!( + "ctx.output_format(prefix='1', enum_value_prefix=none)", + &types + ), + Type::String + ); + + assert_eq!( + assert_fails_to!( + "ctx.output_format(prefix='1', always_hoist_enums=1)", + &types + ), + vec!["Function 'baml::OutputFormat' expects argument 'always_hoist_enums' to be of type (bool | none), but got number"] + ); + + assert_eq!( + assert_fails_to!("ctx.output_format(prefix='1', unknown=1)", &types), + vec!["Function 'baml::OutputFormat' does not have an argument 'unknown'. Did you mean one of these: 'always_hoist_enums', 'enum_value_prefix', 'or_splitter'?"] + ); +} diff --git a/engine/baml-lib/jinja/src/evaluate_type/types.rs b/engine/baml-lib/jinja/src/evaluate_type/types.rs index 24bd62c5f..c92aa8756 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/types.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/types.rs @@ -1,5 +1,9 @@ use core::panic; -use std::{collections::HashMap, ops::BitOr}; +use std::{ + collections::{HashMap, HashSet}, + ops::BitOr, + vec, +}; use minijinja::machinery::{ ast::{Call, Spanned}, @@ -23,6 +27,8 @@ pub enum Type { Map(Box, Box), Tuple(Vec), Union(Vec), + // It is simultaneously two types, whichever fits best + Both(Box, Box), ClassRef(String), FunctionRef(String), Image, @@ -31,25 +37,37 @@ pub enum Type { impl PartialEq for Type { fn eq(&self, other: &Self) -> bool { - match (self, other) { + self.matches(other) + } +} + +impl Eq for Type {} + +impl Type { + pub fn matches(&self, r: &Self) -> bool { + match (self, r) { (Self::Unknown, Self::Unknown) => true, (Self::Unknown, _) => true, (_, Self::Unknown) => true, (Self::Number, Self::Int | Self::Float) => true, (Self::Int | Self::Float, Self::Number) => true, - (Self::List(l0), Self::List(r0)) => l0 == r0, - (Self::Map(l0, l1), Self::Map(r0, r1)) => l0 == r0 && l1 == r1, - (Self::Union(l0), Self::Union(r0)) => l0 == r0, + (Self::List(l0), Self::List(r0)) => l0.matches(r0), + (Self::Map(l0, l1), Self::Map(r0, r1)) => l0.matches(r0) && l1.matches(r1), + (Self::Union(l0), Self::Union(r0)) => { + // Sort l0 and r0 to make sure the order doesn't matter + let mut l0 = l0.clone(); + let mut r0 = r0.clone(); + l0.sort(); + r0.sort(); + l0 == r0 + } + (l0, Self::Union(r0)) => r0.iter().any(|x| l0.matches(x)), (Self::ClassRef(l0), Self::ClassRef(r0)) => l0 == r0, (Self::FunctionRef(l0), Self::FunctionRef(r0)) => l0 == r0, - _ => core::mem::discriminant(self) == core::mem::discriminant(other), + _ => core::mem::discriminant(self) == core::mem::discriminant(r), } } -} - -impl Eq for Type {} -impl Type { pub fn name(&self) -> String { match self { Type::Unknown => "".into(), @@ -70,6 +88,7 @@ impl Type { "({})", v.iter().map(|x| x.name()).collect::>().join(" | ") ), + Type::Both(l, r) => format!("{} & {}", l.name(), r.name()), Type::ClassRef(name) => format!("class {}", name), Type::FunctionRef(name) => format!("function {}", name), Type::Image => "image".into(), @@ -77,6 +96,14 @@ impl Type { } } + pub fn is_optional(&self) -> bool { + match self { + Type::None => true, + Type::Union(v) => v.iter().any(|x| x.is_optional()), + _ => false, + } + } + pub fn merge<'a, I>(v: I) -> Type where I: IntoIterator, @@ -158,10 +185,33 @@ impl PredefinedTypes { pub fn default() -> Self { Self { - functions: HashMap::from([( - "baml::Chat".into(), - (Type::String, vec![("role".into(), Type::String)]), - )]), + functions: HashMap::from([ + ( + "baml::Chat".into(), + (Type::String, vec![("role".into(), Type::String)]), + ), + ( + "baml::OutputFormat".into(), + ( + Type::String, + vec![ + ("prefix".into(), Type::merge(vec![Type::String, Type::None])), + ( + "or_splitter".into(), + Type::merge(vec![Type::String, Type::None]), + ), + ( + "enum_value_prefix".into(), + Type::merge(vec![Type::String, Type::None]), + ), + ( + "always_hoist_enums".into(), + Type::merge(vec![Type::Bool, Type::None]), + ), + ], + ), + ), + ]), classes: HashMap::from([ ( "baml::Client".into(), @@ -173,10 +223,16 @@ impl PredefinedTypes { ( "baml::Context".into(), HashMap::from([ - ("output_format".into(), Type::String), + ( + "output_format".into(), + Type::Both( + Type::String.into(), + Type::FunctionRef("baml::OutputFormat".into()).into(), + ), + ), ("client".into(), Type::ClassRef("baml::Client".into())), ( - "env".into(), + "tags".into(), Type::Map(Box::new(Type::String), Box::new(Type::String)), ), ]), @@ -385,8 +441,20 @@ impl PredefinedTypes { let (ret, args) = val.unwrap(); let mut errors = Vec::new(); + // Check how many args are required. + let mut optional_args = vec![]; + for (name, t) in args.iter().rev() { + if !t.is_optional() { + break; + } + optional_args.push(name); + } + let required_args = args.len() - optional_args.len(); + // Check count - if positional_args.len() + kwargs.len() != args.len() { + if positional_args.len() + kwargs.len() < required_args + || (positional_args.len() + kwargs.len()) > args.len() + { errors.push(TypeError::new_wrong_arg_count( func, span, @@ -394,9 +462,11 @@ impl PredefinedTypes { positional_args.len() + kwargs.len(), )); } else { + let mut unused_args = args.iter().map(|(name, _)| name).collect::>(); // Check types for (i, (name, t)) in args.iter().enumerate() { if i < positional_args.len() { + unused_args.remove(name); let arg_t = &positional_args[i]; if arg_t != t { errors.push(TypeError::new_wrong_arg_type( @@ -410,6 +480,7 @@ impl PredefinedTypes { } } else { if let Some(arg_t) = kwargs.get(name.as_str()) { + unused_args.remove(name); if arg_t != t { errors.push(TypeError::new_wrong_arg_type( func, @@ -421,14 +492,21 @@ impl PredefinedTypes { )); } } else { - errors.push(TypeError::new_missing_arg(func, span, name)); + if !optional_args.contains(&name) { + errors.push(TypeError::new_missing_arg(func, span, name)); + } } } } kwargs.iter().for_each(|(name, _)| { if !args.iter().any(|(arg_name, _)| arg_name == name) { - errors.push(TypeError::new_unknown_arg(func, span, name)); + errors.push(TypeError::new_unknown_arg( + func, + span, + name, + unused_args.clone(), + )); } }); }