Skip to content

Commit

Permalink
Adding some code that prevent asserts from parsing at all + checks no…
Browse files Browse the repository at this point in the history
… long impact parsign (#1101)

<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> Enhance BAML's handling of assertions and checks by updating parsing
logic, modifying related functions, and improving syntax highlighting.
> 
>   - **Behavior**:
> - `run_user_checks()` in `coercer/mod.rs` now returns `Vec<(String,
JinjaExpression, bool)>` instead of `Vec<(Constraint, bool)>`, handling
only checks with labels and rejecting asserts.
> - `ConstraintResults` in `deserialize_flags.rs` now stores
`Vec<(String, JinjaExpression, bool)>`.
> - `score.rs` no longer assigns scores to constraints in
`Flag::ConstraintResults`.
>   - **Functions**:
> - Updated `constraint_results()` in `deserialize_flags.rs` to return
`Vec<(String, JinjaExpression, bool)>`.
> - Updated `fmt()` for `Flag::ConstraintResults` in
`deserialize_flags.rs` to display only checks.
>   - **Misc**:
>     - Minor import reordering in `coercer/mod.rs`.
> - Updated tests in `test_constraints.rs` and `test_unions.rs` to
reflect changes in constraint handling.
> - Added `IntroToChecksDialog` component in `IntroToChecksDialog.tsx`
to introduce checks to users.
> - Updated `baml.tmLanguage.json` to improve syntax highlighting for
checks and asserts.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 3c3d210. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->

---------

Co-authored-by: Greg Hale <[email protected]>
  • Loading branch information
hellovai and imalsogreg authored Oct 26, 2024
1 parent e665346 commit 5ec89c9
Show file tree
Hide file tree
Showing 28 changed files with 679 additions and 152 deletions.
397 changes: 397 additions & 0 deletions docs/docs/calling-baml/validations.mdx

Large diffs are not rendered by default.

27 changes: 26 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ pub fn infer_type<'a>(value: &'a BamlValue) -> Option<FieldType> {
mod tests {
use super::*;
use baml_types::{
BamlMedia, BamlMediaContent, BamlMediaType, BamlValue, FieldType, MediaBase64, TypeValue,
BamlMedia, BamlMediaContent, BamlMediaType, BamlValue, Constraint, ConstraintLevel, FieldType, JinjaExpression, MediaBase64, TypeValue
};
use repr::make_test_ir;

Expand Down Expand Up @@ -665,4 +665,29 @@ mod tests {
let head = nodes.next().unwrap();
assert_eq!(head.meta(), &map_type);
}

#[test]
fn test_malformed_check_in_argument() {
let ir = make_test_ir(
r##"
client<llm> GPT4 {
provider openai
options {
model gpt-4o
api_key env.OPENAI_API_KEY
}
}
function Foo(a: int @assert(malformed, {{ this.length() > 0 }})) -> int {
client GPT4
prompt #""#
}
"##,
)
.unwrap();
let function = ir.find_function("Foo").unwrap();
let params = vec![("a".to_string(), BamlValue::Int(1))].into_iter().collect();
let arg_coercer = ArgCoercer { span_path: None, allow_implicit_cast_to_string: true };
let res = ir.check_function_params(&function, &params, arg_coercer);
assert!(res.is_err());
}
}
104 changes: 77 additions & 27 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use baml_types::{
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue,
TypeValue,
};
use core::result::Result;
use std::path::PathBuf;
Expand Down Expand Up @@ -311,12 +312,12 @@ impl ArgCoercer {
let result = self.coerce_arg(ir, option, value, &mut scope);
if !scope.has_errors() {
if first_good_result.is_err() {
first_good_result = result
first_good_result = result
}
}
}
}
if first_good_result.is_err(){
if first_good_result.is_err() {
scope.push_error(format!("Expected one of {:?}, got `{}`", options, value));
Err(())
} else {
Expand All @@ -342,18 +343,20 @@ impl ArgCoercer {
}
}?;


let search_for_failures_result = first_failing_assert_nested(ir, &value, field_type).map_err(|e| {
scope.push_error(format!("Failed to evaluate assert: {:?}", e));
()
})?;
let search_for_failures_result = first_failing_assert_nested(ir, &value, field_type)
.map_err(|e| {
scope.push_error(format!("Failed to evaluate assert: {:?}", e));
()
})?;
match search_for_failures_result {
Some(Constraint {label, expression, ..}) => {
Some(Constraint {
label, expression, ..
}) => {
let msg = label.as_ref().unwrap_or(&expression.0);
scope.push_error(format!("Failed assert: {msg}"));
Ok(value)
Err(())
}
None => Ok(value)
None => Ok(value),
}
}
}
Expand All @@ -363,31 +366,78 @@ impl ArgCoercer {
fn first_failing_assert_nested<'a>(
ir: &'a IntermediateRepr,
baml_value: &BamlValue,
field_type: &'a FieldType
field_type: &'a FieldType,
) -> anyhow::Result<Option<Constraint>> {
let value_with_types = ir.distribute_type(baml_value.clone(), field_type.clone());
let value_with_types = ir.distribute_type(baml_value.clone(), field_type.clone())?;
let first_failure = value_with_types
.iter()
.map(|value_node| {
let (_, constraints) = value_node.meta().distribute_constraints();
constraints.into_iter().filter_map(|c| {
let constraint = c.clone();
let baml_value: BamlValue = value_node.into();
let result = evaluate_predicate(&&baml_value, &c.expression).map_err(|e| {
anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e))
});
match result {
Ok(false) => if c.level == ConstraintLevel::Assert {Some(Ok(constraint))} else { None },
Ok(true) => None,
Err(e) => Some(Err(e))

}
})
.collect::<Vec<_>>()
constraints
.into_iter()
.filter_map(|c| {
let constraint = c.clone();
let baml_value: BamlValue = value_node.into();
let result = evaluate_predicate(&&baml_value, &c.expression).map_err(|e| {
anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e))
});
match result {
Ok(false) => {
if c.level == ConstraintLevel::Assert {
Some(Ok(constraint))
} else {
None
}
}
Ok(true) => None,
Err(e) => Some(Err(e)),
}
})
.collect::<Vec<_>>()
})
.map(|x| x.into_iter())
.flatten()
.next();
first_failure.transpose()
}

#[cfg(test)]
mod tests {
use baml_types::JinjaExpression;

use crate::ir::repr::make_test_ir;

use super::*;

#[test]
fn test_malformed_check_in_argument() {
let ir = make_test_ir(
r##"
client<llm> GPT4 {
provider openai
options {
model gpt-4o
api_key env.OPENAI_API_KEY
}
}
function Foo(a: int @assert(malformed, {{ this.length() > 0 }})) -> int {
client GPT4
prompt #""#
}
"##,
)
.unwrap();
let value = BamlValue::Int(1);
let type_ = FieldType::Constrained {
base: Box::new(FieldType::Primitive(TypeValue::Int)),
constraints: vec![Constraint {
level: ConstraintLevel::Assert,
expression: JinjaExpression("this.length() > 0".to_string()),
label: Some("foo".to_string()),
}],
};
let arg_coercer = ArgCoercer { span_path: None, allow_implicit_cast_to_string: true };
let res = arg_coercer.coerce_arg(&ir, &type_, &value, &mut ScopeStack::new());
assert!(res.is_err());
}
}
3 changes: 1 addition & 2 deletions engine/baml-lib/baml-core/src/ir/jinja_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub fn render_expression(
// So producing the string `{{}}` requires writing the literal `"{{{{}}}}"`
let template = format!(r#"{{{{ {} }}}}"#, expression.0);
let args_dict = minijinja::Value::from_serialize(ctx);
eprintln!("{}", &template);
Ok(env.render_str(&template, &args_dict)?)
}

Expand All @@ -45,7 +44,7 @@ pub fn evaluate_predicate(
match render_expression(&predicate_expression, &ctx)?.as_ref() {
"true" => Ok(true),
"false" => Ok(false),
_ => Err(anyhow::anyhow!("TODO")),
_ => Err(anyhow::anyhow!("Predicate did not evaluate to a boolean")),
}
}

Expand Down
13 changes: 13 additions & 0 deletions engine/baml-lib/baml-types/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ pub struct Constraint {
pub label: Option<String>,
}

impl Constraint {
pub fn as_check(self) -> Option<(String, JinjaExpression)> {
match self.level {
ConstraintLevel::Check => Some((
self.label
.expect("Checks are guaranteed by the pest grammar to have a label."),
self.expression,
)),
ConstraintLevel::Assert => None,
}
}
}

#[derive(Clone, Debug, PartialEq, serde::Serialize)]
pub enum ConstraintLevel {
Check,
Expand Down
69 changes: 59 additions & 10 deletions engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use baml_types::BamlMap;
use baml_types::{BamlMap, Constraint, ConstraintLevel};
use internal_baml_core::{ir::FieldType, ir::TypeValue};

use crate::deserializer::{
Expand Down Expand Up @@ -86,22 +86,71 @@ impl TypeCoercer for FieldType {
FieldType::Tuple(_) => Err(ctx.error_internal("Tuple not supported")),
FieldType::Constrained { base, .. } => {
let mut coerced_value = base.coerce(ctx, base, value)?;
let constraint_results =
run_user_checks(&coerced_value.clone().into(), &self).map_err(
|e| ParsingError {
reason: format!("Failed to evaluate constraints: {:?}", e),
scope: ctx.scope.clone(),
causes: Vec::new(),
},
)?;
coerced_value.add_flag(Flag::ConstraintResults(constraint_results));
let constraint_results = run_user_checks(&coerced_value.clone().into(), &self)
.map_err(|e| ParsingError {
reason: format!("Failed to evaluate constraints: {:?}", e),
scope: ctx.scope.clone(),
causes: Vec::new(),
})?;
validate_asserts(&constraint_results)?;
let check_results = constraint_results
.into_iter()
.filter_map(|(maybe_check, result)| {
maybe_check
.as_check()
.map(|(label, expr)| (label, expr, result))
})
.collect();
coerced_value.add_flag(Flag::ConstraintResults(check_results));
Ok(coerced_value)
}
},
}
}
}

fn validate_asserts(constraints: &Vec<(Constraint, bool)>) -> Result<(), ParsingError> {
let failing_asserts = constraints
.iter()
.filter_map(
|(
Constraint {
level,
expression,
label,
},
result,
)| {
if !result && ConstraintLevel::Assert == *level {
Some((label, expression))
} else {
None
}
},
)
.collect::<Vec<_>>();
let causes = failing_asserts
.into_iter()
.map(|(label, expr)| ParsingError {
causes: vec![],
reason: format!(
"Failed: {}{}",
label.as_ref().map_or("".to_string(), |l| format!("{} ", l)),
expr.0
),
scope: vec![],
}).collect::<Vec<_>>();
if causes.len() > 0 {
Err(ParsingError {
causes: vec![],
reason: "Assertions failed.".to_string(),
scope: vec![],
})
} else {
Ok(())
}
}

impl DefaultValue for FieldType {
fn default_value(&self, error: Option<&ParsingError>) -> Option<BamlValueWithFlags> {
let get_flags = || {
Expand Down
13 changes: 7 additions & 6 deletions engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ mod match_string;

use anyhow::Result;

use baml_types::{BamlValue, Constraint};
use baml_types::{BamlValue, Constraint, JinjaExpression};
use internal_baml_jinja::types::OutputFormatContent;

use internal_baml_core::ir::{FieldType, jinja_helpers::evaluate_predicate};
use internal_baml_core::ir::{jinja_helpers::evaluate_predicate, FieldType};

use super::types::BamlValueWithFlags;

Expand Down Expand Up @@ -233,12 +233,13 @@ pub fn run_user_checks(
type_: &FieldType,
) -> Result<Vec<(Constraint, bool)>> {
match type_ {
FieldType::Constrained { constraints, .. } => {
constraints.iter().map(|constraint| {
FieldType::Constrained { constraints, .. } => constraints
.iter()
.map(|constraint| {
let result = evaluate_predicate(baml_value, &constraint.expression)?;
Ok((constraint.clone(), result))
}).collect::<Result<Vec<_>>>()
}
})
.collect::<Result<Vec<_>>>(),
_ => Ok(vec![]),
}
}
30 changes: 18 additions & 12 deletions engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{coercer::ParsingError, types::BamlValueWithFlags};
use baml_types::Constraint;
use baml_types::{Constraint, ConstraintLevel, JinjaExpression};

#[derive(Debug, Clone)]
pub enum Flag {
Expand Down Expand Up @@ -44,8 +44,8 @@ pub enum Flag {
// X -> Object convertions.
NoFields(Option<crate::jsonish::Value>),

// Constraint results.
ConstraintResults(Vec<(Constraint, bool)>),
/// Constraint results (only contains checks)
ConstraintResults(Vec<(String, JinjaExpression, bool)>),
}

#[derive(Clone)]
Expand Down Expand Up @@ -99,13 +99,16 @@ impl DeserializerConditions {
.collect::<Vec<_>>()
}

pub fn constraint_results(&self) -> Vec<(Constraint, bool)> {
self.flags.iter().filter_map(|flag| match flag {
Flag::ConstraintResults(cs) => Some(cs.clone()),
_ => None,
}).flatten().collect()
pub fn constraint_results(&self) -> Vec<(String, JinjaExpression, bool)> {
self.flags
.iter()
.filter_map(|flag| match flag {
Flag::ConstraintResults(cs) => Some(cs.clone()),
_ => None,
})
.flatten()
.collect()
}

}

impl std::fmt::Debug for DeserializerConditions {
Expand Down Expand Up @@ -243,10 +246,13 @@ impl std::fmt::Display for Flag {
}
}
Flag::ConstraintResults(cs) => {
for (Constraint{ label, level, expression }, succeeded) in cs.iter() {
let msg = label.as_ref().unwrap_or(&expression.0);
for (label, _, succeeded) in cs.iter() {
let f_result = if *succeeded { "Succeeded" } else { "Failed" };
writeln!(f, "{level:?} {msg} {f_result}")?;
writeln!(
f,
"{level:?} {label} {f_result}",
level = ConstraintLevel::Check
)?;
}
}
}
Expand Down
Loading

0 comments on commit 5ec89c9

Please sign in to comment.