Skip to content

Commit

Permalink
Add constraints to test blocks.
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Nov 23, 2024
1 parent 93b393d commit cf7b582
Show file tree
Hide file tree
Showing 28 changed files with 1,035 additions and 95 deletions.
2 changes: 2 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 11 additions & 13 deletions engine/baml-lib/baml-core/src/ir/jinja_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn sum_filter(value: Vec<Value>) -> Value {
/// E.g. `"a|length > 2"` with context `{"a": [1, 2, 3]}` will return `"true"`.
pub fn render_expression(
expression: &JinjaExpression,
ctx: &HashMap<String, BamlValue>,
ctx: &HashMap<String, minijinja::Value>,
) -> anyhow::Result<String> {
let env = get_env();
// In rust string literals, `{` is escaped as `{{`.
Expand All @@ -66,8 +66,8 @@ pub fn evaluate_predicate(
this: &BamlValue,
predicate_expression: &JinjaExpression,
) -> Result<bool, anyhow::Error> {
let ctx: HashMap<String, BamlValue> =
[("this".to_string(), this.clone())].into_iter().collect();
let ctx: HashMap<String, minijinja::Value> =
HashMap::from([("this".to_string(), minijinja::Value::from_serialize(this))]);
match render_expression(&predicate_expression, &ctx)?.as_ref() {
"true" => Ok(true),
"false" => Ok(false),
Expand All @@ -87,11 +87,12 @@ mod tests {
"a".to_string(),
BamlValue::List(
vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into(),
),
)
.into(),
),
(
"b".to_string(),
BamlValue::String("(123)456-7890".to_string()),
BamlValue::String("(123)456-7890".to_string()).into(),
),
]
.into_iter()
Expand All @@ -118,11 +119,12 @@ mod tests {
"a".to_string(),
BamlValue::List(
vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into(),
),
)
.into(),
),
(
"b".to_string(),
BamlValue::String("(123)456-7890".to_string()),
BamlValue::String("(123)456-7890".to_string()).into(),
),
]
.into_iter()
Expand Down Expand Up @@ -151,16 +153,12 @@ mod tests {
fn test_sum_filter() {
let ctx = vec![].into_iter().collect();
assert_eq!(
render_expression(&JinjaExpression(
r#"[1,2]|sum"#.to_string()
), &ctx).unwrap(),
render_expression(&JinjaExpression(r#"[1,2]|sum"#.to_string()), &ctx).unwrap(),
"3"
);

assert_eq!(
render_expression(&JinjaExpression(
r#"[1,2.5]|sum"#.to_string()
), &ctx).unwrap(),
render_expression(&JinjaExpression(r#"[1,2.5]|sum"#.to_string()), &ctx).unwrap(),
"3.5"
);
}
Expand Down
89 changes: 77 additions & 12 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ use internal_baml_parser_database::{
walkers::{
ClassWalker, ClientSpec as AstClientSpec, ClientWalker, ConfigurationWalker,
EnumValueWalker, EnumWalker, FieldWalker, FunctionWalker, TemplateStringWalker,
Walker as AstWalker,
},
Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy,
};
use internal_baml_schema_ast::ast::SubType;
use internal_baml_schema_ast::ast::{SubType, ValExpId};

use baml_types::JinjaExpression;
use internal_baml_schema_ast::ast::{self, FieldArity, WithName, WithSpan};
Expand Down Expand Up @@ -676,8 +677,14 @@ impl WithRepr<Enum> for EnumWalker<'_> {
fn repr(&self, db: &ParserDatabase) -> Result<Enum> {
Ok(Enum {
name: self.name().to_string(),
values: self.values().map(|w| (w.node(db).map(|v| (v, w.documentation().map(|s| Docstring(s.to_string())))))).collect::<Result<Vec<_>,_>>()?,
docstring: self.get_documentation().map(|s| Docstring(s))
values: self
.values()
.map(|w| {
w.node(db)
.map(|v| (v, w.documentation().map(|s| Docstring(s.to_string()))))
})
.collect::<Result<Vec<_>, _>>()?,
docstring: self.get_documentation().map(|s| Docstring(s)),
})
}
}
Expand Down Expand Up @@ -722,7 +729,6 @@ impl WithRepr<Field> for FieldWalker<'_> {
docstring: self.get_documentation().map(|s| Docstring(s)),
})
}

}

type ClassId = String;
Expand Down Expand Up @@ -774,7 +780,7 @@ impl WithRepr<Class> for ClassWalker<'_> {
.collect::<Result<Vec<_>>>()?,
None => Vec::new(),
},
docstring: self.get_documentation().map(|s| Docstring(s))
docstring: self.get_documentation().map(|s| Docstring(s)),
})
}
}
Expand Down Expand Up @@ -1110,14 +1116,23 @@ pub struct TestCase {
pub name: String,
pub functions: Vec<Node<TestCaseFunction>>,
pub args: IndexMap<String, Expression>,
pub constraints: Vec<Constraint>,
}

impl WithRepr<TestCaseFunction> for (&ConfigurationWalker<'_>, usize) {
fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes {
let span = self.0.test_case().functions[self.1].1.clone();
let constraints = self
.0
.test_case()
.constraints
.iter()
.map(|(c, _, _)| c)
.cloned()
.collect();
NodeAttributes {
meta: IndexMap::new(),
constraints: Vec::new(),
constraints,
span: Some(span),
}
}
Expand All @@ -1131,10 +1146,17 @@ impl WithRepr<TestCaseFunction> for (&ConfigurationWalker<'_>, usize) {

impl WithRepr<TestCase> for ConfigurationWalker<'_> {
fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes {
let constraints = self
.test_case()
.constraints
.iter()
.map(|(c, _, _)| c)
.cloned()
.collect();
NodeAttributes {
meta: IndexMap::new(),
span: Some(self.span().clone()),
constraints: Vec::new(),
constraints,
}
}

Expand All @@ -1151,6 +1173,12 @@ impl WithRepr<TestCase> for ConfigurationWalker<'_> {
.map(|(k, (_, v))| Ok((k.clone(), v.repr(db)?)))
.collect::<Result<IndexMap<_, _>>>()?,
functions,
constraints: <AstWalker<'_, (ValExpId, &str)> as WithRepr<TestCase>>::attributes(
self, db,
)
.constraints
.into_iter()
.collect::<Vec<_>>(),
})
}
}
Expand Down Expand Up @@ -1223,7 +1251,8 @@ mod tests {

#[test]
fn test_docstrings() {
let ir = make_test_ir(r#"
let ir = make_test_ir(
r#"
/// Foo class.
class Foo {
/// Bar field.
Expand All @@ -1243,7 +1272,9 @@ mod tests {
THIRD
}
"#).unwrap();
"#,
)
.unwrap();

// Test class docstrings
let foo = ir.find_class("Foo").as_ref().unwrap().clone().elem();
Expand All @@ -1252,15 +1283,18 @@ mod tests {
[field1, field2] => {
assert_eq!(field1.elem.docstring.as_ref().unwrap().0, "Bar field.");
assert_eq!(field2.elem.docstring.as_ref().unwrap().0, "Baz field.");
},
}
_ => {
panic!("Expected 2 fields");
}
}

// Test enum docstrings
let test_enum = ir.find_enum("TestEnum").as_ref().unwrap().clone().elem();
assert_eq!(test_enum.docstring.as_ref().unwrap().0.as_str(), "Test enum.");
assert_eq!(
test_enum.docstring.as_ref().unwrap().0.as_str(),
"Test enum."
);
match test_enum.values.as_slice() {
[val1, val2, val3] => {
assert_eq!(val1.0.elem.0, "FIRST");
Expand All @@ -1269,10 +1303,41 @@ mod tests {
assert_eq!(val2.1.as_ref().unwrap().0, "Second variant.");
assert_eq!(val3.0.elem.0, "THIRD");
assert!(val3.1.is_none());
},
}
_ => {
panic!("Expected 3 enum values");
}
}
}

#[test]
fn test_block_attributes() {
let ir = make_test_ir(
r##"
client<llm> GPT4 {
provider openai
options {
model gpt-4o
api_key env.OPENAI_API_KEY
}
}
function Foo(a: int) -> int {
client GPT4
prompt #"Double the number {{ a }}"#
}
test Foo() {
functions [Foo]
args {
a 10
}
@@assert( {{ result == 20 }} )
}
"##,
)
.unwrap();
let function = ir.find_function("Foo").unwrap();
let walker = ir.find_test(&function, "Foo").unwrap();
assert_eq!(walker.item.1.elem.constraints.len(), 1);
}
}
4 changes: 2 additions & 2 deletions engine/baml-lib/baml-core/src/ir/walker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ impl Expression {
}
Expression::JinjaExpression(expr) => {
// TODO: do not coerce all context values to strings.
let jinja_context: HashMap<String, BamlValue> = env_values
let jinja_context: HashMap<String, minijinja::Value> = env_values
.iter()
.map(|(k, v)| (k.clone(), BamlValue::String(v.clone())))
.map(|(k, v)| (k.clone(), v.clone().into()))
.collect();
let res_string = render_expression(&expr, &jinja_context)?;
Ok(BamlValue::String(res_string))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod cycle;
mod enums;
mod functions;
mod template_strings;
mod tests;
mod types;

use baml_types::GeneratorOutputType;
Expand All @@ -22,6 +23,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
clients::validate(ctx);
template_strings::validate(ctx);
configurations::validate(ctx);
tests::validate(ctx);

let generators = load_generators_from_ast(ctx.db.ast(), ctx.diagnostics);
let codegen_targets: HashSet<GeneratorOutputType> = generators.into_iter().filter_map(|generator| match generator {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use baml_types::{Constraint, ConstraintLevel};
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};
use internal_baml_jinja_types::{validate_expression, JinjaContext, PredefinedTypes, Type};

use crate::validate::validation_pipeline::context::Context;

pub(super) fn validate(ctx: &mut Context<'_>) {
let tests = ctx.db.walk_test_cases().collect::<Vec<_>>();
tests.iter().for_each(|walker| {
let constraints = &walker.test_case().constraints;
let args = &walker.test_case().args;
let mut check_names: Vec<String> = Vec::new();
for (
Constraint {
label,
level,
expression,
},
constraint_span,
expr_span,
) in constraints.iter()
{
let mut defined_types = PredefinedTypes::default(JinjaContext::Parsing);
defined_types.add_variable("this", Type::Unknown);
defined_types.add_class(
"Checks",
check_names
.iter()
.map(|check_name| (check_name.clone(), Type::Unknown))
.collect(),
);
defined_types.add_class(
"_",
vec![
("checks".to_string(), Type::ClassRef("Checks".to_string())),
("result".to_string(), Type::Unknown),
("latency_ms".to_string(), Type::Number),
]
.into_iter()
.collect(),
);
defined_types.add_variable("_", Type::ClassRef("_".to_string()));
args.keys()
.for_each(|arg_name| defined_types.add_variable(arg_name, Type::Unknown));
match (level, label) {
(ConstraintLevel::Check, Some(check_name)) => {
check_names.push(check_name.to_string());
}
_ => {}
}
match validate_expression(expression.0.as_str(), &mut defined_types) {
Ok(_) => {}
Err(e) => {
if let Some(e) = e.parsing_errors {
let range = match e.range() {
Some(range) => range,
None => {
ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
expr_span.clone(),
));
continue;
}
};

let start_offset = expr_span.start + range.start;
let end_offset = expr_span.start + range.end;

let span = Span::new(
expr_span.file.clone(),
start_offset as usize,
end_offset as usize,
);

ctx.push_error(DatamodelError::new_validation_error(
&format!("Error parsing jinja template: {}", e),
span,
))
} else {
e.errors.iter().for_each(|t| {
let tspan = t.span();
let span = Span::new(
expr_span.file.clone(),
expr_span.start + tspan.start_offset as usize,
expr_span.start + tspan.end_offset as usize,
);
ctx.push_warning(DatamodelWarning::new(t.message().to_string(), span))
})
}
}
}
}
});
}
Loading

0 comments on commit cf7b582

Please sign in to comment.