From a578cc287abbd9c23697adc4c83bcf0979916fcf Mon Sep 17 00:00:00 2001 From: hellovai Date: Mon, 11 Nov 2024 14:40:19 -0800 Subject: [PATCH] Add ability to validate types for template strings (#1161) > [!IMPORTANT] > Add type validation for template strings in BAML engine with new validation logic and test cases. > > - **Behavior**: > - Add `template_strings` module to `validations.rs` and integrate it into the validation pipeline. > - Implement `validate()` in `template_strings.rs` to check template string types and handle errors. > - Update `functions.rs` to make `has_checks_nested()` public for use in template validation. > - **Walkers**: > - Add `walk_input_args()` to `TemplateStringWalker` in `template_string.rs` to iterate over template arguments. > - Add `ArgWalker` type to handle template string arguments. > - **Tests**: > - Add test cases in `bad_calls.baml`, `good_calls.baml`, and `invalid.baml` to verify template string validation. > - **Misc**: > - Minor documentation updates in `class.rs` and `mod.rs`. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral) for 0881047aa4d51b7e0438dfeb3dfe056182fd0fc2. It will automatically update as commits are pushed. --- .../validation_pipeline/validations.rs | 2 + .../validations/functions.rs | 2 +- .../validations/template_strings.rs | 120 ++++++++++++++++++ .../functions_v2/prompt_errors/prompt1.baml | 6 + .../template_string/bad_calls.baml | 80 ++++++++++++ .../template_string/good_calls.baml | 11 ++ .../template_string/invalid.baml | 34 +++++ .../parser-database/src/walkers/class.rs | 2 + .../parser-database/src/walkers/mod.rs | 5 +- .../src/walkers/template_string.rs | 51 +++++++- 10 files changed, 308 insertions(+), 5 deletions(-) create mode 100644 engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs create mode 100644 engine/baml-lib/baml/tests/validation_files/template_string/bad_calls.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/template_string/good_calls.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/template_string/invalid.baml diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs index 72c9f11e3..b595c7f96 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs @@ -4,6 +4,7 @@ mod configurations; mod cycle; mod enums; mod functions; +mod template_strings; mod types; use super::context::Context; @@ -13,6 +14,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { classes::validate(ctx); functions::validate(ctx); clients::validate(ctx); + template_strings::validate(ctx); configurations::validate(ctx); if !ctx.diagnostics.has_errors() { diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs index de7322718..2296e4191 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs @@ -220,7 +220,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { /// Just syntactic sugar for the recursive check. /// /// See [`NestedChecks::has_checks_nested`]. -fn has_checks_nested(ctx: &Context<'_>, field_type: &FieldType) -> bool { +pub(super) fn has_checks_nested(ctx: &Context<'_>, field_type: &FieldType) -> bool { NestedChecks::new(ctx).has_checks_nested(field_type) } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs new file mode 100644 index 000000000..d8a33485c --- /dev/null +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs @@ -0,0 +1,120 @@ +use std::collections::HashSet; + +use crate::validate::validation_pipeline::context::Context; + +use either::Either; +use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; + +use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; + +use super::types::validate_type; + +pub(super) fn validate(ctx: &mut Context<'_>) { + let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default( + internal_baml_jinja_types::JinjaContext::Prompt, + ); + ctx.db.walk_classes().for_each(|t| { + t.add_to_types(&mut defined_types); + }); + ctx.db.walk_templates().for_each(|t| { + t.add_to_types(&mut defined_types); + }); + + for template in ctx.db.walk_templates() { + for args in template.walk_input_args() { + let arg = args.ast_arg(); + validate_type(ctx, &arg.1.field_type); + } + + for args in template.walk_input_args() { + let arg = args.ast_arg(); + let field_type = &arg.1.field_type; + + let span = field_type.span().clone(); + if super::functions::has_checks_nested(ctx, field_type) { + ctx.push_error(DatamodelError::new_validation_error( + "Types with checks are not allowed as function parameters.", + span, + )); + } + } + + let prompt = match template.template_raw() { + Some(p) => p, + None => { + ctx.push_error(DatamodelError::new_validation_error( + "Template string must be a raw string literal like `template_string MyTemplate(myArg: string) #\"\n\n\"#`", + template.identifier().span().clone(), + )); + continue; + } + }; + + defined_types.start_scope(); + + template.walk_input_args().for_each(|arg| { + let name = match arg.ast_arg().0 { + Some(arg) => arg.name().to_string(), + None => { + ctx.push_error(DatamodelError::new_validation_error( + "Argument name is missing.", + arg.ast_arg().1.span().clone(), + )); + return; + } + }; + + let field_type = ctx.db.to_jinja_type(&arg.ast_arg().1.field_type); + + defined_types.add_variable(&name, field_type); + }); + match internal_baml_jinja_types::validate_template( + template.name(), + prompt.raw_value(), + &mut defined_types, + ) { + Ok(_) => {} + Err(e) => { + let pspan = prompt.span(); + if let Some(e) = e.parsing_errors { + let range = match e.range() { + Some(range) => range, + None => { + ctx.push_error(DatamodelError::new_validation_error( + &format!("Error parsing jinja template: {}", e), + pspan.clone(), + )); + continue; + } + }; + + let start_offset = pspan.start + range.start; + let end_offset = pspan.start + range.end; + + let span = Span::new( + pspan.file.clone(), + start_offset as usize, + end_offset as usize, + ); + + ctx.push_error(DatamodelError::new_validation_error( + &format!("Error parsing jinja template: {}", e), + span, + )) + } else { + e.errors.iter().for_each(|t| { + let span = t.span(); + let span = Span::new( + pspan.file.clone(), + pspan.start + span.start_offset as usize, + pspan.start + span.end_offset as usize, + ); + ctx.push_warning(DatamodelWarning::new(t.message().to_string(), span)) + }) + } + } + } + defined_types.end_scope(); + defined_types.errors_mut().clear(); + } +} diff --git a/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml b/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml index 76396e2ce..bf2e51b4a 100644 --- a/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml +++ b/engine/baml-lib/baml/tests/validation_files/functions_v2/prompt_errors/prompt1.baml @@ -50,3 +50,9 @@ function Bar1(a: string) -> int { // 23 | prompt #" // 24 | {{ Foo(a) }} // | +// warning: Variable `b` does not exist. Did you mean one of these: `_`, `ctx`? +// --> functions_v2/prompt_errors/prompt1.baml:6 +// | +// 5 | template_string Foo() #" +// 6 | This! {{ b}} +// | diff --git a/engine/baml-lib/baml/tests/validation_files/template_string/bad_calls.baml b/engine/baml-lib/baml/tests/validation_files/template_string/bad_calls.baml new file mode 100644 index 000000000..51fedd4ef --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/template_string/bad_calls.baml @@ -0,0 +1,80 @@ +template_string WithParams(a: int) #" + ... +"# + +template_string BadCall1 #" + {{ WithParams(a=2, b=2) }} +"# + +template_string BadCall2 #" + {{ WithParams("a") }} +"# + +template_string BadCall3 #" + {{ WithParams() }} +"# + +template_string BadCall4 #" + {{ Random(2) }} +"# + +// warning: Function 'WithParams' expects 1 arguments, but got 2 +// --> template_string/bad_calls.baml:6 +// | +// 5 | template_string BadCall1 #" +// 6 | {{ WithParams(a=2, b=2) }} +// | +// warning: Function 'WithParams' expects argument 'a' to be of type int, but got literal["a"] +// --> template_string/bad_calls.baml:10 +// | +// 9 | template_string BadCall2 #" +// 10 | {{ WithParams("a") }} +// | +// warning: Function 'WithParams' expects 1 arguments, but got 0 +// --> template_string/bad_calls.baml:14 +// | +// 13 | template_string BadCall3 #" +// 14 | {{ WithParams() }} +// | +// warning: Variable `Random` does not exist. Did you mean one of these: `_`, `ctx`? +// --> template_string/bad_calls.baml:18 +// | +// 17 | template_string BadCall4 #" +// 18 | {{ Random(2) }} +// | +// warning: 'Random' is undefined, expected function +// --> template_string/bad_calls.baml:18 +// | +// 17 | template_string BadCall4 #" +// 18 | {{ Random(2) }} +// | +// warning: Function 'WithParams' expects 1 arguments, but got 2 +// --> template_string/bad_calls.baml:6 +// | +// 5 | template_string BadCall1 #" +// 6 | {{ WithParams(a=2, b=2) }} +// | +// warning: Function 'WithParams' expects argument 'a' to be of type int, but got literal["a"] +// --> template_string/bad_calls.baml:10 +// | +// 9 | template_string BadCall2 #" +// 10 | {{ WithParams("a") }} +// | +// warning: Function 'WithParams' expects 1 arguments, but got 0 +// --> template_string/bad_calls.baml:14 +// | +// 13 | template_string BadCall3 #" +// 14 | {{ WithParams() }} +// | +// warning: Variable `Random` does not exist. Did you mean one of these: `_`, `ctx`? +// --> template_string/bad_calls.baml:18 +// | +// 17 | template_string BadCall4 #" +// 18 | {{ Random(2) }} +// | +// warning: 'Random' is undefined, expected function +// --> template_string/bad_calls.baml:18 +// | +// 17 | template_string BadCall4 #" +// 18 | {{ Random(2) }} +// | diff --git a/engine/baml-lib/baml/tests/validation_files/template_string/good_calls.baml b/engine/baml-lib/baml/tests/validation_files/template_string/good_calls.baml new file mode 100644 index 000000000..97e7fbf9c --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/template_string/good_calls.baml @@ -0,0 +1,11 @@ +template_string WithParams(a: int) #" + ... +"# + +template_string GoodCall1 #" + {{ WithParams(a=2) }} +"# + +template_string GoodCall2 #" + {{ WithParams(2) }} +"# diff --git a/engine/baml-lib/baml/tests/validation_files/template_string/invalid.baml b/engine/baml-lib/baml/tests/validation_files/template_string/invalid.baml new file mode 100644 index 000000000..242d4f266 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/template_string/invalid.baml @@ -0,0 +1,34 @@ +template_string FunctionWithBadParams( + param: Unknown, + param2: Unknown2[], + param3: string +) #" + {{ param.foo }} + {{ param2[0].doc }} + {{ param3 }} +"# + +// warning: 'param' is undefined, expected class +// --> template_string/invalid.baml:6 +// | +// 5 | ) #" +// 6 | {{ param.foo }} +// | +// warning: 'param' is undefined, expected class +// --> template_string/invalid.baml:6 +// | +// 5 | ) #" +// 6 | {{ param.foo }} +// | +// error: Type `Unknown` does not exist. Did you mean one of these: `int`, `float`, `bool`, `string`, `true`, `false`? +// --> template_string/invalid.baml:2 +// | +// 1 | template_string FunctionWithBadParams( +// 2 | param: Unknown, +// | +// error: Type `Unknown2` does not exist. Did you mean one of these: `string`, `int`, `float`, `bool`, `true`, `false`? +// --> template_string/invalid.baml:3 +// | +// 2 | param: Unknown, +// 3 | param2: Unknown2[], +// | diff --git a/engine/baml-lib/parser-database/src/walkers/class.rs b/engine/baml-lib/parser-database/src/walkers/class.rs index b3d7cc072..5dd2500b9 100644 --- a/engine/baml-lib/parser-database/src/walkers/class.rs +++ b/engine/baml-lib/parser-database/src/walkers/class.rs @@ -128,6 +128,8 @@ impl<'db> ClassWalker<'db> { } } } + +/// An argument of a function. pub type ArgWalker<'db> = super::Walker<'db, (ast::TypeExpId, bool, ArgumentId)>; impl<'db> ArgWalker<'db> { diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index b9f2143bd..9d1fd2a92 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -19,13 +19,12 @@ pub use client::*; pub use configuration::*; use either::Either; pub use field::*; -pub use function::*; +pub use function::{FunctionWalker, ClientSpec}; +pub use template_string::TemplateStringWalker; use internal_baml_schema_ast::ast::{FieldType, Identifier, TopId, TypeExpId, WithName}; pub use r#class::*; pub use r#enum::*; -pub use self::template_string::TemplateStringWalker; - /// A generic walker. Only walkers intantiated with a concrete ID type (`I`) are useful. #[derive(Clone, Copy)] pub struct Walker<'db, I> { diff --git a/engine/baml-lib/parser-database/src/walkers/template_string.rs b/engine/baml-lib/parser-database/src/walkers/template_string.rs index 6ac6e6d44..25e6aae5a 100644 --- a/engine/baml-lib/parser-database/src/walkers/template_string.rs +++ b/engine/baml-lib/parser-database/src/walkers/template_string.rs @@ -1,6 +1,6 @@ use either::Either; use internal_baml_jinja_types::{PredefinedTypes, Type}; -use internal_baml_schema_ast::ast::{self, BlockArgs, Span, WithIdentifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{self, ArgumentId, BlockArgs, Span, WithIdentifier, WithName, WithSpan}; use crate::types::TemplateStringProperties; @@ -29,6 +29,23 @@ impl<'db> TemplateStringWalker<'db> { &self.metadata().template } + /// Walk the input arguments of the template string. + pub fn walk_input_args(self) -> impl ExactSizeIterator> { + match self.ast_node().input() { + Some(input) => { + let range_end = input.iter_args().len() as u32; + (0..range_end) + .map(move |f| ArgWalker { + db: self.db, + id: (self.id, ArgumentId(f)), + }) + .collect::>() + .into_iter() + } + None => Vec::new().into_iter(), + } + } + /// The name of the template string. pub fn add_to_types(self, types: &mut PredefinedTypes) { let name = self.name(); @@ -59,3 +76,35 @@ impl<'a> WithSpan for TemplateStringWalker<'a> { self.ast_node().span() } } + + +pub type ArgWalker<'db> = super::Walker<'db, (ast::TemplateStringId, ArgumentId)>; + +impl<'db> ArgWalker<'db> { + /// The ID of the function in the db + pub fn block_id(self) -> ast::TemplateStringId { + self.id.0 + } + + /// The AST node. + pub fn ast_type_block(self) -> &'db ast::TemplateString { + &self.db.ast[self.id.0] + } + + /// The AST node. + pub fn ast_arg(self) -> (Option<&'db ast::Identifier>, &'db ast::BlockArg) { + let args = self.ast_type_block().input(); + let res: &_ = &args.expect("Expected input args")[self.id.1]; + (Some(&res.0), &res.1) + } + + /// The name of the type. + pub fn field_type(self) -> &'db ast::FieldType { + &self.ast_arg().1.field_type + } + + /// The name of the function. + pub fn is_optional(self) -> bool { + self.field_type().is_optional() + } +} \ No newline at end of file