Skip to content

Commit

Permalink
Type-narrowing for if blocks (#1313)
Browse files Browse the repository at this point in the history
Closes #1267

Use the `Scope` machinery in the jinja static analyzer to narrow the
types of values that have been used in the predicate expression of an
`if` block.

For example:

```
class Foo {
  x int
}

function UseOptionalFoo(inp: Foo?) {
  client ...
  prompt #"
    {% if inp %}
      {{ inp.x }}
    {%endif %}
  "#
}
```

Since `inp` has been checked for existence, we would like `inp.x` not to
raise a warning. We know it has an `x` fielde because `inp` can no
longer be null, because we are in a scope in which `inp == true`.

This PR does some basic analysis on predicates to update the typing
context with these narrowed types. The following scenarios are
supported:

 - `if foo` narrows `Foo | None` to `Foo` in the if body.
 - `if not foo` narrows in the else body.
 - `if foo and bar` narrows `foo` and `bar`
 - `if not foo and bar` narrows `foo` in else and `bar` in the if body

The screenshot shows a combination of these effects:
<img width="573" alt="Screenshot 2025-01-11 at 12 01 05 AM"
src="https://github.com/user-attachments/assets/2f458f9f-035f-40f2-9956-42b9961ce426"
/>


<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Implements type-narrowing for `if` blocks in Jinja static analyzer
using `Scope` machinery to update typing context based on predicate
truthiness.
> 
>   - **Behavior**:
> - Implements type-narrowing for `if` blocks in `stmt.rs` using `Scope`
machinery.
> - Handles scenarios: `if foo`, `if not foo`, `if foo and bar`, `if not
foo and bar`.
>   - **Functions**:
> - Adds `predicate_implications()` to determine type implications based
on predicates.
>     - Adds `truthy()` to derive truthy type versions.
>   - **Tests**:
> - Adds tests for `truthy_union` and `implication_from_nullable` in
`stmt.rs`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for e6ce8f4. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->

---------

Co-authored-by: aaronvg <[email protected]>
  • Loading branch information
imalsogreg and aaronvg authored Jan 11, 2025
1 parent 15a70ab commit 546f58f
Showing 1 changed file with 159 additions and 1 deletion.
160 changes: 159 additions & 1 deletion engine/baml-lib/jinja/src/evaluate_type/stmt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use minijinja::machinery::ast::{self, Stmt};
use minijinja::machinery::ast::{self, Stmt, UnaryOpKind};

use crate::evaluate_type::types::Type;

Expand Down Expand Up @@ -92,10 +92,19 @@ fn track_walk(node: &ast::Stmt<'_>, state: &mut PredefinedTypes) {
ast::Stmt::IfCond(stmt) => {
let _expr_type = evaluate_type(&stmt.expr, state);

let true_bindings = predicate_implications(&stmt.expr, state, true);
let false_bindings = predicate_implications(&stmt.expr, state, false);

// Record variables in each branch and their types (fuse them if they are the same)
state.start_branch();
true_bindings
.into_iter()
.for_each(|(k, v)| state.add_variable(k.as_str(), v));
stmt.true_body.iter().for_each(|x| track_walk(x, state));
state.start_else_branch();
false_bindings
.into_iter()
.for_each(|(k, v)| state.add_variable(k.as_str(), v));
stmt.false_body.iter().for_each(|x| track_walk(x, state));
state.resolve_branch();
}
Expand Down Expand Up @@ -130,3 +139,152 @@ pub fn get_variable_types(stmt: &Stmt, state: &mut PredefinedTypes) -> Vec<TypeE
track_walk(stmt, state);
state.errors().to_vec()
}

/// For a given predicate, find all the implications on the contained types if
/// truthyness of the predicate is equal to the branch parameter.
///
/// For example, in the context where `a: Number | null`, the expr `a` implies
/// `a: Number`.
/// So `predicate_implications(Var("a"), true)` should return `[("a", Number)]`.
/// `predicate_implications(Var("!a"), false)` should
/// return `[("a", Number)]`, because if `!a` is false,
/// then `a` is true.
///
/// More complex examples (all assuming `branch: true`):
///
/// Γ: { a: Number | null, b: Number | null }
/// (a && b) -> [(a: Number), (b: Number)]
///
/// Γ: { a: Number | null }
/// (!!!!a) -> [(a: Number)]
///
/// Γ: { a: Number | null }
/// (a && true) -> [(a: Number)]
///
/// Γ: { a: Number | null }
/// (a && false) -> []
///
/// Γ: { a: Number | null }
/// (!!!a) -> []
pub fn predicate_implications<'a>(
expr: &'a ast::Expr<'a>,
context: &'a mut PredefinedTypes,
branch: bool,
) -> Vec<(String, Type)> {
use ast::Expr::*;
match expr {
Var(var_name) => context
.resolve(var_name.id)
.and_then(|var_type| truthy(&var_type))
.map_or(vec![], |truthy_type| {
if branch {
vec![(var_name.id.to_string(), truthy_type)]
} else {
vec![(var_name.id.to_string(), Type::None)]
}
}),
UnaryOp(unary_op) => {
let next_branch = match unary_op.op {
UnaryOpKind::Not => !branch,
UnaryOpKind::Neg => branch,
};
predicate_implications(&unary_op.expr, context, next_branch)
},
BinOp(binary_op) => {
match binary_op.op {
ast::BinOpKind::ScAnd => {
let mut left_implications = predicate_implications(&binary_op.left, context, branch);
let right_implications = predicate_implications(&binary_op.right, context, branch);
left_implications.extend(right_implications);
left_implications
},
_ => vec![]

}
}
_ => vec![]
}
}

/// Type-narrowing by truthiness. The truthy version of a value's
/// type is a new type that would be implied by the value being truthy.
/// For example, `truthy( Number | null )` is `Number`, because if some
/// value `a: Number | null` is truthy, we can conclude that `a: Number`.
///
/// Some types like `Number` offer no additional information if they
/// are truthy - in these cases we return None.
pub fn truthy(ty: &Type) -> Option<Type> {
match ty {
Type::Unknown => None,
Type::Undefined => None,
Type::None => None,
Type::Int => None,
Type::Float => None,
Type::Number => None,
Type::String => None,
Type::Bool => None,
Type::Literal(_) => None,
Type::List(_) => None,
Type::Map(_, _) => None,
Type::Tuple(_) => None,
Type::Union(variants) => {
let truthy_variants: Vec<Type> = variants
.iter()
.filter(|variant| !NULLISH.contains(variant))
.cloned()
.collect();
match truthy_variants.len() {
0 => None,
1 => Some(truthy_variants[0].clone()),
_ => Some(Type::Union(truthy_variants)),
}
}
Type::Both(x, y) => match (truthy(x), truthy(y)) {
(None, None) => None,
(Some(truthy_x), None) => Some(truthy_x),
(None, Some(truthy_y)) => Some(truthy_y),
(Some(truthy_x), Some(truthy_y)) => {
Some(Type::Both(Box::new(truthy_x), Box::new(truthy_y)))
}
},
Type::ClassRef(_) => None,
Type::FunctionRef(_) => None,
Type::Alias { resolved, .. } => truthy(resolved),
Type::RecursiveTypeAlias(_) => None,
Type::Image => None,
Type::Audio => None,
}
}

const NULLISH: [Type; 2] = [Type::Undefined, Type::None];

#[cfg(test)]
mod tests {
use ast::{Expr, Spanned, Var};
use minijinja::machinery::Span;

use crate::JinjaContext;

use super::*;

#[test]
fn truthy_union() {
let input = Type::Union(vec![Type::ClassRef("Foo".to_string()), Type::Undefined]);
let expected = Type::ClassRef("Foo".to_string());
assert_eq!(truthy(&input).unwrap(), expected);
}

#[test]
fn implication_from_nullable() {
let mut context = PredefinedTypes::default(JinjaContext::Prompt);
context.add_variable("a", Type::Union(vec![Type::Int, Type::None]));
let expr = Expr::Var(Spanned::new(Var{ id: "a"}, Span::default()));
let new_vars = predicate_implications(&expr, &mut context, true);
match new_vars.as_slice() {
[(name, Type::Int)] => {
assert_eq!(name.as_str(), "a");
},
_ => panic!("Expected singleton list with Type::Int"),
}
}
}

0 comments on commit 546f58f

Please sign in to comment.