Skip to content

Commit

Permalink
Add ability to validate types for template strings (#1161)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Add type validation for template strings in BAML engine with new
validation logic and test cases.
> 
>   - **Behavior**:
> - Add `template_strings` module to `validations.rs` and integrate it
into the validation pipeline.
> - Implement `validate()` in `template_strings.rs` to check template
string types and handle errors.
> - Update `functions.rs` to make `has_checks_nested()` public for use
in template validation.
>   - **Walkers**:
> - Add `walk_input_args()` to `TemplateStringWalker` in
`template_string.rs` to iterate over template arguments.
>     - Add `ArgWalker` type to handle template string arguments.
>   - **Tests**:
> - Add test cases in `bad_calls.baml`, `good_calls.baml`, and
`invalid.baml` to verify template string validation.
>   - **Misc**:
>     - Minor documentation updates in `class.rs` and `mod.rs`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 0881047. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Nov 11, 2024
1 parent f103d52 commit a578cc2
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod configurations;
mod cycle;
mod enums;
mod functions;
mod template_strings;
mod types;

use super::context::Context;
Expand All @@ -13,6 +14,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
classes::validate(ctx);
functions::validate(ctx);
clients::validate(ctx);
template_strings::validate(ctx);
configurations::validate(ctx);

if !ctx.diagnostics.has_errors() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
/// Just syntactic sugar for the recursive check.
///
/// See [`NestedChecks::has_checks_nested`].
fn has_checks_nested(ctx: &Context<'_>, field_type: &FieldType) -> bool {
pub(super) fn has_checks_nested(ctx: &Context<'_>, field_type: &FieldType) -> bool {
NestedChecks::new(ctx).has_checks_nested(field_type)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::collections::HashSet;

use crate::validate::validation_pipeline::context::Context;

use either::Either;
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};

use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan};

use super::types::validate_type;

pub(super) fn validate(ctx: &mut Context<'_>) {
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Prompt,
);
ctx.db.walk_classes().for_each(|t| {
t.add_to_types(&mut defined_types);
});
ctx.db.walk_templates().for_each(|t| {
t.add_to_types(&mut defined_types);
});

for template in ctx.db.walk_templates() {
for args in template.walk_input_args() {
let arg = args.ast_arg();
validate_type(ctx, &arg.1.field_type);
}

for args in template.walk_input_args() {
let arg = args.ast_arg();
let field_type = &arg.1.field_type;

let span = field_type.span().clone();
if super::functions::has_checks_nested(ctx, field_type) {
ctx.push_error(DatamodelError::new_validation_error(
"Types with checks are not allowed as function parameters.",
span,
));
}
}

let prompt = match template.template_raw() {
Some(p) => p,
None => {
ctx.push_error(DatamodelError::new_validation_error(
"Template string must be a raw string literal like `template_string MyTemplate(myArg: string) #\"\n\n\"#`",
template.identifier().span().clone(),
));
continue;
}
};

defined_types.start_scope();

template.walk_input_args().for_each(|arg| {
let name = match arg.ast_arg().0 {
Some(arg) => arg.name().to_string(),
None => {
ctx.push_error(DatamodelError::new_validation_error(
"Argument name is missing.",
arg.ast_arg().1.span().clone(),
));
return;
}
};

let field_type = ctx.db.to_jinja_type(&arg.ast_arg().1.field_type);

defined_types.add_variable(&name, field_type);
});
match internal_baml_jinja_types::validate_template(
template.name(),
prompt.raw_value(),
&mut defined_types,
) {
Ok(_) => {}
Err(e) => {
let pspan = prompt.span();
if let Some(e) = e.parsing_errors {
let range = match e.range() {
Some(range) => range,
None => {
ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
pspan.clone(),
));
continue;
}
};

let start_offset = pspan.start + range.start;
let end_offset = pspan.start + range.end;

let span = Span::new(
pspan.file.clone(),
start_offset as usize,
end_offset as usize,
);

ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span,
))
} else {
e.errors.iter().for_each(|t| {
let span = t.span();
let span = Span::new(
pspan.file.clone(),
pspan.start + span.start_offset as usize,
pspan.start + span.end_offset as usize,
);
ctx.push_warning(DatamodelWarning::new(t.message().to_string(), span))
})
}
}
}
defined_types.end_scope();
defined_types.errors_mut().clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ function Bar1(a: string) -> int {
// 23 | prompt #"
// 24 | {{ Foo(a) }}
// |
// warning: Variable `b` does not exist. Did you mean one of these: `_`, `ctx`?
// --> functions_v2/prompt_errors/prompt1.baml:6
// |
// 5 | template_string Foo() #"
// 6 | This! {{ b}}
// |
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
template_string WithParams(a: int) #"
...
"#

template_string BadCall1 #"
{{ WithParams(a=2, b=2) }}
"#

template_string BadCall2 #"
{{ WithParams("a") }}
"#

template_string BadCall3 #"
{{ WithParams() }}
"#

template_string BadCall4 #"
{{ Random(2) }}
"#

// warning: Function 'WithParams' expects 1 arguments, but got 2
// --> template_string/bad_calls.baml:6
// |
// 5 | template_string BadCall1 #"
// 6 | {{ WithParams(a=2, b=2) }}
// |
// warning: Function 'WithParams' expects argument 'a' to be of type int, but got literal["a"]
// --> template_string/bad_calls.baml:10
// |
// 9 | template_string BadCall2 #"
// 10 | {{ WithParams("a") }}
// |
// warning: Function 'WithParams' expects 1 arguments, but got 0
// --> template_string/bad_calls.baml:14
// |
// 13 | template_string BadCall3 #"
// 14 | {{ WithParams() }}
// |
// warning: Variable `Random` does not exist. Did you mean one of these: `_`, `ctx`?
// --> template_string/bad_calls.baml:18
// |
// 17 | template_string BadCall4 #"
// 18 | {{ Random(2) }}
// |
// warning: 'Random' is undefined, expected function
// --> template_string/bad_calls.baml:18
// |
// 17 | template_string BadCall4 #"
// 18 | {{ Random(2) }}
// |
// warning: Function 'WithParams' expects 1 arguments, but got 2
// --> template_string/bad_calls.baml:6
// |
// 5 | template_string BadCall1 #"
// 6 | {{ WithParams(a=2, b=2) }}
// |
// warning: Function 'WithParams' expects argument 'a' to be of type int, but got literal["a"]
// --> template_string/bad_calls.baml:10
// |
// 9 | template_string BadCall2 #"
// 10 | {{ WithParams("a") }}
// |
// warning: Function 'WithParams' expects 1 arguments, but got 0
// --> template_string/bad_calls.baml:14
// |
// 13 | template_string BadCall3 #"
// 14 | {{ WithParams() }}
// |
// warning: Variable `Random` does not exist. Did you mean one of these: `_`, `ctx`?
// --> template_string/bad_calls.baml:18
// |
// 17 | template_string BadCall4 #"
// 18 | {{ Random(2) }}
// |
// warning: 'Random' is undefined, expected function
// --> template_string/bad_calls.baml:18
// |
// 17 | template_string BadCall4 #"
// 18 | {{ Random(2) }}
// |
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
template_string WithParams(a: int) #"
...
"#

template_string GoodCall1 #"
{{ WithParams(a=2) }}
"#

template_string GoodCall2 #"
{{ WithParams(2) }}
"#
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
template_string FunctionWithBadParams(
param: Unknown,
param2: Unknown2[],
param3: string
) #"
{{ param.foo }}
{{ param2[0].doc }}
{{ param3 }}
"#

// warning: 'param' is undefined, expected class
// --> template_string/invalid.baml:6
// |
// 5 | ) #"
// 6 | {{ param.foo }}
// |
// warning: 'param' is undefined, expected class
// --> template_string/invalid.baml:6
// |
// 5 | ) #"
// 6 | {{ param.foo }}
// |
// error: Type `Unknown` does not exist. Did you mean one of these: `int`, `float`, `bool`, `string`, `true`, `false`?
// --> template_string/invalid.baml:2
// |
// 1 | template_string FunctionWithBadParams(
// 2 | param: Unknown,
// |
// error: Type `Unknown2` does not exist. Did you mean one of these: `string`, `int`, `float`, `bool`, `true`, `false`?
// --> template_string/invalid.baml:3
// |
// 2 | param: Unknown,
// 3 | param2: Unknown2[],
// |
2 changes: 2 additions & 0 deletions engine/baml-lib/parser-database/src/walkers/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ impl<'db> ClassWalker<'db> {
}
}
}

/// An argument of a function.
pub type ArgWalker<'db> = super::Walker<'db, (ast::TypeExpId, bool, ArgumentId)>;

impl<'db> ArgWalker<'db> {
Expand Down
5 changes: 2 additions & 3 deletions engine/baml-lib/parser-database/src/walkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ pub use client::*;
pub use configuration::*;
use either::Either;
pub use field::*;
pub use function::*;
pub use function::{FunctionWalker, ClientSpec};
pub use template_string::TemplateStringWalker;
use internal_baml_schema_ast::ast::{FieldType, Identifier, TopId, TypeExpId, WithName};
pub use r#class::*;
pub use r#enum::*;

pub use self::template_string::TemplateStringWalker;

/// A generic walker. Only walkers intantiated with a concrete ID type (`I`) are useful.
#[derive(Clone, Copy)]
pub struct Walker<'db, I> {
Expand Down
51 changes: 50 additions & 1 deletion engine/baml-lib/parser-database/src/walkers/template_string.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use either::Either;
use internal_baml_jinja_types::{PredefinedTypes, Type};
use internal_baml_schema_ast::ast::{self, BlockArgs, Span, WithIdentifier, WithName, WithSpan};
use internal_baml_schema_ast::ast::{self, ArgumentId, BlockArgs, Span, WithIdentifier, WithName, WithSpan};

use crate::types::TemplateStringProperties;

Expand Down Expand Up @@ -29,6 +29,23 @@ impl<'db> TemplateStringWalker<'db> {
&self.metadata().template
}

/// Walk the input arguments of the template string.
pub fn walk_input_args(self) -> impl ExactSizeIterator<Item = ArgWalker<'db>> {
match self.ast_node().input() {
Some(input) => {
let range_end = input.iter_args().len() as u32;
(0..range_end)
.map(move |f| ArgWalker {
db: self.db,
id: (self.id, ArgumentId(f)),
})
.collect::<Vec<_>>()
.into_iter()
}
None => Vec::new().into_iter(),
}
}

/// The name of the template string.
pub fn add_to_types(self, types: &mut PredefinedTypes) {
let name = self.name();
Expand Down Expand Up @@ -59,3 +76,35 @@ impl<'a> WithSpan for TemplateStringWalker<'a> {
self.ast_node().span()
}
}


pub type ArgWalker<'db> = super::Walker<'db, (ast::TemplateStringId, ArgumentId)>;

impl<'db> ArgWalker<'db> {
/// The ID of the function in the db
pub fn block_id(self) -> ast::TemplateStringId {
self.id.0
}

/// The AST node.
pub fn ast_type_block(self) -> &'db ast::TemplateString {
&self.db.ast[self.id.0]
}

/// The AST node.
pub fn ast_arg(self) -> (Option<&'db ast::Identifier>, &'db ast::BlockArg) {
let args = self.ast_type_block().input();
let res: &_ = &args.expect("Expected input args")[self.id.1];
(Some(&res.0), &res.1)
}

/// The name of the type.
pub fn field_type(self) -> &'db ast::FieldType {
&self.ast_arg().1.field_type
}

/// The name of the function.
pub fn is_optional(self) -> bool {
self.field_type().is_optional()
}
}

0 comments on commit a578cc2

Please sign in to comment.