Skip to content

Commit

Permalink
Block-level constraints (#1124)
Browse files Browse the repository at this point in the history
Adds the ability to specify constraints on types at the block level,
like this:

```
class Foo {
  length_bound int
  baz string
  @@Assert( valid_bound, {{ this.baz|length < this.length_bound }} )
}
```

TODO:
- [x] Integration tests for functions returning top-level block
constraints
- [x] Integration tests for function returning classes with block-level
constraints on nested classes
- [x] Integrations tests for function paramerts with block-level
constraints
- [x] Integration tests for function parameters with nested block-level
constraint fields
 - [x] Documentation (Validations.mdx page and Reference)
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Adds block-level constraints for types, updates constraint handling
logic, and includes integration tests for the new functionality.
> 
>   - **Behavior**:
> - Adds block-level constraints for types, allowing constraints at the
class level.
> - Updates constraint handling logic in `ir_helpers/mod.rs` and
`coercer/field_type.rs`.
> - Adds functions `distribute_constraints`, `type_has_constraints`, and
`type_has_checks` in `ir_helpers/mod.rs`.
>   - **Tests**:
> - Adds integration tests for block-level constraints in
`integ-tests/typescript/tests/integ-tests.test.ts`.
> - Tests include handling of nested block-level constraints and
function parameters with block-level constraints.
>   - **Misc**:
> - Updates `Cargo.toml` and `Cargo.lock` to include `itertools`
dependency.
> - Minor updates to `BamlValueWithMeta` serialization in
`baml_value.rs`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for d21f1aa. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
imalsogreg authored Nov 1, 2024
1 parent b8a221f commit e931acb
Show file tree
Hide file tree
Showing 49 changed files with 1,839 additions and 1,844 deletions.
2 changes: 2 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 104 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
};
use anyhow::Result;
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, LiteralValue, TypeValue};
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue};
pub use to_baml_arg::ArgCoercer;

use super::repr;
Expand Down Expand Up @@ -51,6 +51,18 @@ pub trait IRHelper {
value: BamlValue,
field_type: FieldType,
) -> Result<BamlValueWithMeta<FieldType>>;
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType
) -> (&'a FieldType, Vec<Constraint>);
fn type_has_constraints(
&self,
field_type: &FieldType
) -> bool;
fn type_has_checks(
&self,
field_type: &FieldType
) -> bool;
}

impl IRHelper for IntermediateRepr {
Expand Down Expand Up @@ -365,6 +377,66 @@ impl IRHelper for IntermediateRepr {
}
}
}


/// Constraints may live in several places. A constrained base type stors its
/// constraints by wrapping itself in the `FieldType::Constrained` constructor.
/// Additionally, `FieldType::Class` may have constraints stored in its class node,
/// and `FieldType::Enum` can store constraints in its `Enum` node.
/// And the `FieldType::Constrained` constructor might wrap another
/// `FieldType::Constrained` constructor.
///
/// This function collects constraints for a given type from all these
/// possible sources. Whenever querying a type for its constraints, you
/// should do so with this function, instead of searching manually for all
/// the places that Constraints can live.
fn distribute_constraints<'a>(&'a self, field_type: &'a FieldType) -> (&'a FieldType, Vec<Constraint>) {
match field_type {
FieldType::Class(class_name) => {
match self.find_class(class_name) {
Err(_) => (field_type, Vec::new()),
Ok(class_node) => (field_type, class_node.item.attributes.constraints.clone())
}
}
FieldType::Enum(enum_name) => {
match self.find_enum(enum_name) {
Err(_) => (field_type, Vec::new()),
Ok(enum_node) => (field_type, enum_node.item.attributes.constraints.clone())
}
}
// Check the first level to see if it's constrained.
FieldType::Constrained { base, constraints } => {
match base.as_ref() {
// If so, we must check the second level to see if we need to combine
// constraints across levels.
// The recursion here means that arbitrarily nested `FieldType::Constrained`s
// will be collapsed before the function returns.
FieldType::Constrained { .. } => {
let (sub_base, sub_constraints) = self.distribute_constraints(base.as_ref());
let combined_constraints = vec![constraints.clone(), sub_constraints]
.into_iter()
.flatten()
.collect();
(sub_base, combined_constraints)
}
_ => (base, constraints.clone()),
}
}
_ => (field_type, Vec::new()),
}
}

fn type_has_constraints(&self, field_type: &FieldType) -> bool {
let (_, constraints) = self.distribute_constraints(field_type);
!constraints.is_empty()
}

fn type_has_checks(&self, field_type: &FieldType) -> bool {
let (_, constraints) = self.distribute_constraints(field_type);
constraints
.iter()
.any(|Constraint { level, .. }| *level == ConstraintLevel::Check)
}
}

const UNIT_TYPE: FieldType = FieldType::Tuple(vec![]);
Expand Down Expand Up @@ -686,4 +758,35 @@ mod tests {
let res = ir.check_function_params(&function, &params, arg_coercer);
assert!(res.is_err());
}

#[test]
fn test_nested_constraint_distribution() {
let ir = make_test_ir("").unwrap();
fn mk_constraint(s: &str) -> Constraint {
Constraint {
level: ConstraintLevel::Assert,
expression: JinjaExpression(s.to_string()),
label: Some(s.to_string()),
}
}

let input = FieldType::Constrained {
constraints: vec![mk_constraint("a")],
base: Box::new(FieldType::Constrained {
constraints: vec![mk_constraint("b")],
base: Box::new(FieldType::Constrained {
constraints: vec![mk_constraint("c")],
base: Box::new(FieldType::Primitive(TypeValue::Int)),
}),
}),
};

let expected_base = FieldType::Primitive(TypeValue::Int);
let expected_constraints = vec![mk_constraint("a"), mk_constraint("b"), mk_constraint("c")];

let (base, constraints) = ir.distribute_constraints(&input);

assert_eq!(base, &expected_base);
assert_eq!(constraints, expected_constraints);
}
}
4 changes: 2 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl ArgCoercer {
value: &BamlValue, // original value passed in by user
scope: &mut ScopeStack,
) -> Result<BamlValue, ()> {
let value = match field_type.distribute_constraints() {
let value = match ir.distribute_constraints(field_type) {
(FieldType::Primitive(t), _) => match t {
TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()),
TypeValue::String if self.allow_implicit_cast_to_string => match value {
Expand Down Expand Up @@ -372,7 +372,7 @@ fn first_failing_assert_nested<'a>(
let first_failure = value_with_types
.iter()
.map(|value_node| {
let (_, constraints) = value_node.meta().distribute_constraints();
let (_, constraints) = ir.distribute_constraints(value_node.meta());
constraints
.into_iter()
.filter_map(|c| {
Expand Down
22 changes: 19 additions & 3 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ pub struct NodeAttributes {
#[serde(with = "indexmap::map::serde_seq")]
meta: IndexMap<String, Expression>,

constraints: Vec<Constraint>,
pub constraints: Vec<Constraint>,

// Spans
#[serde(skip)]
Expand Down Expand Up @@ -371,10 +371,26 @@ impl WithRepr<FieldType> for ast::FieldType {
ast::FieldType::Symbol(arity, idn, ..) => type_with_arity(
match db.find_type(idn) {
Some(Either::Left(class_walker)) => {
FieldType::Class(class_walker.name().to_string())
let base_class = FieldType::Class(class_walker.name().to_string());
let maybe_constraints = class_walker.get_constraints(SubType::Class);
match maybe_constraints {
Some(constraints) if constraints.len() > 0 => FieldType::Constrained {
base: Box::new(base_class),
constraints,
},
_ => base_class
}
}
Some(Either::Right(enum_walker)) => {
FieldType::Enum(enum_walker.name().to_string())
let base_type = FieldType::Enum(enum_walker.name().to_string());
let maybe_constraints = enum_walker.get_constraints(SubType::Enum);
match maybe_constraints {
Some(constraints) if constraints.len() > 0 => FieldType::Constrained {
base: Box::new(base_type),
constraints
},
_ => base_type
}
}
None => return Err(anyhow!("Field type uses unresolvable local identifier")),
},
Expand Down
93 changes: 87 additions & 6 deletions engine/baml-lib/baml-types/src/baml_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,13 +625,20 @@ impl Serialize for BamlValueWithMeta<Vec<ResponseCheck>> {
BamlValueWithMeta::Media(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Enum(_enum_name, v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Class(_class_name, v, cr) => {
let mut map = serializer.serialize_map(None)?;
for (key, value) in v {
map.serialize_entry(key, value)?;
if cr.is_empty() {
let mut map = serializer.serialize_map(None)?;
for (key, value) in v {
map.serialize_entry(key, value)?;
}
add_checks(&mut map, cr)?;
map.end()
} else {
let mut checked_value = serializer.serialize_map(Some(2))?;
checked_value.serialize_entry("value", &v)?;
add_checks(&mut checked_value, cr)?;
checked_value.end()
}
add_checks(&mut map, cr)?;
map.end()
}
},
BamlValueWithMeta::Null(cr) => serialize_with_checks(&(), cr, serializer),
}
}
Expand Down Expand Up @@ -707,4 +714,78 @@ mod tests {
assert!(serde_json::to_value(baml_value).is_ok());
assert!(serde_json::to_value(baml_value_2).is_ok());
}

#[test]
fn test_serialize_class_checks() {
let baml_value: BamlValueWithMeta<Vec<ResponseCheck>> =
BamlValueWithMeta::Class(
"Foo".to_string(),
vec![
("foo".to_string(), BamlValueWithMeta::Int(1, vec![])),
("bar".to_string(), BamlValueWithMeta::String("hi".to_string(), vec![])),
].into_iter().collect(),
vec![
ResponseCheck {
name: "bar_len_lt_foo".to_string(),
expression: "this.bar|length < this.foo".to_string(),
status: "failed".to_string()
}
]
);
let expected = serde_json::json!({
"value": {"foo": 1, "bar": "hi"},
"checks": {
"bar_len_lt_foo": {
"name": "bar_len_lt_foo",
"expression": "this.bar|length < this.foo",
"status": "failed"
}
}
});
let json = serde_json::to_value(baml_value).unwrap();
assert_eq!(json, expected);
}

#[test]
fn test_serialize_nested_class_checks() {

// Prepare an object for wrapping.
let foo: BamlValueWithMeta<Vec<ResponseCheck>> =
BamlValueWithMeta::Class(
"Foo".to_string(),
vec![
("foo".to_string(), BamlValueWithMeta::Int(1, vec![])),
("bar".to_string(), BamlValueWithMeta::String("hi".to_string(), vec![])),
].into_iter().collect(),
vec![
ResponseCheck {
name: "bar_len_lt_foo".to_string(),
expression: "this.bar|length < this.foo".to_string(),
status: "failed".to_string()
}
]
);

// Prepare the top-level value.
let baml_value = BamlValueWithMeta::Class(
"FooWrapper".to_string(),
vec![("foo".to_string(), foo)].into_iter().collect(),
vec![]
);
let expected = serde_json::json!({
"foo": {
"value": {"foo": 1, "bar": "hi"},
"checks": {
"bar_len_lt_foo": {
"name": "bar_len_lt_foo",
"expression": "this.bar|length < this.foo",
"status": "failed"
}
}
}
});
let json = serde_json::to_value(baml_value).unwrap();
assert_eq!(json, expected);
}

}
Loading

0 comments on commit e931acb

Please sign in to comment.