Skip to content

Commit

Permalink
Improve static analysis on jinja prompts (#1102)
Browse files Browse the repository at this point in the history
* invalid prompt syntax now provides a compiler error
* unknown variables with check / assert expressions will raise a warning

<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Enhance static analysis for Jinja templates by adding compiler errors
for invalid syntax and warnings for unknown variables, updating
validation functions and tests accordingly.
> 
>   - **Behavior**:
>     - Invalid Jinja prompt syntax now triggers a compiler error.
> - Unknown variables in `check`/`assert` expressions raise a warning.
>   - **Validation**:
> - Updated `validate()` functions in `classes.rs`, `enums.rs`, and
`functions.rs` to use `PredefinedTypes::default(JinjaContext::Prompt)`.
> - Added error handling for parsing errors in `validate_template()` and
`validate_expression()` in `lib.rs`.
>   - **Types**:
> - Introduced `JinjaContext` enum in `types.rs` to differentiate
between `Prompt` and `Parsing` contexts.
>     - Updated `PredefinedTypes` to handle different contexts.
>   - **Tests**:
> - Added tests for malformed and valid-but-invalid expressions in
`malformed_expression.baml` and `valid_but_invalid_expressions.baml`.
> - Updated test cases in `test_expr.rs` and `test_stmt.rs` to reflect
new validation behavior.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 6e59805. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Oct 25, 2024
1 parent aa736ed commit 7ca8136
Show file tree
Hide file tree
Showing 16 changed files with 328 additions and 79 deletions.
6 changes: 4 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,12 @@ mod tests {
#[test]
fn distribute_media() {
let ir = mk_ir();
let v = BamlValue::Media(BamlMedia{
let v = BamlValue::Media(BamlMedia {
media_type: BamlMediaType::Audio,
mime_type: None,
content: BamlMediaContent::Base64(MediaBase64{base64: "abcd=".to_string()}),
content: BamlMediaContent::Base64(MediaBase64 {
base64: "abcd=".to_string(),
}),
});
let t = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio));
let _value_with_meta = ir.distribute_type(v, t).unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::validate::validation_pipeline::context::Context;
use internal_baml_diagnostics::DatamodelError;

pub(super) fn validate(ctx: &mut Context<'_>) {
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default();
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Prompt,
);

for cls in ctx.db.walk_classes() {
for c in cls.static_fields() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use internal_baml_diagnostics::DatamodelError;
use internal_baml_schema_ast::ast::{WithName, WithSpan};

pub(super) fn validate(ctx: &mut Context<'_>) {
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default();
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Prompt,
);
for enm in ctx.db.walk_enums() {
for args in enm.walk_input_args() {
let arg = args.ast_arg();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{validate::validation_pipeline::context::Context};
use crate::validate::validation_pipeline::context::Context;

use either::Either;
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};
Expand All @@ -14,7 +14,9 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
.map(|c| c.name().to_string())
.collect::<Vec<_>>();

let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default();
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Prompt,
);
ctx.db.walk_classes().for_each(|t| {
t.add_to_types(&mut defined_types);
});
Expand Down Expand Up @@ -49,11 +51,31 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
Ok(_) => {}
Err(e) => {
let pspan = prompt.span();
if let Some(_e) = e.parsing_errors {
// ctx.push_error(DatamodelError::new_validation_error(
// &format!("Error parsing jinja template: {}", e),
// e.line(),
// ))
if let Some(e) = e.parsing_errors {
let range = match e.range() {
Some(range) => range,
None => {
ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
pspan.clone(),
));
continue;
}
};

let start_offset = pspan.start + range.start;
let end_offset = pspan.start + range.end;

let span = Span::new(
pspan.file.clone(),
start_offset as usize,
end_offset as usize,
);

ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span,
))
} else {
e.errors.iter().for_each(|t| {
let span = t.span();
Expand Down Expand Up @@ -83,9 +105,11 @@ pub(super) fn validate(ctx: &mut Context<'_>) {

let span = field_type.span().clone();
if has_checks_nested(ctx, field_type) {
ctx.push_error(DatamodelError::new_validation_error("Types with checks are not allowed as function parameters.", span));
ctx.push_error(DatamodelError::new_validation_error(
"Types with checks are not allowed as function parameters.",
span,
));
}

}

// Ensure the client is correct.
Expand Down Expand Up @@ -149,10 +173,30 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
Err(e) => {
let pspan = prompt.span();
if let Some(e) = e.parsing_errors {
// ctx.push_error(DatamodelError::new_validation_error(
// &format!("Error parsing jinja template: {}", e),
// // e.,
// ))
let range = match e.range() {
Some(range) => range,
None => {
ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
pspan.clone(),
));
continue;
}
};

let start_offset = pspan.start + range.start;
let end_offset = pspan.start + range.end;

let span = Span::new(
pspan.file.clone(),
start_offset as usize,
end_offset as usize,
);

ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span,
))
} else {
e.errors.iter().for_each(|t| {
let span = t.span();
Expand All @@ -179,23 +223,27 @@ fn has_checks_nested(ctx: &Context<'_>, field_type: &FieldType) -> bool {
}

match field_type {
FieldType::Symbol(_, id, ..) => {
match ctx.db.find_type(id) {
Some(Either::Left(class_walker)) => {
let mut fields = class_walker.static_fields();
fields.any(|field| field.ast_field().expr.as_ref().map_or(false, |ft| has_checks_nested(ctx, &ft)))
}
,
_ => false,
FieldType::Symbol(_, id, ..) => match ctx.db.find_type(id) {
Some(Either::Left(class_walker)) => {
let mut fields = class_walker.static_fields();
fields.any(|field| {
field
.ast_field()
.expr
.as_ref()
.map_or(false, |ft| has_checks_nested(ctx, &ft))
})
}
_ => false,
},

FieldType::Primitive(..) => false,
FieldType::Union(_, children, ..) => children.iter().any(|ft| has_checks_nested(ctx, ft)),
FieldType::Literal(..) => false,
FieldType::Tuple(_, children, ..) => children.iter().any(|ft| has_checks_nested(ctx, ft)),
FieldType::List(_, child, ..) => has_checks_nested(ctx, child),
FieldType::Map(_, kv, ..) =>
has_checks_nested(ctx, &kv.as_ref().0) || has_checks_nested(ctx, &kv.as_ref().1),
FieldType::Map(_, kv, ..) => {
has_checks_nested(ctx, &kv.as_ref().0) || has_checks_nested(ctx, &kv.as_ref().1)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use baml_types::TypeValue;
use internal_baml_diagnostics::DatamodelError;
use internal_baml_schema_ast::ast::{Argument, Attribute, Expression, FieldArity, FieldType, Identifier, WithName, WithSpan};
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};
use internal_baml_schema_ast::ast::{
Argument, Attribute, Expression, FieldArity, FieldType, Identifier, WithName, WithSpan,
};

use crate::validate::validation_pipeline::context::Context;

Expand Down Expand Up @@ -89,28 +91,139 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) {
}

fn validate_type_constraints(ctx: &mut Context<'_>, field_type: &FieldType) {
let constraint_attrs = field_type.attributes().iter().filter(|attr| ["assert", "check"].contains(&attr.name.name())).collect::<Vec<_>>();
for Attribute { arguments, span, name, .. } in constraint_attrs.iter() {
let arg_expressions = arguments.arguments.iter().map(|Argument{value,..}| value).collect::<Vec<_>>();

match arg_expressions.as_slice() {
[ Expression::Identifier(Identifier::Local(s,_)), Expression::JinjaExpressionValue(_, _)] => {
// Ok.
},
[Expression::JinjaExpressionValue(_, _)] => {
if name.to_string() == "check" {
ctx.push_error(DatamodelError::new_validation_error(
"Check constraints must have a name.",
span.clone()
))
let constraint_attrs = field_type
.attributes()
.iter()
.filter(|attr| ["assert", "check"].contains(&attr.name.name()))
.collect::<Vec<_>>();
for Attribute {
arguments,
span,
name,
..
} in constraint_attrs.iter()
{
let arg_expressions = arguments
.arguments
.iter()
.map(|Argument { value, .. }| value)
.collect::<Vec<_>>();

match arg_expressions.as_slice() {
[Expression::Identifier(Identifier::Local(s, _)), Expression::JinjaExpressionValue(expr, span)] =>
{
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Parsing,
);
defined_types.add_variable("this", internal_baml_jinja_types::Type::Unknown);
match internal_baml_jinja_types::validate_expression(&expr.0, &mut defined_types) {
Ok(_) => {}
Err(e) => {
if let Some(e) = e.parsing_errors {
let range = match e.range() {
Some(range) => range,
None => {
ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span.clone(),
));
continue;
}
};

let start_offset = span.start + range.start;
let end_offset = span.start + range.end;

let span = Span::new(
span.file.clone(),
start_offset as usize,
end_offset as usize,
);

ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span,
))
} else {
e.errors.iter().for_each(|t| {
let tspan = t.span();
let span = Span::new(
span.file.clone(),
span.start + tspan.start_offset as usize,
span.start + tspan.end_offset as usize,
);
ctx.push_warning(DatamodelWarning::new(
t.message().to_string(),
span,
))
})
}
}
},
_ => {
}
}
[Expression::JinjaExpressionValue(expr, span)] => {
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Parsing,
);
defined_types.add_variable("this", internal_baml_jinja_types::Type::Unknown);
match internal_baml_jinja_types::validate_expression(&expr.0, &mut defined_types) {
Ok(_) => {}
Err(e) => {
if let Some(e) = e.parsing_errors {
let range = match e.range() {
Some(range) => range,
None => {
ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span.clone(),
));
continue;
}
};

let start_offset = span.start + range.start;
let end_offset = span.start + range.end;

let span = Span::new(
span.file.clone(),
start_offset as usize,
end_offset as usize,
);

ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span,
))
} else {
e.errors.iter().for_each(|t| {
let tspan = t.span();
let span = Span::new(
span.file.clone(),
span.start + tspan.start_offset as usize,
span.start + tspan.end_offset as usize,
);
ctx.push_warning(DatamodelWarning::new(
t.message().to_string(),
span,
))
})
}
}
}

if name.to_string() == "check" {
ctx.push_error(DatamodelError::new_validation_error(
"Check constraints must have a name.",
span.clone(),
))
}
}
_ => {
ctx.push_error(DatamodelError::new_validation_error(
"A constraint must have one Jinja argument such as {{ expr }}, and optionally one String label",
span.clone()
));
}
}
}
}
}
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
class Foo {
bar string @check(bar_check, {{ ) }})
}


function FunctionName(arg:string) -> "foo" {
client "openai/gpt-4o"
prompt #"
Your prompt here in jinja format
{{ ) }}
"#
}


function FunctionName2(arg:string) -> "foo" {
client "openai/gpt-4o"
prompt #"
Your prompt here in jinja format
{{ if foo }}
{{ foo }}
{{ endif }}
"#
}

// error: Error validating: Error parsing jinja template: syntax error: unexpected `)` (in <expression>:1)
// --> constraints/malformed_expression.baml:2
// |
// 1 | class Foo {
// 2 | bar string @check(bar_check, {{ ) }})
// |
// error: Error validating: Error parsing jinja template: syntax error: unexpected `)` (in FunctionName:3)
// --> constraints/malformed_expression.baml:10
// |
// 9 | Your prompt here in jinja format
// 10 | {{ ) }}
// |
// error: Error validating: Error parsing jinja template: syntax error: unexpected identifier, expected end of variable block (in FunctionName2:3)
// --> constraints/malformed_expression.baml:19
// |
// 18 | Your prompt here in jinja format
// 19 | {{ if foo }}
// |
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Foo {
bar string @check(bar_check, {{ bar }})
}

// warning: Variable `bar` does not exist. Did you mean `this`?
// --> constraints/valid_but_invalid_expressions.baml:2
// |
// 1 | class Foo {
// 2 | bar string @check(bar_check, {{ bar }})
// |
Loading

0 comments on commit 7ca8136

Please sign in to comment.