Skip to content

Commit

Permalink
Sarosi's Recursive Cycles Detection Algorithm 💀💀💀 (#1065)
Browse files Browse the repository at this point in the history
Progress so far:

- Rewrite dependency cycle algorithm (detect all cycles in the graph and
report all errors, not just one).
- Allow recursive types when they don't enter infinite recursion.
- Find all finite recursive cycles and pipeline them through DB -> IR ->
Jinja.
- Hoist recursive classes in prompt rendering.
- Update jsonish parser / coercer to stop recursion when dealing with
recursive types.
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Introduces a new algorithm for detecting recursive cycles, allowing
finite recursive types, and updates the JSON parser and tests to support
these changes.
> 
>   - **Behavior**:
> - Rewrites dependency cycle detection to find all cycles and report
all errors in `cycle.rs`.
> - Allows finite recursive types and hoists recursive classes in prompt
rendering in `repr.rs`.
>     - Updates JSON parser to handle recursion in `coerce_class.rs`.
>   - **Models**:
> - Adds `finite_recursive_cycles` to `IntermediateRepr` in `repr.rs`.
> - Updates `OutputFormatContent` in `types.rs` to support recursive
classes.
>   - **Tests**:
> - Adds tests for recursive types in `test_class.rs` and
`integ-tests.test.ts`.
> - Updates integration tests to cover new recursive behavior in
`integ-tests.test.ts`.
>   - **Misc**:
>     - Updates `Cargo.lock` to use `indexmap 2.2.6`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for a10e003. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Nov 10, 2024
1 parent 8c43e37 commit 8100df9
Show file tree
Hide file tree
Showing 48 changed files with 4,613 additions and 898 deletions.
3 changes: 2 additions & 1 deletion engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 52 additions & 15 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashSet;
use anyhow::{anyhow, Result};
use baml_types::{Constraint, ConstraintLevel, FieldType};
use either::Either;
use indexmap::IndexMap;
use indexmap::{IndexMap, IndexSet};
use internal_baml_parser_database::{
walkers::{
ClassWalker, ClientSpec as AstClientSpec, ClientWalker, ConfigurationWalker,
Expand All @@ -27,6 +27,8 @@ use crate::Configuration;
pub struct IntermediateRepr {
enums: Vec<Node<Enum>>,
classes: Vec<Node<Class>>,
/// Strongly connected components of the dependency graph (finite cycles).
finite_recursive_cycles: Vec<IndexSet<String>>,
functions: Vec<Node<Function>>,
clients: Vec<Node<Client>>,
retry_policies: Vec<Node<RetryPolicy>>,
Expand All @@ -50,6 +52,7 @@ impl IntermediateRepr {
IntermediateRepr {
enums: vec![],
classes: vec![],
finite_recursive_cycles: vec![],
functions: vec![],
clients: vec![],
retry_policies: vec![],
Expand All @@ -72,6 +75,14 @@ impl IntermediateRepr {
.collect::<HashSet<&str>>()
}

/// Returns a list of all the recursive cycles in the IR.
///
/// Each cycle is represented as a set of strings, where each string is the
/// name of a class.
pub fn finite_recursive_cycles(&self) -> &[IndexSet<String>] {
&self.finite_recursive_cycles
}

pub fn walk_enums<'a>(&'a self) -> impl ExactSizeIterator<Item = Walker<'a, &'a Node<Enum>>> {
self.enums.iter().map(|e| Walker { db: self, item: e })
}
Expand Down Expand Up @@ -139,6 +150,15 @@ impl IntermediateRepr {
.walk_classes()
.map(|e| e.node(db))
.collect::<Result<Vec<_>>>()?,
finite_recursive_cycles: db
.finite_recursive_cycles()
.iter()
.map(|ids| {
ids.iter()
.map(|id| db.ast()[*id].name().to_string())
.collect()
})
.collect(),
functions: db
.walk_functions()
.map(|e| e.node(db))
Expand Down Expand Up @@ -312,7 +332,6 @@ fn type_with_arity(t: FieldType, arity: &FieldArity) -> FieldType {
}

impl WithRepr<FieldType> for ast::FieldType {

// TODO: (Greg) This code only extracts constraints, and ignores any
// other types of attributes attached to the type directly.
fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes {
Expand All @@ -323,20 +342,29 @@ impl WithRepr<FieldType> for ast::FieldType {
let level = match attr.name.to_string().as_str() {
"assert" => Some(ConstraintLevel::Assert),
"check" => Some(ConstraintLevel::Check),
_ => None
_ => None,
}?;
let (label, expression) = match attr.arguments.arguments.as_slice() {
[arg1, arg2] => match (arg1.clone().value, arg2.clone().value) {
(ast::Expression::Identifier(ast::Identifier::Local(s, _)), ast::Expression::JinjaExpressionValue(j,_)) => Some((Some(s), j)),
_ => None
(
ast::Expression::Identifier(ast::Identifier::Local(s, _)),
ast::Expression::JinjaExpressionValue(j, _),
) => Some((Some(s), j)),
_ => None,
},
[arg1] => match arg1.clone().value {
ast::Expression::JinjaExpressionValue(JinjaExpression(j),_) => Some((None, JinjaExpression(j.clone()))),
_ => None
}
ast::Expression::JinjaExpressionValue(JinjaExpression(j), _) => {
Some((None, JinjaExpression(j.clone())))
}
_ => None,
},
_ => None,
}?;
Some(Constraint{ level, expression, label })
Some(Constraint {
level,
expression,
label,
})
})
.collect::<Vec<Constraint>>();
let attributes = NodeAttributes {
Expand Down Expand Up @@ -438,7 +466,10 @@ impl WithRepr<FieldType> for ast::FieldType {
};

let with_constraints = if has_constraints {
FieldType::Constrained { base: Box::new(base.clone()), constraints }
FieldType::Constrained {
base: Box::new(base.clone()),
constraints,
}
} else {
base
};
Expand Down Expand Up @@ -1128,18 +1159,24 @@ impl WithRepr<Prompt> for PromptAst<'_> {
/// Generate an IntermediateRepr from a single block of BAML source code.
/// This is useful for generating IR test fixtures.
pub fn make_test_ir(source_code: &str) -> anyhow::Result<IntermediateRepr> {
use std::path::PathBuf;
use internal_baml_diagnostics::SourceFile;
use crate::ValidatedSchema;
use crate::validate;
use crate::ValidatedSchema;
use internal_baml_diagnostics::SourceFile;
use std::path::PathBuf;

let path: PathBuf = "fake_file.baml".into();
let source_file: SourceFile = (path.clone(), source_code).into();
let validated_schema: ValidatedSchema = validate(&path, vec![source_file]);
let diagnostics = &validated_schema.diagnostics;
if diagnostics.has_errors() {
return Err(anyhow::anyhow!("Source code was invalid: \n{:?}", diagnostics.errors()))
return Err(anyhow::anyhow!(
"Source code was invalid: \n{:?}",
diagnostics.errors()
));
}
let ir = IntermediateRepr::from_parser_database(&validated_schema.db, validated_schema.configuration)?;
let ir = IntermediateRepr::from_parser_database(
&validated_schema.db,
validated_schema.configuration,
)?;
Ok(ir)
}
Original file line number Diff line number Diff line change
@@ -1,99 +1,114 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use either::Either;
use internal_baml_diagnostics::DatamodelError;
use internal_baml_schema_ast::ast::{TypeExpId, WithIdentifier, WithName, WithSpan};
use internal_baml_parser_database::Tarjan;
use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan};

use crate::validate::validation_pipeline::context::Context;

/// Validates if the dependency graph contains one or more infinite cycles.
pub(super) fn validate(ctx: &mut Context<'_>) {
// Validates if there's a cycle in any dependency graph.
let mut deps_list = ctx
.db
.walk_classes()
.map(|f| {
(
f.id,
f.dependencies()
.into_iter()
.filter(|f| match ctx.db.find_type_by_str(f) {
Some(either::Either::Left(_cls)) => true,
// Don't worry about enum dependencies, they can't form cycles.
Some(either::Either::Right(_enm)) => false,
None => {
panic!("Unknown class `{}`", f);
}
})
.collect::<HashSet<_>>(),
)
})
.collect::<Vec<_>>();

// Now we can check for cycles using topological sort.
let mut stack: Vec<(TypeExpId, Vec<TypeExpId>)> = Vec::new(); // This stack now also keeps track of the path
let mut visited = HashSet::new();
let mut in_stack = HashSet::new();

// Find all items with 0 dependencies
for (id, deps) in &deps_list {
if deps.is_empty() {
stack.push((*id, vec![*id]));
// First, build a graph of all the "required" dependencies represented as an
// adjacency list. We're only going to consider type dependencies that can
// actually cause infinite recursion. Unions and optionals can stop the
// recursion at any point, so they don't have to be part of the "dependency"
// graph because technically an optional field doesn't "depend" on anything,
// it can just be null.
let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| {
let expr_block = &ctx.db.ast()[class.id];

// TODO: There's already a hash set that returns "dependencies" in
// the DB, it shoudn't be necessary to traverse all the fields here
// again and build yet another graph, we need to refactor
// .dependencies() or add a new method that returns not only the
// dependency name but also field arity. The arity could be computed at
// the same time as the dependencies hash set. Code is here:
//
// baml-lib/parser-database/src/types/mod.rs
// fn visit_class()
let mut dependencies = HashSet::new();

for field in &expr_block.fields {
if let Some(field_type) = &field.expr {
insert_required_deps(class.id, field_type, ctx, &mut dependencies);
}
}

(class.id, dependencies)
}));

for component in Tarjan::components(&dependency_graph) {
let cycle = component
.iter()
.map(|id| ctx.db.ast()[*id].name().to_string())
.collect::<Vec<_>>()
.join(" -> ");

// TODO: We can push an error for every sinlge class here (that's what
// Rust does), for now it's an error for every cycle found.
ctx.push_error(DatamodelError::new_validation_error(
&format!("These classes form a dependency cycle: {}", cycle),
ctx.db.ast()[component[0]].span().clone(),
));
}
}

while let Some((current, path)) = stack.pop() {
let name = ctx.db.ast()[current].name().to_string();
let span = ctx.db.ast()[current].span();

if in_stack.contains(&current) {
let cycle_start_index = match path.iter().position(|&x| x == current) {
Some(index) => index,
None => {
ctx.push_error(DatamodelError::new_validation_error(
"Cycle start index not found in the path.",
span.clone(),
));
return;
}
};
let cycle = path[cycle_start_index..]
.iter()
.map(|&x| ctx.db.ast()[x].name())
.collect::<Vec<_>>()
.join(" -> ");
ctx.push_error(DatamodelError::new_validation_error(
&format!("These classes form a dependency cycle: {}", cycle),
span.clone(),
));
return;
/// Inserts all the required dependencies of a field into the given set.
///
/// Recursively deals with unions of unions. Can be implemented iteratively with
/// a while loop and a stack/queue if this ends up being slow / inefficient or
/// it reaches stack overflows with large inputs.
fn insert_required_deps(
id: TypeExpId,
field: &FieldType,
ctx: &Context<'_>,
deps: &mut HashSet<TypeExpId>,
) {
match field {
FieldType::Symbol(arity, ident, _) if arity.is_required() => {
if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) {
deps.insert(class.id);
}
}

in_stack.insert(current);
visited.insert(current);
FieldType::Union(arity, field_types, _, _) if arity.is_required() => {
// All the dependencies of the union.
let mut union_deps = HashSet::new();

// All the dependencies of a single field in the union. This is
// reused on every iteration of the loop below to avoid allocating
// a new hash set every time.
let mut nested_deps = HashSet::new();

for f in field_types {
insert_required_deps(id, f, ctx, &mut nested_deps);

deps_list.iter_mut().for_each(|(id, deps)| {
if deps.remove(&name) {
// If this item has now 0 dependencies, add it to the stack
if deps.is_empty() {
let mut new_path = path.clone();
new_path.push(*id);
stack.push((*id, new_path));
// No nested deps found on this component, this makes the
// union finite, so no need to go deeper.
if nested_deps.is_empty() {
return;
}
}
});

in_stack.remove(&current);
}
// Add the nested deps to the overall union deps and clear the
// iteration hash set.
union_deps.extend(nested_deps.drain());
}

// If there are still items left in deps_list after the above steps, there's a cycle
if visited.len() != deps_list.len() {
for (id, _) in &deps_list {
if !visited.contains(id) {
let cls = &ctx.db.ast()[*id];
ctx.push_error(DatamodelError::new_validation_error(
&format!("These classes form a dependency cycle: {}", cls.name()),
cls.identifier().span().clone(),
));
// A union does not depend on itself if the field can take other
// values. However, if it only depends on itself, it means we have
// something like this:
//
// class Example {
// field: Example | Example | Example
// }
if union_deps.len() > 1 {
union_deps.remove(&id);
}

deps.extend(union_deps);
}

_ => {}
}
}
Loading

0 comments on commit 8100df9

Please sign in to comment.