-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sarosi's Recursive Cycles Detection Algorithm 💀💀💀 (#1065)
Progress so far: - Rewrite dependency cycle algorithm (detect all cycles in the graph and report all errors, not just one). - Allow recursive types when they don't enter infinite recursion. - Find all finite recursive cycles and pipeline them through DB -> IR -> Jinja. - Hoist recursive classes in prompt rendering. - Update jsonish parser / coercer to stop recursion when dealing with recursive types. <!-- ELLIPSIS_HIDDEN --> ---- > [!IMPORTANT] > Introduces a new algorithm for detecting recursive cycles, allowing finite recursive types, and updates the JSON parser and tests to support these changes. > > - **Behavior**: > - Rewrites dependency cycle detection to find all cycles and report all errors in `cycle.rs`. > - Allows finite recursive types and hoists recursive classes in prompt rendering in `repr.rs`. > - Updates JSON parser to handle recursion in `coerce_class.rs`. > - **Models**: > - Adds `finite_recursive_cycles` to `IntermediateRepr` in `repr.rs`. > - Updates `OutputFormatContent` in `types.rs` to support recursive classes. > - **Tests**: > - Adds tests for recursive types in `test_class.rs` and `integ-tests.test.ts`. > - Updates integration tests to cover new recursive behavior in `integ-tests.test.ts`. > - **Misc**: > - Updates `Cargo.lock` to use `indexmap 2.2.6`. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for a10e003. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
- Loading branch information
1 parent
8c43e37
commit 8100df9
Showing
48 changed files
with
4,613 additions
and
898 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 95 additions & 80 deletions
175
engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,99 +1,114 @@ | ||
use std::collections::HashSet; | ||
use std::collections::{HashMap, HashSet}; | ||
|
||
use either::Either; | ||
use internal_baml_diagnostics::DatamodelError; | ||
use internal_baml_schema_ast::ast::{TypeExpId, WithIdentifier, WithName, WithSpan}; | ||
use internal_baml_parser_database::Tarjan; | ||
use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan}; | ||
|
||
use crate::validate::validation_pipeline::context::Context; | ||
|
||
/// Validates if the dependency graph contains one or more infinite cycles. | ||
pub(super) fn validate(ctx: &mut Context<'_>) { | ||
// Validates if there's a cycle in any dependency graph. | ||
let mut deps_list = ctx | ||
.db | ||
.walk_classes() | ||
.map(|f| { | ||
( | ||
f.id, | ||
f.dependencies() | ||
.into_iter() | ||
.filter(|f| match ctx.db.find_type_by_str(f) { | ||
Some(either::Either::Left(_cls)) => true, | ||
// Don't worry about enum dependencies, they can't form cycles. | ||
Some(either::Either::Right(_enm)) => false, | ||
None => { | ||
panic!("Unknown class `{}`", f); | ||
} | ||
}) | ||
.collect::<HashSet<_>>(), | ||
) | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
// Now we can check for cycles using topological sort. | ||
let mut stack: Vec<(TypeExpId, Vec<TypeExpId>)> = Vec::new(); // This stack now also keeps track of the path | ||
let mut visited = HashSet::new(); | ||
let mut in_stack = HashSet::new(); | ||
|
||
// Find all items with 0 dependencies | ||
for (id, deps) in &deps_list { | ||
if deps.is_empty() { | ||
stack.push((*id, vec![*id])); | ||
// First, build a graph of all the "required" dependencies represented as an | ||
// adjacency list. We're only going to consider type dependencies that can | ||
// actually cause infinite recursion. Unions and optionals can stop the | ||
// recursion at any point, so they don't have to be part of the "dependency" | ||
// graph because technically an optional field doesn't "depend" on anything, | ||
// it can just be null. | ||
let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| { | ||
let expr_block = &ctx.db.ast()[class.id]; | ||
|
||
// TODO: There's already a hash set that returns "dependencies" in | ||
// the DB, it shoudn't be necessary to traverse all the fields here | ||
// again and build yet another graph, we need to refactor | ||
// .dependencies() or add a new method that returns not only the | ||
// dependency name but also field arity. The arity could be computed at | ||
// the same time as the dependencies hash set. Code is here: | ||
// | ||
// baml-lib/parser-database/src/types/mod.rs | ||
// fn visit_class() | ||
let mut dependencies = HashSet::new(); | ||
|
||
for field in &expr_block.fields { | ||
if let Some(field_type) = &field.expr { | ||
insert_required_deps(class.id, field_type, ctx, &mut dependencies); | ||
} | ||
} | ||
|
||
(class.id, dependencies) | ||
})); | ||
|
||
for component in Tarjan::components(&dependency_graph) { | ||
let cycle = component | ||
.iter() | ||
.map(|id| ctx.db.ast()[*id].name().to_string()) | ||
.collect::<Vec<_>>() | ||
.join(" -> "); | ||
|
||
// TODO: We can push an error for every sinlge class here (that's what | ||
// Rust does), for now it's an error for every cycle found. | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
&format!("These classes form a dependency cycle: {}", cycle), | ||
ctx.db.ast()[component[0]].span().clone(), | ||
)); | ||
} | ||
} | ||
|
||
while let Some((current, path)) = stack.pop() { | ||
let name = ctx.db.ast()[current].name().to_string(); | ||
let span = ctx.db.ast()[current].span(); | ||
|
||
if in_stack.contains(¤t) { | ||
let cycle_start_index = match path.iter().position(|&x| x == current) { | ||
Some(index) => index, | ||
None => { | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
"Cycle start index not found in the path.", | ||
span.clone(), | ||
)); | ||
return; | ||
} | ||
}; | ||
let cycle = path[cycle_start_index..] | ||
.iter() | ||
.map(|&x| ctx.db.ast()[x].name()) | ||
.collect::<Vec<_>>() | ||
.join(" -> "); | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
&format!("These classes form a dependency cycle: {}", cycle), | ||
span.clone(), | ||
)); | ||
return; | ||
/// Inserts all the required dependencies of a field into the given set. | ||
/// | ||
/// Recursively deals with unions of unions. Can be implemented iteratively with | ||
/// a while loop and a stack/queue if this ends up being slow / inefficient or | ||
/// it reaches stack overflows with large inputs. | ||
fn insert_required_deps( | ||
id: TypeExpId, | ||
field: &FieldType, | ||
ctx: &Context<'_>, | ||
deps: &mut HashSet<TypeExpId>, | ||
) { | ||
match field { | ||
FieldType::Symbol(arity, ident, _) if arity.is_required() => { | ||
if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) { | ||
deps.insert(class.id); | ||
} | ||
} | ||
|
||
in_stack.insert(current); | ||
visited.insert(current); | ||
FieldType::Union(arity, field_types, _, _) if arity.is_required() => { | ||
// All the dependencies of the union. | ||
let mut union_deps = HashSet::new(); | ||
|
||
// All the dependencies of a single field in the union. This is | ||
// reused on every iteration of the loop below to avoid allocating | ||
// a new hash set every time. | ||
let mut nested_deps = HashSet::new(); | ||
|
||
for f in field_types { | ||
insert_required_deps(id, f, ctx, &mut nested_deps); | ||
|
||
deps_list.iter_mut().for_each(|(id, deps)| { | ||
if deps.remove(&name) { | ||
// If this item has now 0 dependencies, add it to the stack | ||
if deps.is_empty() { | ||
let mut new_path = path.clone(); | ||
new_path.push(*id); | ||
stack.push((*id, new_path)); | ||
// No nested deps found on this component, this makes the | ||
// union finite, so no need to go deeper. | ||
if nested_deps.is_empty() { | ||
return; | ||
} | ||
} | ||
}); | ||
|
||
in_stack.remove(¤t); | ||
} | ||
// Add the nested deps to the overall union deps and clear the | ||
// iteration hash set. | ||
union_deps.extend(nested_deps.drain()); | ||
} | ||
|
||
// If there are still items left in deps_list after the above steps, there's a cycle | ||
if visited.len() != deps_list.len() { | ||
for (id, _) in &deps_list { | ||
if !visited.contains(id) { | ||
let cls = &ctx.db.ast()[*id]; | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
&format!("These classes form a dependency cycle: {}", cls.name()), | ||
cls.identifier().span().clone(), | ||
)); | ||
// A union does not depend on itself if the field can take other | ||
// values. However, if it only depends on itself, it means we have | ||
// something like this: | ||
// | ||
// class Example { | ||
// field: Example | Example | Example | ||
// } | ||
if union_deps.len() > 1 { | ||
union_deps.remove(&id); | ||
} | ||
|
||
deps.extend(union_deps); | ||
} | ||
|
||
_ => {} | ||
} | ||
} |
Oops, something went wrong.