Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block-level constraints #1124

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 104 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
};
use anyhow::Result;
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, LiteralValue, TypeValue};
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue};
pub use to_baml_arg::ArgCoercer;

use super::repr;
Expand Down Expand Up @@ -51,6 +51,18 @@ pub trait IRHelper {
value: BamlValue,
field_type: FieldType,
) -> Result<BamlValueWithMeta<FieldType>>;
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType
) -> (&'a FieldType, Vec<Constraint>);
fn type_has_constraints(
&self,
field_type: &FieldType
) -> bool;
fn type_has_checks(
&self,
field_type: &FieldType
) -> bool;
}

impl IRHelper for IntermediateRepr {
Expand Down Expand Up @@ -365,6 +377,66 @@ impl IRHelper for IntermediateRepr {
}
}
}


/// Constraints may live in several places. A constrained base type stors its
/// constraints by wrapping itself in the `FieldType::Constrained` constructor.
/// Additionally, `FieldType::Class` may have constraints stored in its class node,
/// and `FieldType::Enum` can store constraints in its `Enum` node.
/// And the `FieldType::Constrained` constructor might wrap another
/// `FieldType::Constrained` constructor.
///
/// This function collects constraints for a given type from all these
/// possible sources. Whenever querying a type for its constraints, you
/// should do so with this function, instead of searching manually for all
/// the places that Constraints can live.
fn distribute_constraints<'a>(&'a self, field_type: &'a FieldType) -> (&'a FieldType, Vec<Constraint>) {
match field_type {
FieldType::Class(class_name) => {
match self.find_class(class_name) {
Err(_) => (field_type, Vec::new()),
Ok(class_node) => (field_type, class_node.item.attributes.constraints.clone())
}
}
FieldType::Enum(enum_name) => {
match self.find_enum(enum_name) {
Err(_) => (field_type, Vec::new()),
Ok(enum_node) => (field_type, enum_node.item.attributes.constraints.clone())
}
}
// Check the first level to see if it's constrained.
FieldType::Constrained { base, constraints } => {
match base.as_ref() {
// If so, we must check the second level to see if we need to combine
// constraints across levels.
// The recursion here means that arbitrarily nested `FieldType::Constrained`s
// will be collapsed before the function returns.
FieldType::Constrained { .. } => {
let (sub_base, sub_constraints) = self.distribute_constraints(base.as_ref());
let combined_constraints = vec![constraints.clone(), sub_constraints]
.into_iter()
.flatten()
.collect();
(sub_base, combined_constraints)
}
_ => (base, constraints.clone()),
}
}
_ => (field_type, Vec::new()),
}
}

fn type_has_constraints(&self, field_type: &FieldType) -> bool {
let (_, constraints) = self.distribute_constraints(field_type);
!constraints.is_empty()
}

fn type_has_checks(&self, field_type: &FieldType) -> bool {
let (_, constraints) = self.distribute_constraints(field_type);
constraints
.iter()
.any(|Constraint { level, .. }| *level == ConstraintLevel::Check)
}
}

const UNIT_TYPE: FieldType = FieldType::Tuple(vec![]);
Expand Down Expand Up @@ -686,4 +758,35 @@ mod tests {
let res = ir.check_function_params(&function, &params, arg_coercer);
assert!(res.is_err());
}

#[test]
fn test_nested_constraint_distribution() {
let ir = make_test_ir("").unwrap();
fn mk_constraint(s: &str) -> Constraint {
Constraint {
level: ConstraintLevel::Assert,
expression: JinjaExpression(s.to_string()),
label: Some(s.to_string()),
}
}

let input = FieldType::Constrained {
constraints: vec![mk_constraint("a")],
base: Box::new(FieldType::Constrained {
constraints: vec![mk_constraint("b")],
base: Box::new(FieldType::Constrained {
constraints: vec![mk_constraint("c")],
base: Box::new(FieldType::Primitive(TypeValue::Int)),
}),
}),
};

let expected_base = FieldType::Primitive(TypeValue::Int);
let expected_constraints = vec![mk_constraint("a"), mk_constraint("b"), mk_constraint("c")];

let (base, constraints) = ir.distribute_constraints(&input);

assert_eq!(base, &expected_base);
assert_eq!(constraints, expected_constraints);
}
}
4 changes: 2 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl ArgCoercer {
value: &BamlValue, // original value passed in by user
scope: &mut ScopeStack,
) -> Result<BamlValue, ()> {
let value = match field_type.distribute_constraints() {
let value = match ir.distribute_constraints(field_type) {
(FieldType::Primitive(t), _) => match t {
TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()),
TypeValue::String if self.allow_implicit_cast_to_string => match value {
Expand Down Expand Up @@ -372,7 +372,7 @@ fn first_failing_assert_nested<'a>(
let first_failure = value_with_types
.iter()
.map(|value_node| {
let (_, constraints) = value_node.meta().distribute_constraints();
let (_, constraints) = ir.distribute_constraints(value_node.meta());
constraints
.into_iter()
.filter_map(|c| {
Expand Down
22 changes: 19 additions & 3 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ pub struct NodeAttributes {
#[serde(with = "indexmap::map::serde_seq")]
meta: IndexMap<String, Expression>,

constraints: Vec<Constraint>,
pub constraints: Vec<Constraint>,

// Spans
#[serde(skip)]
Expand Down Expand Up @@ -371,10 +371,26 @@ impl WithRepr<FieldType> for ast::FieldType {
ast::FieldType::Symbol(arity, idn, ..) => type_with_arity(
match db.find_type(idn) {
Some(Either::Left(class_walker)) => {
FieldType::Class(class_walker.name().to_string())
let base_class = FieldType::Class(class_walker.name().to_string());
let maybe_constraints = class_walker.get_constraints(SubType::Class);
match maybe_constraints {
Some(constraints) if constraints.len() > 0 => FieldType::Constrained {
base: Box::new(base_class),
constraints,
},
_ => base_class
}
}
Some(Either::Right(enum_walker)) => {
FieldType::Enum(enum_walker.name().to_string())
let base_type = FieldType::Enum(enum_walker.name().to_string());
let maybe_constraints = enum_walker.get_constraints(SubType::Enum);
match maybe_constraints {
Some(constraints) if constraints.len() > 0 => FieldType::Constrained {
base: Box::new(base_type),
constraints
},
_ => base_type
}
}
None => return Err(anyhow!("Field type uses unresolvable local identifier")),
},
Expand Down
93 changes: 87 additions & 6 deletions engine/baml-lib/baml-types/src/baml_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,13 +625,20 @@ impl Serialize for BamlValueWithMeta<Vec<ResponseCheck>> {
BamlValueWithMeta::Media(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Enum(_enum_name, v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Class(_class_name, v, cr) => {
let mut map = serializer.serialize_map(None)?;
for (key, value) in v {
map.serialize_entry(key, value)?;
if cr.is_empty() {
let mut map = serializer.serialize_map(None)?;
for (key, value) in v {
map.serialize_entry(key, value)?;
}
add_checks(&mut map, cr)?;
map.end()
} else {
let mut checked_value = serializer.serialize_map(Some(2))?;
checked_value.serialize_entry("value", &v)?;
add_checks(&mut checked_value, cr)?;
checked_value.end()
}
add_checks(&mut map, cr)?;
map.end()
}
},
BamlValueWithMeta::Null(cr) => serialize_with_checks(&(), cr, serializer),
}
}
Expand Down Expand Up @@ -707,4 +714,78 @@ mod tests {
assert!(serde_json::to_value(baml_value).is_ok());
assert!(serde_json::to_value(baml_value_2).is_ok());
}

#[test]
fn test_serialize_class_checks() {
let baml_value: BamlValueWithMeta<Vec<ResponseCheck>> =
BamlValueWithMeta::Class(
"Foo".to_string(),
vec![
("foo".to_string(), BamlValueWithMeta::Int(1, vec![])),
("bar".to_string(), BamlValueWithMeta::String("hi".to_string(), vec![])),
].into_iter().collect(),
vec![
ResponseCheck {
name: "bar_len_lt_foo".to_string(),
expression: "this.bar|length < this.foo".to_string(),
status: "failed".to_string()
}
]
);
let expected = serde_json::json!({
"value": {"foo": 1, "bar": "hi"},
"checks": {
"bar_len_lt_foo": {
"name": "bar_len_lt_foo",
"expression": "this.bar|length < this.foo",
"status": "failed"
}
}
});
let json = serde_json::to_value(baml_value).unwrap();
assert_eq!(json, expected);
}

#[test]
fn test_serialize_nested_class_checks() {

// Prepare an object for wrapping.
let foo: BamlValueWithMeta<Vec<ResponseCheck>> =
BamlValueWithMeta::Class(
"Foo".to_string(),
vec![
("foo".to_string(), BamlValueWithMeta::Int(1, vec![])),
("bar".to_string(), BamlValueWithMeta::String("hi".to_string(), vec![])),
].into_iter().collect(),
vec![
ResponseCheck {
name: "bar_len_lt_foo".to_string(),
expression: "this.bar|length < this.foo".to_string(),
status: "failed".to_string()
}
]
);

// Prepare the top-level value.
let baml_value = BamlValueWithMeta::Class(
"FooWrapper".to_string(),
vec![("foo".to_string(), foo)].into_iter().collect(),
vec![]
);
let expected = serde_json::json!({
"foo": {
"value": {"foo": 1, "bar": "hi"},
"checks": {
"bar_len_lt_foo": {
"name": "bar_len_lt_foo",
"expression": "this.bar|length < this.foo",
"status": "failed"
}
}
}
});
let json = serde_json::to_value(baml_value).unwrap();
assert_eq!(json, expected);
}

}
Loading
Loading