diff --git a/engine/Cargo.lock b/engine/Cargo.lock index 8414efd68..02a170e0f 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -955,6 +955,7 @@ dependencies = [ "log", "mime", "mime_guess", + "minijinja", "notify-debouncer-full", "pin-project-lite", "pretty_assertions", @@ -1010,6 +1011,7 @@ dependencies = [ "indoc", "internal-baml-codegen", "internal-baml-core", + "itertools 0.13.0", "js-sys", "jsonish", "log", diff --git a/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs b/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs index b5d3a636c..56a46b7d1 100644 --- a/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs +++ b/engine/baml-lib/baml-core/src/ir/jinja_helpers.rs @@ -50,7 +50,7 @@ fn sum_filter(value: Vec) -> Value { /// E.g. `"a|length > 2"` with context `{"a": [1, 2, 3]}` will return `"true"`. pub fn render_expression( expression: &JinjaExpression, - ctx: &HashMap, + ctx: &HashMap, ) -> anyhow::Result { let env = get_env(); // In rust string literals, `{` is escaped as `{{`. @@ -66,8 +66,8 @@ pub fn evaluate_predicate( this: &BamlValue, predicate_expression: &JinjaExpression, ) -> Result { - let ctx: HashMap = - [("this".to_string(), this.clone())].into_iter().collect(); + let ctx: HashMap = + HashMap::from([("this".to_string(), minijinja::Value::from_serialize(this))]); match render_expression(&predicate_expression, &ctx)?.as_ref() { "true" => Ok(true), "false" => Ok(false), @@ -87,11 +87,12 @@ mod tests { "a".to_string(), BamlValue::List( vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into(), - ), + ) + .into(), ), ( "b".to_string(), - BamlValue::String("(123)456-7890".to_string()), + BamlValue::String("(123)456-7890".to_string()).into(), ), ] .into_iter() @@ -118,11 +119,12 @@ mod tests { "a".to_string(), BamlValue::List( vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into(), - ), + ) + .into(), ), ( "b".to_string(), - BamlValue::String("(123)456-7890".to_string()), + BamlValue::String("(123)456-7890".to_string()).into(), ), ] .into_iter() @@ -151,16 +153,12 @@ mod tests { fn test_sum_filter() { let ctx = vec![].into_iter().collect(); assert_eq!( - render_expression(&JinjaExpression( - r#"[1,2]|sum"#.to_string() - ), &ctx).unwrap(), + render_expression(&JinjaExpression(r#"[1,2]|sum"#.to_string()), &ctx).unwrap(), "3" ); assert_eq!( - render_expression(&JinjaExpression( - r#"[1,2.5]|sum"#.to_string() - ), &ctx).unwrap(), + render_expression(&JinjaExpression(r#"[1,2.5]|sum"#.to_string()), &ctx).unwrap(), "3.5" ); } diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index 3d2db25ed..fb6d71fa1 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -8,10 +8,11 @@ use internal_baml_parser_database::{ walkers::{ ClassWalker, ClientSpec as AstClientSpec, ClientWalker, ConfigurationWalker, EnumValueWalker, EnumWalker, FieldWalker, FunctionWalker, TemplateStringWalker, + Walker as AstWalker, }, Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, }; -use internal_baml_schema_ast::ast::SubType; +use internal_baml_schema_ast::ast::{SubType, ValExpId}; use baml_types::JinjaExpression; use internal_baml_schema_ast::ast::{self, FieldArity, WithName, WithSpan}; @@ -676,8 +677,14 @@ impl WithRepr for EnumWalker<'_> { fn repr(&self, db: &ParserDatabase) -> Result { Ok(Enum { name: self.name().to_string(), - values: self.values().map(|w| (w.node(db).map(|v| (v, w.documentation().map(|s| Docstring(s.to_string())))))).collect::,_>>()?, - docstring: self.get_documentation().map(|s| Docstring(s)) + values: self + .values() + .map(|w| { + w.node(db) + .map(|v| (v, w.documentation().map(|s| Docstring(s.to_string())))) + }) + .collect::, _>>()?, + docstring: self.get_documentation().map(|s| Docstring(s)), }) } } @@ -722,7 +729,6 @@ impl WithRepr for FieldWalker<'_> { docstring: self.get_documentation().map(|s| Docstring(s)), }) } - } type ClassId = String; @@ -774,7 +780,7 @@ impl WithRepr for ClassWalker<'_> { .collect::>>()?, None => Vec::new(), }, - docstring: self.get_documentation().map(|s| Docstring(s)) + docstring: self.get_documentation().map(|s| Docstring(s)), }) } } @@ -1110,14 +1116,23 @@ pub struct TestCase { pub name: String, pub functions: Vec>, pub args: IndexMap, + pub constraints: Vec, } impl WithRepr for (&ConfigurationWalker<'_>, usize) { fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { let span = self.0.test_case().functions[self.1].1.clone(); + let constraints = self + .0 + .test_case() + .constraints + .iter() + .map(|(c, _, _)| c) + .cloned() + .collect(); NodeAttributes { meta: IndexMap::new(), - constraints: Vec::new(), + constraints, span: Some(span), } } @@ -1131,10 +1146,17 @@ impl WithRepr for (&ConfigurationWalker<'_>, usize) { impl WithRepr for ConfigurationWalker<'_> { fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { + let constraints = self + .test_case() + .constraints + .iter() + .map(|(c, _, _)| c) + .cloned() + .collect(); NodeAttributes { meta: IndexMap::new(), span: Some(self.span().clone()), - constraints: Vec::new(), + constraints, } } @@ -1151,6 +1173,12 @@ impl WithRepr for ConfigurationWalker<'_> { .map(|(k, (_, v))| Ok((k.clone(), v.repr(db)?))) .collect::>>()?, functions, + constraints: as WithRepr>::attributes( + self, db, + ) + .constraints + .into_iter() + .collect::>(), }) } } @@ -1223,7 +1251,8 @@ mod tests { #[test] fn test_docstrings() { - let ir = make_test_ir(r#" + let ir = make_test_ir( + r#" /// Foo class. class Foo { /// Bar field. @@ -1243,7 +1272,9 @@ mod tests { THIRD } - "#).unwrap(); + "#, + ) + .unwrap(); // Test class docstrings let foo = ir.find_class("Foo").as_ref().unwrap().clone().elem(); @@ -1252,7 +1283,7 @@ mod tests { [field1, field2] => { assert_eq!(field1.elem.docstring.as_ref().unwrap().0, "Bar field."); assert_eq!(field2.elem.docstring.as_ref().unwrap().0, "Baz field."); - }, + } _ => { panic!("Expected 2 fields"); } @@ -1260,7 +1291,10 @@ mod tests { // Test enum docstrings let test_enum = ir.find_enum("TestEnum").as_ref().unwrap().clone().elem(); - assert_eq!(test_enum.docstring.as_ref().unwrap().0.as_str(), "Test enum."); + assert_eq!( + test_enum.docstring.as_ref().unwrap().0.as_str(), + "Test enum." + ); match test_enum.values.as_slice() { [val1, val2, val3] => { assert_eq!(val1.0.elem.0, "FIRST"); @@ -1269,10 +1303,41 @@ mod tests { assert_eq!(val2.1.as_ref().unwrap().0, "Second variant."); assert_eq!(val3.0.elem.0, "THIRD"); assert!(val3.1.is_none()); - }, + } _ => { panic!("Expected 3 enum values"); } } } + + #[test] + fn test_block_attributes() { + let ir = make_test_ir( + r##" + client GPT4 { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } + } + function Foo(a: int) -> int { + client GPT4 + prompt #"Double the number {{ a }}"# + } + + test Foo() { + functions [Foo] + args { + a 10 + } + @@assert( {{ result == 20 }} ) + } + "##, + ) + .unwrap(); + let function = ir.find_function("Foo").unwrap(); + let walker = ir.find_test(&function, "Foo").unwrap(); + assert_eq!(walker.item.1.elem.constraints.len(), 1); + } } diff --git a/engine/baml-lib/baml-core/src/ir/walker.rs b/engine/baml-lib/baml-core/src/ir/walker.rs index 034bf2e3a..86e09bc61 100644 --- a/engine/baml-lib/baml-core/src/ir/walker.rs +++ b/engine/baml-lib/baml-core/src/ir/walker.rs @@ -260,9 +260,9 @@ impl Expression { } Expression::JinjaExpression(expr) => { // TODO: do not coerce all context values to strings. - let jinja_context: HashMap = env_values + let jinja_context: HashMap = env_values .iter() - .map(|(k, v)| (k.clone(), BamlValue::String(v.clone()))) + .map(|(k, v)| (k.clone(), v.clone().into())) .collect(); let res_string = render_expression(&expr, &jinja_context)?; Ok(BamlValue::String(res_string)) diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs index c5c6acb3f..4dfc72aef 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs @@ -5,6 +5,7 @@ mod cycle; mod enums; mod functions; mod template_strings; +mod tests; mod types; use baml_types::GeneratorOutputType; @@ -22,6 +23,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { clients::validate(ctx); template_strings::validate(ctx); configurations::validate(ctx); + tests::validate(ctx); let generators = load_generators_from_ast(ctx.db.ast(), ctx.diagnostics); let codegen_targets: HashSet = generators.into_iter().filter_map(|generator| match generator { diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/tests.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/tests.rs new file mode 100644 index 000000000..288f43992 --- /dev/null +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/tests.rs @@ -0,0 +1,94 @@ +use baml_types::{Constraint, ConstraintLevel}; +use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; +use internal_baml_jinja_types::{validate_expression, JinjaContext, PredefinedTypes, Type}; + +use crate::validate::validation_pipeline::context::Context; + +pub(super) fn validate(ctx: &mut Context<'_>) { + let tests = ctx.db.walk_test_cases().collect::>(); + tests.iter().for_each(|walker| { + let constraints = &walker.test_case().constraints; + let args = &walker.test_case().args; + let mut check_names: Vec = Vec::new(); + for ( + Constraint { + label, + level, + expression, + }, + constraint_span, + expr_span, + ) in constraints.iter() + { + let mut defined_types = PredefinedTypes::default(JinjaContext::Parsing); + defined_types.add_variable("this", Type::Unknown); + defined_types.add_class( + "Checks", + check_names + .iter() + .map(|check_name| (check_name.clone(), Type::Unknown)) + .collect(), + ); + defined_types.add_class( + "_", + vec![ + ("checks".to_string(), Type::ClassRef("Checks".to_string())), + ("result".to_string(), Type::Unknown), + ("latency_ms".to_string(), Type::Number), + ] + .into_iter() + .collect(), + ); + defined_types.add_variable("_", Type::ClassRef("_".to_string())); + args.keys() + .for_each(|arg_name| defined_types.add_variable(arg_name, Type::Unknown)); + match (level, label) { + (ConstraintLevel::Check, Some(check_name)) => { + check_names.push(check_name.to_string()); + } + _ => {} + } + match validate_expression(expression.0.as_str(), &mut defined_types) { + Ok(_) => {} + Err(e) => { + if let Some(e) = e.parsing_errors { + let range = match e.range() { + Some(range) => range, + None => { + ctx.push_error(DatamodelError::new_validation_error( + &format!("Error parsing jinja template: {}", e), + expr_span.clone(), + )); + continue; + } + }; + + let start_offset = expr_span.start + range.start; + let end_offset = expr_span.start + range.end; + + let span = Span::new( + expr_span.file.clone(), + start_offset as usize, + end_offset as usize, + ); + + ctx.push_error(DatamodelError::new_validation_error( + &format!("Error parsing jinja template: {}", e), + span, + )) + } else { + e.errors.iter().for_each(|t| { + let tspan = t.span(); + let span = Span::new( + expr_span.file.clone(), + expr_span.start + tspan.start_offset as usize, + expr_span.start + tspan.end_offset as usize, + ); + ctx.push_warning(DatamodelWarning::new(t.message().to_string(), span)) + }) + } + } + } + } + }); +} diff --git a/engine/baml-lib/baml-types/src/constraint.rs b/engine/baml-lib/baml-types/src/constraint.rs index 16abad0a3..372125d7f 100644 --- a/engine/baml-lib/baml-types/src/constraint.rs +++ b/engine/baml-lib/baml-types/src/constraint.rs @@ -27,7 +27,7 @@ pub enum ConstraintLevel { } /// The user-visible schema for a failed check. -#[derive(Clone, Debug, serde::Serialize)] +#[derive(Clone, Debug, serde::Serialize, PartialEq, Eq)] pub struct ResponseCheck { pub name: String, pub expression: String, diff --git a/engine/baml-lib/parser-database/src/attributes/constraint.rs b/engine/baml-lib/parser-database/src/attributes/constraint.rs index c54a921d5..75ca839bd 100644 --- a/engine/baml-lib/parser-database/src/attributes/constraint.rs +++ b/engine/baml-lib/parser-database/src/attributes/constraint.rs @@ -1,9 +1,70 @@ use baml_types::{Constraint, ConstraintLevel}; use internal_baml_diagnostics::{DatamodelError, Span}; -use internal_baml_schema_ast::ast::{Attribute, Expression}; +use internal_baml_schema_ast::ast::{Argument, Attribute, Expression}; use crate::{context::Context, types::Attributes}; +/// Interpret an attribute as a constraint, the whole constraint's span, +/// and the span of the constraint's jinja expression. +pub fn attribute_as_constraint( + attribute: &Attribute, +) -> (Option<(Constraint, Span, Span)>, Vec) { + let span = attribute.span.clone(); + let mut datamodel_errors = Vec::new(); + let attribute_name = attribute.name.to_string(); + let arguments: Vec = attribute + .arguments + .arguments + .iter() + .map(|Argument { value, .. }| value) + .cloned() + .collect(); + + let level = match attribute_name.as_str() { + "assert" => ConstraintLevel::Assert, + "check" => ConstraintLevel::Check, + _ => { + return (None, datamodel_errors); + } + }; + + let (label, expression, expr_span) = match arguments.as_slice() { + [Expression::JinjaExpressionValue(expression, expr_span)] => { + if level == ConstraintLevel::Check { + datamodel_errors.push(DatamodelError::new_attribute_validation_error( + "Checks must specify a label.", + attribute_name.as_str(), + span.clone(), + )); + } + (None, expression.clone(), expr_span.clone()) + } + [Expression::Identifier(label), Expression::JinjaExpressionValue(expression, expr_span)] => { + ( + Some(label.to_string()), + expression.clone(), + expr_span.clone(), + ) + } + _ => { + datamodel_errors.push( + DatamodelError::new_attribute_validation_error( + "Checks and asserts may have either a label and an expression, or a lone expression.", + attribute_name.as_str(), + span + ) + ); + return (None, datamodel_errors); + } + }; + let constraint = Constraint { + label, + expression, + level, + }; + (Some((constraint, span, expr_span)), datamodel_errors) +} + pub(super) fn visit_constraint_attributes( attribute_name: String, span: Span, diff --git a/engine/baml-lib/parser-database/src/attributes/mod.rs b/engine/baml-lib/parser-database/src/attributes/mod.rs index 95daebcba..7d27531fc 100644 --- a/engine/baml-lib/parser-database/src/attributes/mod.rs +++ b/engine/baml-lib/parser-database/src/attributes/mod.rs @@ -1,7 +1,7 @@ use internal_baml_schema_ast::ast::{Top, TopId, TypeExpId, TypeExpressionBlock}; mod alias; -mod constraint; +pub mod constraint; mod description; mod to_string_attribute; use crate::interner::StringId; diff --git a/engine/baml-lib/parser-database/src/tarjan.rs b/engine/baml-lib/parser-database/src/tarjan.rs index e3559f39e..935969e95 100644 --- a/engine/baml-lib/parser-database/src/tarjan.rs +++ b/engine/baml-lib/parser-database/src/tarjan.rs @@ -1,4 +1,7 @@ //! Tarjan's strongly connected components algorithm for cycle detection. +//! +//! This is used in parser_database to detect cycles in BAML types +//! that reference each other recursively. use std::{ cmp, diff --git a/engine/baml-lib/parser-database/src/types/configurations.rs b/engine/baml-lib/parser-database/src/types/configurations.rs index 95f88c71c..6b09b7b4e 100644 --- a/engine/baml-lib/parser-database/src/types/configurations.rs +++ b/engine/baml-lib/parser-database/src/types/configurations.rs @@ -1,11 +1,17 @@ +use baml_types::Constraint; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; -use internal_baml_schema_ast::ast::{ValExpId, ValueExprBlock, WithIdentifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{ + Attribute, ValExpId, ValueExprBlock, WithIdentifier, WithName, WithSpan, +}; use regex::Regex; use std::collections::HashSet; +use crate::attributes::constraint::attribute_as_constraint; use crate::{coerce, coerce_array, coerce_expression::coerce_map, context::Context}; -use super::{ContantDelayStrategy, ExponentialBackoffStrategy, RetryPolicy, RetryPolicyStrategy}; +use super::{ + Attributes, ContantDelayStrategy, ExponentialBackoffStrategy, RetryPolicy, RetryPolicyStrategy, +}; fn dedent(s: &str) -> String { // Find the shortest indentation in the string (that's not an empty line). @@ -288,6 +294,18 @@ pub(crate) fn visit_test_case<'db>( )), }); + let constraints: Vec<(Constraint, Span, Span)> = config + .attributes + .iter() + .filter_map(|attribute| { + let (maybe_constraint, errors) = attribute_as_constraint(attribute); + for error in errors { + ctx.push_error(error); + } + maybe_constraint + }) + .collect(); + match (functions, args) { (None, _) => ctx.push_error(DatamodelError::new_validation_error( "Missing `functions` property", @@ -304,6 +322,7 @@ pub(crate) fn visit_test_case<'db>( functions, args, args_field_span: args_field_span.clone(), + constraints, }, ); } diff --git a/engine/baml-lib/parser-database/src/types/mod.rs b/engine/baml-lib/parser-database/src/types/mod.rs index ddb13b4ea..9cfb09664 100644 --- a/engine/baml-lib/parser-database/src/types/mod.rs +++ b/engine/baml-lib/parser-database/src/types/mod.rs @@ -5,6 +5,7 @@ use crate::coerce; use crate::types::configurations::visit_test_case; use crate::{context::Context, DatamodelError}; +use baml_types::Constraint; use indexmap::IndexMap; use internal_baml_diagnostics::Span; use internal_baml_prompt_parser::ast::{ChatBlock, PrinterBlock, Variable}; @@ -138,6 +139,7 @@ pub struct TestCase { // The span is the span of the argument (the expression has its own span) pub args: IndexMap, pub args_field_span: Span, + pub constraints: Vec<(Constraint, Span, Span)>, } #[derive(Debug, Clone)] diff --git a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs index ccebc25fa..7619cc12f 100644 --- a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs @@ -64,7 +64,7 @@ pub struct BlockArgs { pub(crate) span: Span, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum ValueExprBlockType { Function, Client, diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index 3373c5024..f0c3ec5c4 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -23,7 +23,7 @@ field_type_with_attr = { field_type ~ (NEWLINE? ~ (field_attribute | trailing_co value_expression_keyword = { FUNCTION_KEYWORD | TEST_KEYWORD | CLIENT_KEYWORD | RETRY_POLICY_KEYWORD | GENERATOR_KEYWORD } value_expression_block = { value_expression_keyword ~ identifier ~ named_argument_list? ~ ARROW? ~ field_type_chain? ~ SPACER_TEXT ~ BLOCK_OPEN ~ value_expression_contents ~ BLOCK_CLOSE } value_expression_contents = { - (value_expression | comment_block | empty_lines | BLOCK_LEVEL_CATCH_ALL)* + (value_expression | comment_block | block_attribute | empty_lines | BLOCK_LEVEL_CATCH_ALL)* } value_expression = { identifier ~ expression? ~ (NEWLINE? ~ field_attribute)* ~ trailing_comment? } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs b/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs index c4544cc9e..bd1341de8 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs @@ -1,10 +1,5 @@ use super::{ - helpers::{parsing_catch_all, Pair}, - parse_comments::*, - parse_field::parse_value_expr, - parse_identifier::parse_identifier, - parse_named_args_list::{parse_function_arg, parse_named_argument_list}, - Rule, + helpers::{parsing_catch_all, Pair}, parse_attribute::parse_attribute, parse_comments::*, parse_field::parse_value_expr, parse_identifier::parse_identifier, parse_named_args_list::{parse_function_arg, parse_named_argument_list}, Rule }; use crate::ast::*; @@ -17,7 +12,7 @@ pub(crate) fn parse_value_expression_block( ) -> Result { let pair_span = pair.as_span(); let mut name: Option = None; - let attributes: Vec = Vec::new(); + let mut attributes: Vec = Vec::new(); let mut input = None; let mut output = None; let mut fields: Vec> = vec![]; @@ -85,6 +80,30 @@ pub(crate) fn parse_value_expression_block( } Rule::comment_block => pending_field_comment = Some(item), + Rule::block_attribute => { + let span = item.as_span(); + let attribute = parse_attribute(item, false, diagnostics); + let value_is_test = sub_type == Some(ValueExprBlockType::Test); + let attribute_name = attribute.name.to_string(); + let attribute_is_constraint = &attribute_name == "check" || &attribute_name == "assert"; + + // Only tests may have block attributes, and the only valid block attributes + // are checks/asserts. + if value_is_test && attribute_is_constraint { + // value_expression_block is compatible with the attribute + attributes.push(attribute); + } else if !value_is_test { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!("Only Tests may contain block-level attributes"), + diagnostics.span(span), + )) + } else { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!("Tests may only contain 'check' or 'assert' attributes"), + diagnostics.span(span), + )) + } + } Rule::empty_lines => {} Rule::BLOCK_LEVEL_CATCH_ALL => { diagnostics.push_error(DatamodelError::new_validation_error( diff --git a/engine/baml-runtime/Cargo.toml b/engine/baml-runtime/Cargo.toml index 680bcd932..798c8070e 100644 --- a/engine/baml-runtime/Cargo.toml +++ b/engine/baml-runtime/Cargo.toml @@ -43,6 +43,7 @@ baml-types = { path = "../baml-lib/baml-types" } internal-baml-core = { path = "../baml-lib/baml-core" } internal-baml-jinja = { path = "../baml-lib/jinja-runtime" } log.workspace = true +minijinja.workspace = true pin-project-lite.workspace = true reqwest-eventsource = "0.6.0" scopeguard.workspace = true diff --git a/engine/baml-runtime/src/constraints.rs b/engine/baml-runtime/src/constraints.rs new file mode 100644 index 000000000..d28a86092 --- /dev/null +++ b/engine/baml-runtime/src/constraints.rs @@ -0,0 +1,469 @@ +use baml_types::{BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, ResponseCheck}; +use internal_baml_core::ir::jinja_helpers::{evaluate_predicate, render_expression}; +use jsonish::BamlValueWithFlags; + +use anyhow::Result; +use indexmap::IndexMap; +use minijinja; +use std::{collections::HashMap, fmt}; + +use crate::internal::llm_client::LLMCompleteResponse; + +/// Evaluate a list of constraints to be applied to a `BamlValueWithFlags`, in +/// the order that the constraints were specified by the user. +/// +/// When a check in a test is evaluated, its results are added to the context +/// so that future constraints can refer to it. +pub fn evaluate_test_constraints( + args: &IndexMap, + value: &BamlValueWithMeta>, + response: &LLMCompleteResponse, + constraints: Vec, +) -> TestConstraintsResult { + // Fold over all the constraints, updating both our success state, and + // our jinja context full of Check results. + // Finally, return the success state. + constraints + .into_iter() + .fold(Accumulator::new(), |acc, constraint| { + step_constraints(args, value, response, acc, constraint) + }) + .result +} + +/// The result of running a series of block-level constraints within a test. +#[derive(Clone, Debug, PartialEq)] +pub enum TestConstraintsResult { + /// Constraint testing finished with the following check + /// results, and optionally a failing assert. + Completed { + checks: Vec<(String, bool)>, + failed_assert: Option, + }, + + /// There was a problem evaluating a constraint. + InternalError { details: String }, +} + +/// State update helper functions. +impl TestConstraintsResult { + pub fn empty() -> Self { + TestConstraintsResult::Completed { + checks: Vec::new(), + failed_assert: None, + } + } + fn checks(self) -> Vec<(String, bool)> { + match self { + TestConstraintsResult::Completed { checks, .. } => checks, + _ => Vec::new(), + } + } + fn add_check_result(self, name: String, result: bool) -> Self { + match self { + TestConstraintsResult::Completed { mut checks, .. } => { + checks.push((name, result)); + TestConstraintsResult::Completed { + checks, + failed_assert: None, + } + } + _ => self, + } + } + fn fail_assert(self, name: Option) -> Self { + match self { + TestConstraintsResult::Completed { checks, .. } => TestConstraintsResult::Completed { + checks, + failed_assert: Some(name.unwrap_or("".to_string())), + }, + _ => self, + } + } +} + +/// The state that we track as we iterate over constraints in the test block. +struct Accumulator { + pub result: TestConstraintsResult, + pub check_results: Vec<(String, minijinja::Value)>, +} + +impl Accumulator { + pub fn new() -> Self { + Accumulator { + result: TestConstraintsResult::Completed { + checks: Vec::new(), + failed_assert: None, + }, + check_results: Vec::new(), + } + } +} + +/// The accumultator function, for running a single constraint +/// and updating the success state and the jinja context. +fn step_constraints( + args: &IndexMap, + value: &BamlValueWithMeta>, + response: &LLMCompleteResponse, + acc: Accumulator, + constraint: Constraint, +) -> Accumulator { + // Short-circuit if we have already had a hard failure. We can skip + // the work in the rest of this function if we have already encountered + // a hard failure. + let already_failed = matches!( + acc.result, + TestConstraintsResult::Completed { + failed_assert: Some(_), + .. + } + ) || matches!(acc.result, TestConstraintsResult::InternalError { .. }); + if already_failed { + return acc; + } + + let mut check_results: Vec<(String, minijinja::Value)> = acc.check_results.clone(); + let check_results_for_jinja = check_results.iter().cloned().collect::>(); + let underscore = minijinja::Value::from_serialize( + vec![ + ("result", minijinja::Value::from_serialize(value)), + ( + "latency_ms", + minijinja::Value::from_serialize(response.latency.as_millis()), + ), + ( + "checks", + minijinja::Value::from_serialize(check_results_for_jinja), + ), + ] + .into_iter() + .collect::>(), + ); + + let ctx = vec![ + ("_".to_string(), underscore), + ("this".to_string(), minijinja::Value::from_serialize(value)), + ] + .into_iter() + .chain( + args.iter() + .map(|(name, value)| (name.to_string(), minijinja::Value::from_serialize(value))), + ) + .collect(); + + let constraint_result_str = render_expression(&constraint.expression, &ctx); + let bool_result_or_internal_error: Result = + match constraint_result_str.as_ref().map(|s| s.as_str()) { + Ok("true") => Ok(true), + Ok("false") => Ok(false), + Ok("") => Ok(false), + Ok(x) => Err(format!("Expected true or false, got {x}.")), + Err(e) => Err(format!("Constraint error: {e:?}")), + }; + + // After running the constraint, we update the checks available in the + // minijinja context. + use ConstraintLevel::*; + + // The next value of the accumulator depends on several factors: + // - Whether we are processing a Check or an Assert. + // - Whether the constraint has a name or not. + // - The current accumulator state. + // In this match block, we use the result + match ( + constraint.level, + constraint.label, + bool_result_or_internal_error, + ) { + // A check ran to completion and succeeded or failed + // (i.e. returned a bool). This updates both the checks jinja context + // and the status. + (Check, Some(check_name), Ok(check_passed)) => { + check_results.push((check_name.clone(), check_passed.into())); + let mut new_checks = match acc.result { + TestConstraintsResult::Completed { checks, .. } => checks, + _ => Vec::new(), + }; + new_checks.push((check_name, check_passed)); + let result = TestConstraintsResult::Completed { + checks: new_checks, + failed_assert: None, + }; + return Accumulator { + result, + check_results, + }; + } + + // Internal error always produces a hard error. + (_, _, Err(e)) => { + return Accumulator { + result: TestConstraintsResult::InternalError { details: e }, + check_results: acc.check_results, + }; + } + + // A check without a name has no effect, and should never be observed, because + // the parser enforces that all checks are named. + (Check, None, _) => { + log::warn!( + "Encountered a check without a name: {:?}", + constraint.expression + ); + return acc; + } + + // A passing assert has no effect. + (Assert, _, Ok(true)) => { + return acc; + } + + // A failing assert is a hard error. + (Assert, maybe_name, Ok(false)) => { + let result = acc.result.fail_assert(maybe_name); + return Accumulator { + result, + check_results, + }; + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::internal::llm_client::{LLMCompleteResponse, LLMCompleteResponseMetadata}; + use baml_types::{ + BamlValueWithMeta, Constraint, ConstraintLevel, JinjaExpression, ResponseCheck, + }; + use internal_baml_jinja::RenderedPrompt; + + use std::collections::HashMap; + + /// Construct a value to use as a test fixture. + /// It aims to combine a mix of: + /// - top-level vs. nested constraints + /// - asserts vs. checks + /// - successes vs. failures + /// + /// Roughly this schema: + /// { + /// "name": { + /// value: "Greg", + /// meta: [ + /// (@assert(good_name, {{ this|length > 0}}), true), + /// (@check(long_name, {{ this|length > 4}}), false), + /// ]}}, + /// "kids": { + /// value: [ + /// { name: { + /// value: "Tao", + /// meta: (same meta as top-level name) + /// }, + /// age: 6 + /// }, + /// { name: { + /// value: "Ellie", + /// meta: (same meta as top-level name, but no failing check) + /// }, + /// age: 3 + /// } + /// ], + /// "meta": [ + /// (@check(has_kids, {{ this|length > 0 }}), true) + /// ] + /// } + /// } + fn mk_value() -> BamlValueWithMeta> { + fn mk_name(name: &str) -> BamlValueWithMeta> { + let meta = vec![ + ResponseCheck { + name: "good_name".to_string(), + expression: "this|length > 0".to_string(), + status: "succeeded".to_string(), + }, + ResponseCheck { + name: "long_name".to_string(), + expression: "this|length > 4".to_string(), + status: if name.len() > 4 { + "succeeded".to_string() + } else { + "failed".to_string() + }, + }, + ]; + BamlValueWithMeta::String(name.to_string(), meta) + } + + fn mk_child(name: &str, age: i64) -> BamlValueWithMeta> { + BamlValueWithMeta::Class( + "child".to_string(), + vec![ + ("name".to_string(), mk_name(name)), + ("age".to_string(), BamlValueWithMeta::Int(age, vec![])), + ] + .into_iter() + .collect(), + vec![], + ) + } + + BamlValueWithMeta::Class( + "parent".to_string(), + vec![ + ("name".to_string(), mk_name("Greg")), + ( + "kids".to_string(), + BamlValueWithMeta::List(vec![mk_child("Tao", 6), mk_child("Ellie", 3)], vec![]), + ), + ] + .into_iter() + .collect(), + vec![], + ) + } + + fn mk_response() -> LLMCompleteResponse { + LLMCompleteResponse { + client: "test_client".to_string(), + model: "test_model".to_string(), + prompt: RenderedPrompt::Completion(String::new()), + request_options: HashMap::new(), + content: String::new(), + start_time: web_time::SystemTime::UNIX_EPOCH, + latency: web_time::Duration::from_millis(500), + metadata: LLMCompleteResponseMetadata { + baml_is_complete: true, + finish_reason: None, + prompt_tokens: None, + output_tokens: None, + total_tokens: None, + }, + } + } + + fn mk_check(label: &str, expr: &str) -> Constraint { + Constraint { + label: Some(label.to_string()), + level: ConstraintLevel::Check, + expression: JinjaExpression(expr.to_string()), + } + } + + fn mk_assert(label: &str, expr: &str) -> Constraint { + Constraint { + label: Some(label.to_string()), + level: ConstraintLevel::Assert, + expression: JinjaExpression(expr.to_string()), + } + } + + fn run_pipeline(constraints: &[Constraint]) -> TestConstraintsResult { + let args = IndexMap::new(); + let value = mk_value(); + let constraints = constraints.into(); + let response = mk_response(); + evaluate_test_constraints(&args, &value, &response, constraints) + } + + #[test] + fn basic_test_constraints() { + let res = run_pipeline(&[mk_assert("has_kids", "_.result.kids|length > 0")]); + assert_eq!( + res, + TestConstraintsResult::Completed { + checks: vec![], + failed_assert: None, + } + ); + } + + #[test] + fn test_dependencies() { + let res = run_pipeline(&[ + mk_check("has_kids", "_.result.kids|length > 0"), + mk_check("not_too_many", "this.kids.length < 100"), + mk_assert("both_pass", "_.checks.has_kids and _.checks.not_too_many"), + ]); + assert_eq!( + res, + TestConstraintsResult::Completed { + checks: vec![ + ("has_kids".to_string(), true), + ("not_too_many".to_string(), true), + ], + failed_assert: None + } + ); + } + + #[test] + fn test_dependencies_non_check() { + let res = run_pipeline(&[ + mk_assert("has_kids", "_.result.kids|length > 0"), + mk_check("not_too_many", "this.kids.length < 100"), + mk_assert("both_pass", "_.checks.has_kids and _.checks.not_too_many"), + ]); + // This constraint set should fail because `has_kids` is an assert, not + // a check, therefore it doesn't get a field in `checks`. + assert_eq!( + res, + TestConstraintsResult::Completed { + checks: vec![("not_too_many".to_string(), true),], + failed_assert: Some("both_pass".to_string()) + } + ); + } + + #[test] + fn test_fast_is_sufficient() { + let res = run_pipeline(&[ + mk_check("has_kids", "_.result.kids|length > 0"), + mk_check("not_too_many", "this.kids.length < 100"), + mk_check("both_pass", "_.checks.has_kids and _.checks.not_too_many"), + mk_assert("either_or", "_.checks.both_pass or _.latency_ms < 1000"), + ]); + assert_eq!( + res, + TestConstraintsResult::Completed { + checks: vec![ + ("has_kids".to_string(), true), + ("not_too_many".to_string(), true), + ("both_pass".to_string(), true), + ], + failed_assert: None + } + ); + } + + #[test] + fn test_failing_checks() { + let res = run_pipeline(&[ + mk_check("has_kids", "_.result.kids|length > 0"), + mk_check("not_too_many", "this.kids.length < 100"), + mk_assert("both_pass", "_.checks.has_kids and _.checks.not_too_many"), + mk_check("no_kids", "this.kids|length == 0"), + mk_check("way_too_many", "this.kids|length > 1000"), + ]); + assert_eq!( + res, + TestConstraintsResult::Completed { + checks: vec![ + ("has_kids".to_string(), true), + ("not_too_many".to_string(), true), + ("no_kids".to_string(), false), + ("way_too_many".to_string(), false) + ], + failed_assert: None + } + ); + } + + #[test] + fn test_internal_error() { + let res = run_pipeline(&[mk_check("faulty", "__.result.kids|length > 0")]); + // This test fails because there is a typo: `__` (double underscore). + assert!(matches!(res, TestConstraintsResult::InternalError { .. })); + } +} diff --git a/engine/baml-runtime/src/internal/llm_client/mod.rs b/engine/baml-runtime/src/internal/llm_client/mod.rs index d61ab0caf..352d0dbee 100644 --- a/engine/baml-runtime/src/internal/llm_client/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/mod.rs @@ -27,10 +27,10 @@ use wasm_bindgen::JsValue; pub type ResponseBamlValue = BamlValueWithMeta>; /// Validate a parsed value, checking asserts and checks. -pub fn parsed_value_to_response(baml_value: &BamlValueWithFlags) -> Result { +pub fn parsed_value_to_response(baml_value: &BamlValueWithFlags) -> ResponseBamlValue { let baml_value_with_meta: BamlValueWithMeta> = baml_value.clone().into(); - Ok(baml_value_with_meta.map_meta(|cs| { + baml_value_with_meta.map_meta(|cs| { cs.iter() .map(|(label, expr, result)| { let status = (if *result { "succeeded" } else { "failed" }).to_string(); @@ -41,7 +41,7 @@ pub fn parsed_value_to_response(baml_value: &BamlValueWithFlags) -> Result (Some(Ok(v.clone())), Some(parsed_value_to_response(&v))), + Some(Ok(v)) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), Some(Err(e)) => (None, Some(Err(e))), None => (None, None), }; diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index 680733c24..74750a8bd 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -66,7 +66,7 @@ where LLMResponse::Success(s) => { let parsed = partial_parse_fn(&s.content); let (parsed, response_value) = match parsed { - Ok(v) => (Some(Ok(v.clone())), Some(parsed_value_to_response(&v))), + Ok(v) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), Err(e) => (None, Some(Err(e))), }; on_event(FunctionResult::new( @@ -103,7 +103,7 @@ where _ => None, }; let (parsed_response, response_value) = match parsed_response { - Some(Ok(v)) => (Some(Ok(v.clone())), Some(parsed_value_to_response(&v))), + Some(Ok(v)) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), Some(Err(e)) => (None, Some(Err(e))), None => (None, None), }; diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index d4659ead4..5308ed79c 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -8,6 +8,7 @@ pub(crate) mod internal; #[cfg(not(target_arch = "wasm32"))] pub mod cli; pub mod client_registry; +pub mod constraints; pub mod errors; pub mod request; mod runtime; @@ -25,6 +26,7 @@ use anyhow::Result; use baml_types::BamlMap; use baml_types::BamlValue; +use baml_types::Constraint; use cfg_if::cfg_if; use client_registry::ClientRegistry; use indexmap::IndexMap; @@ -62,6 +64,9 @@ pub use internal_baml_core::internal_baml_diagnostics; pub use internal_baml_core::internal_baml_diagnostics::Diagnostics as DiagnosticsError; pub use internal_baml_core::ir::{scope_diagnostics, FieldType, IRHelper, TypeValue}; +use crate::constraints::{evaluate_test_constraints, TestConstraintsResult}; +use crate::internal::llm_client::LLMResponse; + #[cfg(not(target_arch = "wasm32"))] static TOKIO_SINGLETON: OnceLock>> = OnceLock::new(); @@ -179,13 +184,27 @@ impl BamlRuntime { } impl BamlRuntime { + pub fn get_test_params_and_constraints( + &self, + function_name: &str, + test_name: &str, + ctx: &RuntimeContext, + ) -> Result<(BamlMap, Vec)> { + let params = self.inner.get_test_params(function_name, test_name, ctx)?; + let constraints = self + .inner + .get_test_constraints(function_name, test_name, &ctx)?; + Ok((params, constraints)) + } + pub fn get_test_params( &self, function_name: &str, test_name: &str, ctx: &RuntimeContext, ) -> Result> { - self.inner.get_test_params(function_name, test_name, ctx) + let (params, _) = self.get_test_params_and_constraints(function_name, test_name, ctx)?; + Ok(params) } pub async fn run_test( @@ -200,40 +219,50 @@ impl BamlRuntime { { let span = self.tracer.start_span(test_name, ctx, &Default::default()); - let response = match ctx.create_ctx(None, None) { - Ok(rctx) => { - let params = self.get_test_params(function_name, test_name, &rctx); - match params { - Ok(params) => match ctx.create_ctx(None, None) { - Ok(rctx_stream) => { - let stream = self.inner.stream_function_impl( - function_name.into(), - ¶ms, - self.tracer.clone(), - rctx_stream, - #[cfg(not(target_arch = "wasm32"))] - self.async_runtime.clone(), - ); - match stream { - Ok(mut stream) => { - let (response, span) = - stream.run(on_event, ctx, None, None).await; - response.map(|res| TestResponse { - function_response: res, - function_span: span, - }) - } - Err(e) => Err(e), - } - } - Err(e) => Err(e), - }, - Err(e) => Err(e), + let run_to_response = || async { + let rctx = ctx.create_ctx(None, None)?; + let (params, constraints) = + self.get_test_params_and_constraints(function_name, test_name, &rctx)?; + let rctx_stream = ctx.create_ctx(None, None)?; + let mut stream = self.inner.stream_function_impl( + function_name.into(), + ¶ms, + self.tracer.clone(), + rctx_stream, + #[cfg(not(target_arch = "wasm32"))] + self.async_runtime.clone(), + )?; + let (response_res, span_uuid) = stream.run(on_event, ctx, None, None).await; + let res = response_res?; + let (_, llm_resp, _, val) = res + .event_chain() + .iter() + .last() + .context("Expected non-empty event chain")?; + let complete_resp = match llm_resp { + LLMResponse::Success(complete_llm_response) => Ok(complete_llm_response), + _ => Err(anyhow::anyhow!("LLM Response was not successful")), + }?; + let test_constraints_result = if constraints.is_empty() { + TestConstraintsResult::empty() + } else { + match val { + Some(Ok(value)) => { + evaluate_test_constraints(¶ms, &value, &complete_resp, constraints) + } + _ => TestConstraintsResult::empty(), } - } - Err(e) => Err(e), + }; + let test_response = Ok(TestResponse { + function_response: res, + function_span: span_uuid, + constraints_result: test_constraints_result, + }); + test_response }; + let response = run_to_response().await; + let mut target_id = None; if let Some(span) = span { #[cfg(not(target_arch = "wasm32"))] diff --git a/engine/baml-runtime/src/runtime/runtime_interface.rs b/engine/baml-runtime/src/runtime/runtime_interface.rs index 62771412c..db850ad73 100644 --- a/engine/baml-runtime/src/runtime/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime/runtime_interface.rs @@ -24,7 +24,7 @@ use crate::{ RuntimeContext, RuntimeInterface, }; use anyhow::{Context, Result}; -use baml_types::{BamlMap, BamlValue}; +use baml_types::{BamlMap, BamlValue, Constraint}; use internal_baml_core::{ internal_baml_diagnostics::SourceFile, ir::{ @@ -281,6 +281,14 @@ impl InternalRuntimeInterface for InternalBamlRuntime { Err(e) => return Err(anyhow::anyhow!("Unable to resolve test params: {:?}", e)), } } + + fn get_test_constraints( + &self, function_name: &str, test_name: &str, ctx: &RuntimeContext + ) -> Result> { + let func = self.get_function(function_name, ctx)?; + let walker = self.ir().find_test(&func, test_name)?; + Ok(walker.item.1.elem.constraints.clone()) + } } impl RuntimeConstructor for InternalBamlRuntime { diff --git a/engine/baml-runtime/src/runtime_interface.rs b/engine/baml-runtime/src/runtime_interface.rs index 396e646f9..fd1149007 100644 --- a/engine/baml-runtime/src/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime_interface.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::{BamlMap, BamlValue}; +use baml_types::{BamlMap, BamlValue, Constraint}; use internal_baml_core::internal_baml_diagnostics::Diagnostics; use internal_baml_core::ir::repr::ClientSpec; use internal_baml_core::ir::{repr::IntermediateRepr, FunctionWalker}; @@ -159,4 +159,11 @@ pub trait InternalRuntimeInterface { test_name: &str, ctx: &RuntimeContext, ) -> Result>; + + fn get_test_constraints( + &self, + function_name: &str, + test_name: &str, + ctx: &RuntimeContext + ) -> Result>; } diff --git a/engine/baml-runtime/src/types/response.rs b/engine/baml-runtime/src/types/response.rs index 02ef5d4cd..f0581091d 100644 --- a/engine/baml-runtime/src/types/response.rs +++ b/engine/baml-runtime/src/types/response.rs @@ -1,5 +1,6 @@ pub use crate::internal::llm_client::LLMResponse; use crate::{ + constraints::TestConstraintsResult, errors::ExposedError, internal::llm_client::{orchestrator::OrchestrationScope, ResponseBamlValue}, }; @@ -182,9 +183,11 @@ impl FunctionResult { } } +#[derive(Debug)] pub struct TestResponse { pub function_response: FunctionResult, pub function_span: Option, + pub constraints_result: TestConstraintsResult, } impl std::fmt::Display for TestResponse { @@ -196,6 +199,7 @@ impl std::fmt::Display for TestResponse { #[derive(Debug, PartialEq, Eq)] pub enum TestStatus<'a> { Pass, + NeedsHumanEval(Vec), Fail(TestFailReason<'a>), } @@ -203,6 +207,9 @@ impl From> for BamlValue { fn from(status: TestStatus) -> Self { match status { TestStatus::Pass => BamlValue::String("pass".to_string()), + TestStatus::NeedsHumanEval(checks) => { + BamlValue::String(format!("checks need human evaluation: {:?}", checks)) + } TestStatus::Fail(r) => BamlValue::String(format!("failed! {:?}", r)), } } @@ -210,9 +217,13 @@ impl From> for BamlValue { #[derive(Debug)] pub enum TestFailReason<'a> { - TestUnspecified(&'a anyhow::Error), + TestUnspecified(anyhow::Error), TestLLMFailure(&'a LLMResponse), TestParseFailure(&'a anyhow::Error), + TestConstraintsFailure { + checks: Vec<(String, bool)>, + failed_assert: Option, + }, } impl PartialEq for TestFailReason<'_> { @@ -235,7 +246,26 @@ impl TestResponse { let func_res = &self.function_response; if let Some(parsed) = func_res.result_with_constraints() { if parsed.is_ok() { - TestStatus::Pass + match self.constraints_result.clone() { + TestConstraintsResult::InternalError { details } => { + TestStatus::Fail(TestFailReason::TestUnspecified(anyhow::anyhow!(details))) + } + TestConstraintsResult::Completed { + checks, + failed_assert, + } => { + let n_failed_checks: usize = + checks.iter().filter(|(_, pass)| !pass).count(); + if failed_assert.is_some() || n_failed_checks > 0 { + TestStatus::Fail(TestFailReason::TestConstraintsFailure { + checks, + failed_assert, + }) + } else { + TestStatus::Pass + } + } + } } else { TestStatus::Fail(TestFailReason::TestParseFailure( parsed.as_ref().unwrap_err(), diff --git a/engine/baml-schema-wasm/Cargo.toml b/engine/baml-schema-wasm/Cargo.toml index 4ca8c00a3..d83bcede2 100644 --- a/engine/baml-schema-wasm/Cargo.toml +++ b/engine/baml-schema-wasm/Cargo.toml @@ -36,6 +36,7 @@ wasm-bindgen-futures = "0.4.42" wasm-logger = { version = "0.2.0" } web-time.workspace = true either = "1.8.1" +itertools = "0.13.0" [dependencies.web-sys] version = "0.3.69" diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index 1e72eb67b..62557b82e 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -20,6 +20,7 @@ use jsonish::deserializer::deserialize_flags::Flag; use jsonish::BamlValueWithFlags; use baml_runtime::internal::llm_client::orchestrator::ExecutionScope; +use itertools::join; use js_sys::Promise; use js_sys::Uint8Array; use serde::{Deserialize, Serialize}; @@ -397,6 +398,7 @@ pub struct WasmFunctionResponse { } #[wasm_bindgen] +#[derive(Debug)] pub struct WasmTestResponse { test_response: anyhow::Result, span: Option, @@ -415,10 +417,13 @@ pub struct WasmParsedTestResponse { } #[wasm_bindgen] +#[derive(Clone, Debug)] pub enum TestStatus { Passed, LLMFailure, ParseFailure, + ConstraintsFailed, + AssertFailed, UnableToRun, } @@ -501,7 +506,19 @@ impl WasmFunctionResponse { } } +// TODO: What is supposed to happen with the serialized baml_value? +// That value has checks nested inside. Are they meant to be removed +// during flattening? Or duplicated into the top-level list of checks? fn flatten_checks(value: &BamlValueWithFlags) -> (serde_json::Value, usize) { + // // Note: (Greg) depending on the goal of this function, we may be able + // // to replace most of it like this: + // let value_with_meta: BamlValueWithMeta> = parsed_value_to_response(value); + // let n_checks: usize = value_with_meta.iter().map(|node| node.meta().len()).sum(); + // let bare_baml_value: BamlValue = value_with_meta.into(); + // let json_value: serde_json::Value = serde_json::to_value(bare_baml_value).unwrap_or( + // "Error converting value to JSON".into() + // ); + type J = serde_json::Value; let checks = value @@ -511,12 +528,7 @@ fn flatten_checks(value: &BamlValueWithFlags) -> (serde_json::Value, usize) { .flat_map(|f| match f { Flag::ConstraintResults(c) => c .iter() - .map(|(label, _expr, b)| { - ( - label.clone(), - *b, - ) - }) + .map(|(label, _expr, b)| (label.clone(), *b)) .collect::>(), _ => vec![], }) @@ -580,10 +592,20 @@ impl WasmTestResponse { match &self.test_response { Ok(t) => match t.status() { baml_runtime::TestStatus::Pass => TestStatus::Passed, + baml_runtime::TestStatus::NeedsHumanEval(_) => TestStatus::ConstraintsFailed, baml_runtime::TestStatus::Fail(r) => match r { baml_runtime::TestFailReason::TestUnspecified(_) => TestStatus::UnableToRun, baml_runtime::TestFailReason::TestLLMFailure(_) => TestStatus::LLMFailure, baml_runtime::TestFailReason::TestParseFailure(_) => TestStatus::ParseFailure, + baml_runtime::TestFailReason::TestConstraintsFailure { + failed_assert, .. + } => { + if failed_assert.is_some() { + TestStatus::AssertFailed + } else { + TestStatus::ConstraintsFailed + } + } }, }, Err(_) => TestStatus::UnableToRun, @@ -645,6 +667,10 @@ impl WasmTestResponse { Ok(r) => match r.status() { baml_runtime::TestStatus::Pass => None, baml_runtime::TestStatus::Fail(r) => r.render_error(), + baml_runtime::TestStatus::NeedsHumanEval(checks) => Some(format!( + "Checks require human validation: {}", + join(checks, ", ") + )), }, Err(e) => Some(format!("{e:#}")), } @@ -759,6 +785,23 @@ impl WithRenderError for baml_runtime::TestFailReason<'_> { baml_runtime::TestFailReason::TestUnspecified(e) => Some(format!("{e:#}")), baml_runtime::TestFailReason::TestLLMFailure(f) => f.render_error(), baml_runtime::TestFailReason::TestParseFailure(e) => Some(format!("{e:#}")), + baml_runtime::TestFailReason::TestConstraintsFailure { + checks, + failed_assert, + } => { + let checks_msg = if checks.len() > 0 { + let check_msgs = checks.into_iter().map(|(name, pass)| { + format!("{name}: {}", if *pass { "Passed" } else { "Failed" }) + }); + format!("Check results:\n{}", join(check_msgs, "\n")) + } else { + String::new() + }; + let assert_msg = failed_assert + .as_ref() + .map_or("".to_string(), |name| format!("\nFailed assert: {name}")); + Some(format!("{checks_msg}{assert_msg}")) + } } } } diff --git a/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts b/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts index 4c7647808..a294054e6 100644 --- a/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts +++ b/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts @@ -10,7 +10,7 @@ export const showTestsAtom = atom(false) export const showClientGraphAtom = atom(false) export type TestStatusType = 'queued' | 'running' | 'done' | 'error' -export type DoneTestStatusType = 'passed' | 'llm_failed' | 'parse_failed' | 'error' +export type DoneTestStatusType = 'passed' | 'llm_failed' | 'parse_failed' | 'constraints_failed' | 'error' export type TestState = | { status: 'queued' @@ -44,11 +44,45 @@ export const statusCountAtom = atom({ passed: 0, llm_failed: 0, parse_failed: 0, + constraints_failed: 0, error: 0, }, error: 0, }) +/// This atom will track the state of the full test suite. +/// 'unknown` means tests haven't been run yet. `pass` means +/// they have all run to completion. +/// 'warn' means at least one check has failed, and `fail` +/// means at least one assert has failed, or an internal error +/// occurred. +export type TestSuiteSummary = 'pass' | 'warn' | 'fail' | 'unknown' +export const testSuiteSummaryAtom = atom('unknown') + +/// For an old summary and a new result, compute the new summary. +/// The new summary will overwrite the old, unless the old one +/// has higher priority. +function updateTestSuiteState(old_result: TestSuiteSummary, new_result: TestSuiteSummary): TestSuiteSummary { + function priority(x: TestSuiteSummary): number { + switch (x) { + case 'unknown': + return 0 + case 'pass': + return 1 + case 'warn': + return 2 + case 'fail': + return 3 + } + } + + if (priority(new_result) > priority(old_result)) { + return new_result + } else { + return old_result + } +} + export const useRunHooks = () => { const isRunning = useAtomValue(isRunningAtom) @@ -68,6 +102,7 @@ export const useRunHooks = () => { } set(isRunningAtom, true) set(showTestsAtom, true) + set(testSuiteSummaryAtom, 'unknown') vscode.postMessage({ command: 'telemetry', @@ -92,6 +127,7 @@ export const useRunHooks = () => { passed: 0, llm_failed: 0, parse_failed: 0, + constraints_failed: 0, error: 0, }, error: 0, @@ -144,7 +180,7 @@ export const useRunHooks = () => { const { res, elapsed } = result.value // console.log('result', i, result.value.res.llm_response(), 'batch[i]', batch[i]) - let status = res.status() + let status: Number = res.status() let response_status: DoneTestStatusType = 'error' if (status === 0) { response_status = 'passed' @@ -152,6 +188,8 @@ export const useRunHooks = () => { response_status = 'llm_failed' } else if (status === 2) { response_status = 'parse_failed' + } else if (status === 3 || status === 4) { + response_status = 'constraints_failed' } else { response_status = 'error' } @@ -171,6 +209,23 @@ export const useRunHooks = () => { running: prev.running - 1, } }) + + let newTestSuiteStatus: TestSuiteSummary = 'unknown' + if (status === 0) { + newTestSuiteStatus = 'pass' + } else if (status === 1) { + newTestSuiteStatus = 'fail' + } else if (status === 2) { + newTestSuiteStatus = 'fail' + } else if (status === 3) { + newTestSuiteStatus = 'warn' + } else if (status === 4) { + newTestSuiteStatus = 'fail' + } + + let currentSummary = get(testSuiteSummaryAtom) + let updatedSummary = updateTestSuiteState(currentSummary, newTestSuiteStatus) + set(testSuiteSummaryAtom, updatedSummary) } else { set(testStatusAtom(batch[i]), { status: 'error', message: `${result.reason}` }) set(statusCountAtom, (prev) => { diff --git a/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx b/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx index b73114f8e..abd80b5cb 100644 --- a/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx +++ b/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx @@ -9,6 +9,7 @@ import { runningTestsAtom, statusCountAtom, testStatusAtom, + testSuiteSummaryAtom, DoneTestStatusType, useRunHooks, showTestsAtom, @@ -57,6 +58,8 @@ const TestStatusMessage: React.FC<{ testStatus: DoneTestStatusType }> = ({ testS return
LLM Failed
case 'parse_failed': return
Parse Failed
+ case 'constraints_failed': + return
Constraints Failed
case 'error': return
Unable to run
} @@ -98,8 +101,10 @@ const TestStatusIcon: React.FC<{ ) } -type FilterValues = 'queued' | 'running' | 'error' | 'llm_failed' | 'parse_failed' | 'passed' -const filterAtom = atom(new Set(['running', 'error', 'llm_failed', 'parse_failed', 'passed'])) +type FilterValues = 'queued' | 'running' | 'error' | 'llm_failed' | 'parse_failed' | 'constraints_failed' | 'passed' +const filterAtom = atom( + new Set(['running', 'error', 'llm_failed', 'parse_failed', 'constraints_failed', 'passed']), +) const checkFilter = (filter: Set, status: TestStatusType, test_status?: DoneTestStatusType) => { if (filter.size === 0) { @@ -218,7 +223,9 @@ const ParsedTestResult: React.FC<{ doneStatus: string; parsed?: WasmParsedTestRe ) : ( <> - {failure &&
{failure}
} + {failure && doneStatus === 'parse_failed' && ( +
{failure}
+ )} {parsed !== undefined && ( <> { count={statusCounts.done.parse_failed} onClick={() => toggleFilter('parse_failed')} /> + toggleFilter('constraints_failed')} + /> { const selectedFunction = useAtomValue(selectedFunctionAtom) const [showTests, setShowTests] = useAtom(showTestsAtom) const [showClientGraph, setClientGraph] = useAtom(showClientGraphAtom) + const [testSuiteSummary] = useAtom(testSuiteSummaryAtom) // reset the tab when switching funcs useEffect(() => { @@ -660,6 +674,22 @@ const TestResults: React.FC = () => { }, [selectedFunction?.name]) const isNextJs = (window as any).next?.version + let testSuiteIcon = + switch (testSuiteSummary) { + case 'fail': + testSuiteIcon = + break + case 'pass': + testSuiteIcon = + break + case 'warn': + testSuiteIcon = ⚠️ + break + case 'unknown': + testSuiteIcon = + break + } + return (
@@ -689,7 +719,9 @@ const TestResults: React.FC = () => { setClientGraph(false) }} > - Test Results +
+ Test Results {testSuiteIcon} +