Skip to content

Commit

Permalink
Add improved static analysis for jinja (#734)
Browse files Browse the repository at this point in the history
* Includes better testing for ctx.output_format,
  unions, filters
  • Loading branch information
hellovai authored Jun 29, 2024
1 parent 528f242 commit 423faa1
Show file tree
Hide file tree
Showing 5 changed files with 413 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,3 @@ function Bar1(a: string) -> int {
// 23 | prompt #"
// 24 | {{ Foo(a) }}
// |
// warning: 'b' is a (function Foo | function Foo2), expected function
// --> functions_v2/prompt_errors/prompt1.baml:37
// |
// 36 |
// 37 | {{ b() }}
// |
265 changes: 223 additions & 42 deletions engine/baml-lib/jinja/src/evaluate_type/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,90 @@ use super::{
ScopeTracker, TypeError,
};

fn parse_as_function_call<'a>(
expr: &ast::Spanned<ast::Call>,
state: &mut ScopeTracker,
types: &PredefinedTypes,
t: &Type,
) -> (Type, Vec<TypeError>) {
match t {
Type::FunctionRef(name) => {
let mut positional_args = Vec::new();
let mut kwargs = HashMap::new();
for arg in &expr.args {
match arg {
ast::Expr::Kwargs(kkwargs) => {
for (k, v) in &kkwargs.pairs {
let t = tracker_visit_expr(v, state, types);
kwargs.insert(*k, t);
}
}
_ => {
let t = tracker_visit_expr(arg, state, types);
positional_args.push(t);
}
}
}

types.check_function_args((&name, expr), &positional_args, &kwargs)
}
Type::Both(x, y) => {
match (x.as_ref(), y.as_ref()) {
(Type::FunctionRef(_), Type::FunctionRef(_)) => {}
(Type::FunctionRef(_), _) => return parse_as_function_call(expr, state, types, x),
(_, Type::FunctionRef(_)) => return parse_as_function_call(expr, state, types, y),
_ => {}
}

let (t1, e1) = parse_as_function_call(expr, state, types, x);
let (t2, e2) = parse_as_function_call(expr, state, types, y);
match (e1.is_empty(), e2.is_empty()) {
(true, true) => (Type::merge([t1, t2]), vec![]),
(true, false) => (t1, e1),
(false, true) => (t2, e2),
(false, false) => (
Type::merge([t1, t2]),
e1.into_iter().chain(e2.into_iter()).collect(),
),
}
}
Type::Union(items) => {
let items = items
.iter()
.map(|x| parse_as_function_call(expr, state, types, x))
.reduce(|acc, x| {
let (t1, e1) = acc;
let (t2, e2) = x;
(
Type::merge([t1, t2]),
e1.into_iter().chain(e2.into_iter()).collect(),
)
});
match items {
Some(x) => x,
None => (
Type::Unknown,
vec![TypeError::new_invalid_type(
&expr.expr,
t,
"function",
expr.span(),
)],
),
}
}
_ => (
Type::Unknown,
vec![TypeError::new_invalid_type(
&expr.expr,
t,
"function",
expr.span(),
)],
),
}
}

fn tracker_visit_expr<'a>(
expr: &ast::Expr<'a>,
state: &mut ScopeTracker,
Expand Down Expand Up @@ -79,16 +163,145 @@ fn tracker_visit_expr<'a>(
.unwrap_or(Type::Unknown);
Type::merge([true_expr, false_expr])
}
ast::Expr::Filter(expr) => match expr.name {
"items" => {
let inner = tracker_visit_expr(expr.expr.as_ref().unwrap(), state, types);
match inner {
ast::Expr::Filter(expr) => {
// Filters have a name
let inner = tracker_visit_expr(expr.expr.as_ref().unwrap(), state, types);

let mut ensure_type = |error_string: &str| {
state.errors.push(TypeError::new_invalid_type(
expr.expr.as_ref().unwrap(),
&inner,
error_string,
expr.span(),
));
};

let valid_filters = vec![
"abs",
"attrs",
"batch",
"bool",
"capitalize",
"escape",
"first",
"last",
"default",
"float",
"indent",
"int",
"dictsort",
"items",
"join",
"length",
"list",
"lower",
"upper",
"map",
"max",
"min",
"pprint",
"reject",
"rejectattr",
"replace",
"reverse",
"round",
"safe",
"select",
"selectattr",
"slice",
"sort",
"split",
"title",
"tojson",
"json",
"trim",
"unique",
"urlencode",
];
match expr.name {
"abs" => {
if inner.matches(&Type::Number) {
ensure_type("number");
}
Type::Number
}
"attrs" => Type::Unknown,
"batch" => Type::Unknown,
"bool" => Type::Bool,
"capitalize" | "escape" => {
if inner.matches(&Type::String) {
ensure_type("string");
}
Type::String
}
"first" | "last" => match inner {
Type::List(t) => Type::merge([*t, Type::None]),
Type::Unknown => Type::Unknown,
_ => {
ensure_type("list");
Type::Unknown
}
},
"default" => Type::Unknown,
"float" => Type::Float,
"indent" => Type::String,
"int" => Type::Int,
"dictsort" | "items" => match inner {
Type::Map(k, v) => Type::List(Box::new(Type::Tuple(vec![*k, *v]))),
_ => Type::Unknown,
Type::ClassRef(_) => {
Type::List(Box::new(Type::Tuple(vec![Type::String, Type::Unknown])))
}
_ => {
ensure_type("map or class");
Type::Unknown
}
},
"join" => Type::String,
"length" => match inner {
Type::List(_) | Type::String | Type::ClassRef(_) | Type::Map(_, _) => Type::Int,
Type::Unknown => Type::Unknown,
_ => {
ensure_type("list, string, class or map");
Type::Unknown
}
},
"list" => Type::List(Box::new(Type::Unknown)),
"lower" | "upper" => {
if inner.matches(&Type::String) {
ensure_type("string");
}
Type::String
}
"map" => Type::Unknown,
"max" => Type::Unknown,
"min" => Type::Unknown,
"pprint" => Type::Unknown,
"reject" => Type::Unknown,
"rejectattr" => Type::Unknown,
"replace" => Type::String,
"reverse" => Type::Unknown,
"round" => Type::Float,
"safe" => Type::String,
"select" => Type::Unknown,
"selectattr" => Type::Unknown,
"slice" => Type::Unknown,
"sort" => Type::Unknown,
"split" => Type::List(Box::new(Type::String)),
"title" => Type::String,
"tojson" | "json" => Type::String,
"trim" => Type::String,
"unique" => Type::Unknown,
"urlencode" => Type::String,
other => {
state.errors.push(TypeError::new_invalid_filter(
other,
expr.span(),
&valid_filters,
));
Type::Unknown
}
}
_ => Type::Unknown,
},
}
ast::Expr::Test(expr) => {
let _test = tracker_visit_expr(&expr.expr, state, types);
// TODO: Check for type compatibility
Expand Down Expand Up @@ -122,41 +335,9 @@ fn tracker_visit_expr<'a>(
ast::Expr::Slice(_slice) => Type::Unknown,
ast::Expr::Call(expr) => {
let func = tracker_visit_expr(&expr.expr, state, types);

match func {
Type::FunctionRef(name) => {
// lets segregate positional and keyword arguments
let mut positional_args = Vec::new();
let mut kwargs = HashMap::new();
for arg in &expr.args {
match arg {
ast::Expr::Kwargs(kkwargs) => {
for (k, v) in &kkwargs.pairs {
let t = tracker_visit_expr(v, state, types);
kwargs.insert(*k, t);
}
}
_ => {
let t = tracker_visit_expr(arg, state, types);
positional_args.push(t);
}
}
}

let res = types.check_function_args((&name, expr), &positional_args, &kwargs);
state.errors.extend(res.1);
res.0
}
t => {
state.errors.push(TypeError::new_invalid_type(
&expr.expr,
&t,
"function",
expr.span(),
));
Type::Unknown
}
}
let (t, errs) = parse_as_function_call(expr, state, types, &func);
state.errors.extend(errs);
t
}
ast::Expr::List(expr) => {
let inner = Type::merge(
Expand Down
57 changes: 52 additions & 5 deletions engine/baml-lib/jinja/src/evaluate_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod test_expr;
mod test_stmt;
mod types;

use std::collections::HashSet;
use std::fmt::Debug;
use std::ops::Index;

Expand Down Expand Up @@ -143,11 +144,57 @@ impl TypeError {
}
}

fn new_unknown_arg(func: &str, span: Span, name: &str) -> Self {
Self {
message: format!("Function '{}' does not have an argument '{}'", func, name),
span,
}
fn new_unknown_arg(func: &str, span: Span, name: &str, valid_args: HashSet<&String>) -> Self {
let names = valid_args.into_iter().collect::<Vec<_>>();
let mut close_names = sort_by_match(name, &names, Some(3));
close_names.sort();
let close_names = close_names;

let message = if close_names.is_empty() {
// If no names are close enough, suggest nothing or provide a generic message
format!("Function '{}' does not have an argument '{}'.", func, name)
} else if close_names.len() == 1 {
// If there's only one close name, suggest it
format!(
"Function '{}' does not have an argument '{}'. Did you mean '{}'?",
func, name, close_names[0]
)
} else {
// If there are multiple close names, suggest them all
let suggestions = close_names.join("', '");
format!(
"Function '{}' does not have an argument '{}'. Did you mean one of these: '{}'?",
func, name, suggestions
)
};

Self { message, span }
}

fn new_invalid_filter(name: &str, span: Span, valid_filters: &Vec<&str>) -> Self {
let mut close_names = sort_by_match(name, valid_filters, Some(5));
close_names.sort();
let close_names = close_names;

let message = if close_names.is_empty() {
// If no names are close enough, suggest nothing or provide a generic message
format!("Filter '{}' does not exist", name)
} else if close_names.len() == 1 {
// If there's only one close name, suggest it
format!(
"Filter '{}' does not exist. Did you mean '{}'?",
name, close_names[0]
)
} else {
// If there are multiple close names, suggest them all
let suggestions = close_names.join("', '");
format!(
"Filter '{}' does not exist. Did you mean one of these: '{}'?",
name, suggestions
)
};

Self { message: format!("{message}\n\nSee: https://docs.rs/minijinja/latest/minijinja/filters/index.html#functions for the compelete list"), span }
}

fn new_invalid_type(expr: &Expr, got: &Type, expected: &str, span: Span) -> Self {
Expand Down
Loading

0 comments on commit 423faa1

Please sign in to comment.