Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow structural recursion in type aliases #1207

Merged
merged 26 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e38e5e7
Allow structural recursion
antoniosarosi Dec 2, 2024
794b3f4
Pass structural cycles to IR
antoniosarosi Dec 2, 2024
9869707
Test structural recursion finder
antoniosarosi Dec 2, 2024
81f776a
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 4, 2024
3582b18
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 4, 2024
fab92f5
Merge `antonio/type-aliases`
antoniosarosi Dec 5, 2024
5093eb9
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 5, 2024
c09997a
Merge `antonio/type-aliases`
antoniosarosi Dec 9, 2024
cd5e1f8
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 11, 2024
ba8177a
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 13, 2024
72ea8cb
Implement codegen for Python type aliases
antoniosarosi Dec 16, 2024
140b3dd
Integ test works! Yeah
antoniosarosi Dec 17, 2024
68b98b7
Fix structural cycles rendering
antoniosarosi Dec 17, 2024
d462e5c
Coerce is wonky
antoniosarosi Dec 17, 2024
e0ae448
Fix test `relevant_data_models`
antoniosarosi Dec 17, 2024
abb7430
`is_subtype_of` causing issues with aliases
antoniosarosi Dec 18, 2024
c4e8b85
Fixed `subtype`, `coerce` still doesn't work
antoniosarosi Dec 18, 2024
d6b1e9e
Add integ tests for TS
antoniosarosi Dec 18, 2024
c5267b5
Remove recursion debug limit
antoniosarosi Dec 18, 2024
cac1a16
Add more tests (doesn't work because of score function)
antoniosarosi Dec 18, 2024
39141cb
Add codegen for TS
antoniosarosi Dec 18, 2024
401a97d
Add docs for Ruby type alias
antoniosarosi Dec 18, 2024
fc25050
Fix OpenAPI map keys
antoniosarosi Dec 18, 2024
342fb5e
Fix score of `JsonToString` flag
antoniosarosi Dec 18, 2024
2e10579
Fix integ tests for json type cycle
antoniosarosi Dec 18, 2024
8ff0397
Fix scoring ranking whaterver
antoniosarosi Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pub struct IntermediateRepr {
/// Strongly connected components of the dependency graph (finite cycles).
finite_recursive_cycles: Vec<IndexSet<String>>,

/// Type alias cycles introduced by lists and maps.
structural_recursive_alias_cycles: Vec<IndexSet<String>>,

configuration: Configuration,
}

Expand All @@ -53,6 +56,7 @@ impl IntermediateRepr {
enums: vec![],
classes: vec![],
finite_recursive_cycles: vec![],
structural_recursive_alias_cycles: vec![],
functions: vec![],
clients: vec![],
retry_policies: vec![],
Expand Down Expand Up @@ -174,6 +178,15 @@ impl IntermediateRepr {
.collect()
})
.collect(),
structural_recursive_alias_cycles: db
.structural_recursive_alias_cycles()
.iter()
.map(|ids| {
ids.iter()
.map(|id| db.ast()[*id].name().to_string())
.collect()
})
.collect(),
functions: db
.walk_functions()
.map(|e| e.node(db))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,40 @@ use crate::validate::validation_pipeline::context::Context;

/// Validates if the dependency graph contains one or more infinite cycles.
pub(super) fn validate(ctx: &mut Context<'_>) {
// Solve cycles first. We need that information in case a class points to
// an unresolveble type alias.
let alias_cycles = report_infinite_cycles(
&ctx.db.type_alias_dependencies(),
// We'll check type alias cycles first. Just like Typescript, cycles are
// allowed only for maps and lists. We'll call such cycles "structural
// recursion". Anything else like nulls or unions won't terminate a cycle.
let structural_type_aliases = HashMap::from_iter(ctx.db.walk_type_aliases().map(|alias| {
let mut dependencies = HashSet::new();
insert_required_alias_deps(alias.target(), ctx, &mut dependencies);

(alias.id, dependencies)
}));

// Based on the graph we've built with does not include the edges created
// by maps and lists, check the cycles and report them.
report_infinite_cycles(
&structural_type_aliases,
ctx,
"These aliases form a dependency cycle",
);

// First, build a graph of all the "required" dependencies represented as an
// In order to avoid infinite recursion when resolving types for class
// dependencies below, we'll compute the cycles of aliases including maps
// and lists so that the recursion can be stopped before entering a cycle.
let complete_alias_cycles = Tarjan::components(ctx.db.type_alias_dependencies())
.iter()
.flatten()
.copied()
.collect();

// Now build a graph of all the "required" dependencies represented as an
// adjacency list. We're only going to consider type dependencies that can
// actually cause infinite recursion. Unions and optionals can stop the
// recursion at any point, so they don't have to be part of the "dependency"
// graph because technically an optional field doesn't "depend" on anything,
// it can just be null.
let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| {
let class_dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| {
let expr_block = &ctx.db.ast()[class.id];

// TODO: There's already a hash set that returns "dependencies" in
Expand All @@ -44,12 +63,12 @@ pub(super) fn validate(ctx: &mut Context<'_>) {

for field in &expr_block.fields {
if let Some(field_type) = &field.expr {
insert_required_deps(
insert_required_class_deps(
class.id,
field_type,
ctx,
&mut dependencies,
&alias_cycles.iter().flatten().copied().collect(),
&complete_alias_cycles,
);
}
}
Expand All @@ -58,7 +77,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
}));

report_infinite_cycles(
&dependency_graph,
&class_dependency_graph,
ctx,
"These classes form a dependency cycle",
);
Expand Down Expand Up @@ -106,7 +125,7 @@ where
/// it reaches stack overflows with large inputs.
///
/// TODO: Use a struct to keep all this state. Too many parameters already.
fn insert_required_deps(
fn insert_required_class_deps(
id: TypeExpId,
field: &FieldType,
ctx: &Context<'_>,
Expand Down Expand Up @@ -142,7 +161,7 @@ fn insert_required_deps(
// We also have to stop recursion if we know the alias is
// part of a cycle.
if !alias_cycles.contains(&alias.id) {
insert_required_deps(id, alias.target(), ctx, deps, alias_cycles)
insert_required_class_deps(id, alias.target(), ctx, deps, alias_cycles)
}
}
_ => {}
Expand All @@ -159,7 +178,7 @@ fn insert_required_deps(
let mut nested_deps = HashSet::new();

for f in field_types {
insert_required_deps(id, f, ctx, &mut nested_deps, alias_cycles);
insert_required_class_deps(id, f, ctx, &mut nested_deps, alias_cycles);

// No nested deps found on this component, this makes the
// union finite, so no need to go deeper.
Expand Down Expand Up @@ -189,3 +208,26 @@ fn insert_required_deps(
_ => {}
}
}

/// Implemented a la TS, maps and lists are not included as edges.
fn insert_required_alias_deps(
field_type: &FieldType,
ctx: &Context<'_>,
required: &mut HashSet<TypeAliasId>,
) {
match field_type {
FieldType::Symbol(_, ident, _) => {
if let Some(TypeWalker::TypeAlias(alias)) = ctx.db.find_type_by_str(ident.name()) {
required.insert(alias.id);
}
}

FieldType::Union(_, field_types, ..) | FieldType::Tuple(_, field_types, ..) => {
for f in field_types {
insert_required_alias_deps(f, ctx, required);
}
}

_ => {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ type Map = map<string, Map>
// 50 |
// 51 | type EnterCycle = NoStop
// |
// error: Error validating: These aliases form a dependency cycle: Map
// --> class/recursive_type_aliases.baml:56
// |
// 55 | // RecursiveMap
// 56 | type Map = map<string, Map>
// |
// error: Error validating: These classes form a dependency cycle: Recursive
// --> class/recursive_type_aliases.baml:22
// |
Expand Down
62 changes: 41 additions & 21 deletions engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ impl ParserDatabase {
self.types.resolved_type_aliases.insert(*alias_id, resolved);
}

// Cycles left here after cycle validation are allowed. Basically lists
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here is misleading. It states that cycles left after validation are allowed, but the code actually rebuilds the cycles for rendering purposes. Consider updating the comment to reflect the actual behavior.

// and maps can introduce cycles.
self.types.structural_recursive_alias_cycles =
Tarjan::components(&self.types.type_alias_dependencies);

// NOTE: Class dependency cycles are already checked at
// baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs
//
Expand All @@ -159,27 +164,42 @@ impl ParserDatabase {
// instead of strings (class names). That requires less conversions when
// working with the graph. Once the work is done, IDs can be converted
// to names where needed.
let finite_cycles = Tarjan::components(&HashMap::from_iter(
self.types.class_dependencies.iter().map(|(id, deps)| {
let deps =
HashSet::from_iter(deps.iter().filter_map(
|dep| match self.find_type_by_str(dep) {
Some(TypeWalker::Class(cls)) => Some(cls.id),
Some(TypeWalker::Enum(_)) => None,
Some(TypeWalker::TypeAlias(_)) => None,
None => panic!("Unknown class `{dep}`"),
},
));
(*id, deps)
}),
));
let mut resolved_dependency_graph = HashMap::new();

for (id, deps) in self.types.class_dependencies.iter() {
let mut resolved_deps = HashSet::new();

for dep in deps {
match self.find_type_by_str(dep) {
Some(TypeWalker::Class(cls)) => {
resolved_deps.insert(cls.id);
}
Some(TypeWalker::Enum(_)) => {}
// Gotta resolve type aliases.
Some(TypeWalker::TypeAlias(alias)) => {
resolved_deps.extend(alias.resolved().flat_idns().iter().map(|ident| {
match self.find_type_by_str(ident.name()) {
Some(TypeWalker::Class(cls)) => cls.id,
Some(TypeWalker::Enum(_)) => {
panic!("Enums are not allowed in type aliases")
}
Some(TypeWalker::TypeAlias(alias)) => {
panic!("Alias should be resolved at this point")
}
None => panic!("Unknown class `{dep}`"),
}
}))
}
None => panic!("Unknown class `{dep}`"),
}
}

resolved_dependency_graph.insert(*id, resolved_deps);
}

// Inject finite cycles into parser DB. This will then be passed into
// the IR and then into the Jinja output format.
self.types.finite_recursive_cycles = finite_cycles
.into_iter()
.map(|cycle| cycle.into_iter().collect())
.collect();
// Find the cycles and inject them into parser DB. This will then be
// passed into the IR and then into the Jinja output format.
self.types.finite_recursive_cycles = Tarjan::components(&resolved_dependency_graph);

// Fully resolve function dependencies.
let extends = self
Expand Down Expand Up @@ -308,7 +328,7 @@ mod test {
}

fn assert_finite_cycles(baml: &'static str, expected: &[&[&str]]) -> Result<(), Diagnostics> {
let mut db = parse(baml)?;
let db = parse(baml)?;

assert_eq!(
db.finite_recursive_cycles()
Expand Down
7 changes: 7 additions & 0 deletions engine/baml-lib/parser-database/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ pub(super) struct Types {
/// Merge-Find Set or something like that.
pub(super) finite_recursive_cycles: Vec<Vec<ast::TypeExpId>>,

/// Contains recursive type aliases.
///
/// Recursive type aliases are a little bit trickier than recursive classes
/// because the termination condition is tied to lists and maps only. Nulls
/// and unions won't allow type alias cycles to be resolved.
pub(super) structural_recursive_alias_cycles: Vec<Vec<ast::TypeAliasId>>,

pub(super) function: HashMap<ast::ValExpId, FunctionType>,

pub(super) client_properties: HashMap<ast::ValExpId, ClientProperties>,
Expand Down
23 changes: 22 additions & 1 deletion engine/baml-lib/parser-database/src/walkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ pub use configuration::*;
use either::Either;
pub use field::*;
pub use function::FunctionWalker;
use internal_baml_schema_ast::ast::{FieldType, Identifier, TopId, TypeExpId, WithName};
use internal_baml_schema_ast::ast::{
FieldType, Identifier, TopId, TypeAliasId, TypeExpId, WithName,
};
pub use r#class::*;
pub use r#enum::*;
pub use template_string::TemplateStringWalker;
Expand Down Expand Up @@ -142,6 +144,14 @@ impl<'db> crate::ParserDatabase {
&self.types.finite_recursive_cycles
}

/// Set of all aliases that are part of a structural cycle.
///
/// A structural cycle is created through a map or list, which introduce one
/// level of indirection.
pub fn structural_recursive_alias_cycles(&self) -> &[Vec<TypeAliasId>] {
&self.types.structural_recursive_alias_cycles
}

/// Returns the resolved aliases map.
pub fn resolved_type_alias_by_name(&self, alias: &str) -> Option<&FieldType> {
match self.find_type_by_str(alias) {
Expand Down Expand Up @@ -209,6 +219,17 @@ impl<'db> crate::ParserDatabase {
})
}

/// Walk all the type aliases in the AST.
pub fn walk_type_aliases(&self) -> impl Iterator<Item = TypeAliasWalker<'_>> {
self.ast()
.iter_tops()
.filter_map(|(top_id, _)| top_id.as_type_alias_id())
.map(move |top_id| Walker {
db: self,
id: top_id,
})
}

/// Walk all template strings in the schema.
pub fn walk_templates(&self) -> impl Iterator<Item = TemplateStringWalker<'_>> {
self.ast()
Expand Down
8 changes: 8 additions & 0 deletions engine/baml-lib/schema-ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ impl TopId {
}
}

/// Try to interpret the top as a type alias.
pub fn as_type_alias_id(self) -> Option<TypeAliasId> {
match self {
TopId::TypeAlias(id) => Some(id),
_ => None,
}
}

/// Try to interpret the top as a function.
pub fn as_function_id(self) -> Option<ValExpId> {
match self {
Expand Down
Loading