Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow structural recursion in type aliases #1207

Merged
merged 26 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e38e5e7
Allow structural recursion
antoniosarosi Dec 2, 2024
794b3f4
Pass structural cycles to IR
antoniosarosi Dec 2, 2024
9869707
Test structural recursion finder
antoniosarosi Dec 2, 2024
81f776a
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 4, 2024
3582b18
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 4, 2024
fab92f5
Merge `antonio/type-aliases`
antoniosarosi Dec 5, 2024
5093eb9
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 5, 2024
c09997a
Merge `antonio/type-aliases`
antoniosarosi Dec 9, 2024
cd5e1f8
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 11, 2024
ba8177a
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 13, 2024
72ea8cb
Implement codegen for Python type aliases
antoniosarosi Dec 16, 2024
140b3dd
Integ test works! Yeah
antoniosarosi Dec 17, 2024
68b98b7
Fix structural cycles rendering
antoniosarosi Dec 17, 2024
d462e5c
Coerce is wonky
antoniosarosi Dec 17, 2024
e0ae448
Fix test `relevant_data_models`
antoniosarosi Dec 17, 2024
abb7430
`is_subtype_of` causing issues with aliases
antoniosarosi Dec 18, 2024
c4e8b85
Fixed `subtype`, `coerce` still doesn't work
antoniosarosi Dec 18, 2024
d6b1e9e
Add integ tests for TS
antoniosarosi Dec 18, 2024
c5267b5
Remove recursion debug limit
antoniosarosi Dec 18, 2024
cac1a16
Add more tests (doesn't work because of score function)
antoniosarosi Dec 18, 2024
39141cb
Add codegen for TS
antoniosarosi Dec 18, 2024
401a97d
Add docs for Ruby type alias
antoniosarosi Dec 18, 2024
fc25050
Fix OpenAPI map keys
antoniosarosi Dec 18, 2024
342fb5e
Fix score of `JsonToString` flag
antoniosarosi Dec 18, 2024
2e10579
Fix integ tests for json type cycle
antoniosarosi Dec 18, 2024
8ff0397
Fix scoring ranking whaterver
antoniosarosi Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,24 @@ impl ArgCoercer {
Err(())
}
},
(FieldType::Alias { resolution, .. }, _) => {
self.coerce_arg(ir, &resolution, value, scope)
// TODO: Is this even possible?
(FieldType::RecursiveTypeAlias(name), _) => {
let mut maybe_coerced = None;
// TODO: Fix this O(n)
for cycle in ir.structural_recursive_alias_cycles().iter() {
if let Some(target) = cycle.get(name) {
maybe_coerced = Some(self.coerce_arg(ir, target, value, scope)?);
break;
}
}

match maybe_coerced {
Some(coerced) => Ok(coerced),
None => {
scope.push_error(format!("Recursive type alias {} not found", name));
Err(())
}
}
}
(FieldType::List(item), _) => match value {
BamlValue::List(arr) => {
Expand Down
4 changes: 3 additions & 1 deletion engine/baml-lib/baml-core/src/ir/json_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ impl WithJsonSchema for FieldType {
FieldType::Class(name) | FieldType::Enum(name) => json!({
"$ref": format!("#/definitions/{}", name),
}),
FieldType::Alias { resolution, .. } => resolution.json_schema(),
FieldType::Literal(v) => json!({
"const": v.to_string(),
}),
FieldType::RecursiveTypeAlias(_) => json!({
"type": ["number", "string", "boolean", "object", "array", "null"]
}),
FieldType::Primitive(t) => match t {
TypeValue::String => json!({
"type": "string",
Expand Down
62 changes: 46 additions & 16 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pub struct IntermediateRepr {
/// Strongly connected components of the dependency graph (finite cycles).
finite_recursive_cycles: Vec<IndexSet<String>>,

/// Type alias cycles introduced by lists and maps.
structural_recursive_alias_cycles: Vec<IndexMap<String, FieldType>>,

configuration: Configuration,
}

Expand All @@ -53,6 +56,7 @@ impl IntermediateRepr {
enums: vec![],
classes: vec![],
finite_recursive_cycles: vec![],
structural_recursive_alias_cycles: vec![],
functions: vec![],
clients: vec![],
retry_policies: vec![],
Expand Down Expand Up @@ -98,6 +102,10 @@ impl IntermediateRepr {
&self.finite_recursive_cycles
}

pub fn structural_recursive_alias_cycles(&self) -> &[IndexMap<String, FieldType>] {
&self.structural_recursive_alias_cycles
}

pub fn walk_enums(&self) -> impl ExactSizeIterator<Item = Walker<'_, &Node<Enum>>> {
self.enums.iter().map(|e| Walker { db: self, item: e })
}
Expand All @@ -106,6 +114,14 @@ impl IntermediateRepr {
self.classes.iter().map(|e| Walker { db: self, item: e })
}

// TODO: Exact size Iterator + Node<>?
pub fn walk_alias_cycles(&self) -> impl Iterator<Item = Walker<'_, (&String, &FieldType)>> {
self.structural_recursive_alias_cycles
.iter()
.flatten()
.map(|e| Walker { db: self, item: e })
}

pub fn function_names(&self) -> impl ExactSizeIterator<Item = &str> {
self.functions.iter().map(|f| f.elem.name())
}
Expand Down Expand Up @@ -168,6 +184,18 @@ impl IntermediateRepr {
.collect()
})
.collect(),
structural_recursive_alias_cycles: {
let mut recursive_aliases = vec![];
for cycle in db.structural_recursive_alias_cycles() {
let mut component = IndexMap::new();
for id in cycle {
let alias = &db.ast()[*id];
component.insert(alias.name().to_string(), alias.value.repr(db)?);
}
recursive_aliases.push(component);
}
recursive_aliases
},
functions: db
.walk_functions()
.map(|e| e.node(db))
Expand Down Expand Up @@ -419,11 +447,18 @@ impl WithRepr<FieldType> for ast::FieldType {
_ => base_type,
}
}
Some(TypeWalker::TypeAlias(alias_walker)) => FieldType::Alias {
name: alias_walker.name().to_owned(),
target: Box::new(alias_walker.target().repr(db)?),
resolution: Box::new(alias_walker.resolved().repr(db)?),
},
Some(TypeWalker::TypeAlias(alias_walker)) => {
if db
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| cycle.contains(&alias_walker.id))
{
FieldType::RecursiveTypeAlias(alias_walker.name().to_string())
} else {
alias_walker.resolved().to_owned().repr(db)?
}
}

None => return Err(anyhow!("Field type uses unresolvable local identifier")),
},
arity,
Expand Down Expand Up @@ -1224,11 +1259,7 @@ mod tests {
let class = ir.find_class("Test").unwrap();
let alias = class.find_field("field").unwrap();

let FieldType::Alias { resolution, .. } = alias.r#type() else {
panic!("expected alias type, found {:?}", alias.r#type());
};

assert_eq!(**resolution, FieldType::Primitive(TypeValue::Int));
assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int));
}

#[test]
Expand All @@ -1249,12 +1280,11 @@ mod tests {
let class = ir.find_class("Test").unwrap();
let alias = class.find_field("field").unwrap();

let FieldType::Alias { resolution, .. } = alias.r#type() else {
panic!("expected alias type, found {:?}", alias.r#type());
};

let FieldType::Constrained { base, constraints } = &**resolution else {
panic!("expected resolved constrained type, found {:?}", resolution);
let FieldType::Constrained { base, constraints } = alias.r#type() else {
panic!(
"expected resolved constrained type, found {:?}",
alias.r#type()
);
};

assert_eq!(constraints.len(), 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,40 @@ use crate::validate::validation_pipeline::context::Context;

/// Validates if the dependency graph contains one or more infinite cycles.
pub(super) fn validate(ctx: &mut Context<'_>) {
// Solve cycles first. We need that information in case a class points to
// an unresolveble type alias.
let type_aliases_components = report_infinite_cycles(
&ctx.db.type_alias_dependencies(),
// We'll check type alias cycles first. Just like Typescript, cycles are
// allowed only for maps and lists. We'll call such cycles "structural
// recursion". Anything else like nulls or unions won't terminate a cycle.
let structural_type_aliases = HashMap::from_iter(ctx.db.walk_type_aliases().map(|alias| {
let mut dependencies = HashSet::new();
insert_required_alias_deps(alias.target(), ctx, &mut dependencies);

(alias.id, dependencies)
}));

// Based on the graph we've built with does not include the edges created
// by maps and lists, check the cycles and report them.
report_infinite_cycles(
&structural_type_aliases,
ctx,
"These aliases form a dependency cycle",
);

// Store this locally to pass refs to the insert_required_deps function.
let alias_cycles = type_aliases_components.iter().flatten().copied().collect();
// In order to avoid infinite recursion when resolving types for class
// dependencies below, we'll compute the cycles of aliases including maps
// and lists so that the recursion can be stopped before entering a cycle.
let complete_alias_cycles = Tarjan::components(ctx.db.type_alias_dependencies())
.iter()
.flatten()
.copied()
.collect();

// First, build a graph of all the "required" dependencies represented as an
// Now build a graph of all the "required" dependencies represented as an
// adjacency list. We're only going to consider type dependencies that can
// actually cause infinite recursion. Unions and optionals can stop the
// recursion at any point, so they don't have to be part of the "dependency"
// graph because technically an optional field doesn't "depend" on anything,
// it can just be null.
let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| {
let class_dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| {
let expr_block = &ctx.db.ast()[class.id];

// TODO: There's already a hash set that returns "dependencies" in
Expand All @@ -47,15 +63,21 @@ pub(super) fn validate(ctx: &mut Context<'_>) {

for field in &expr_block.fields {
if let Some(field_type) = &field.expr {
insert_required_deps(class.id, field_type, ctx, &mut dependencies, &alias_cycles);
insert_required_class_deps(
class.id,
field_type,
ctx,
&mut dependencies,
&complete_alias_cycles,
);
}
}

(class.id, dependencies)
}));

report_infinite_cycles(
&dependency_graph,
&class_dependency_graph,
ctx,
"These classes form a dependency cycle",
);
Expand Down Expand Up @@ -103,7 +125,7 @@ where
/// it reaches stack overflows with large inputs.
///
/// TODO: Use a struct to keep all this state. Too many parameters already.
fn insert_required_deps(
fn insert_required_class_deps(
id: TypeExpId,
field: &FieldType,
ctx: &Context<'_>,
Expand Down Expand Up @@ -139,7 +161,7 @@ fn insert_required_deps(
// We also have to stop recursion if we know the alias is
// part of a cycle.
if !alias_cycles.contains(&alias.id) {
insert_required_deps(id, alias.target(), ctx, deps, alias_cycles)
insert_required_class_deps(id, alias.target(), ctx, deps, alias_cycles)
}
}
_ => {}
Expand All @@ -156,7 +178,7 @@ fn insert_required_deps(
let mut nested_deps = HashSet::new();

for f in field_types {
insert_required_deps(id, f, ctx, &mut nested_deps, alias_cycles);
insert_required_class_deps(id, f, ctx, &mut nested_deps, alias_cycles);

// No nested deps found on this component, this makes the
// union finite, so no need to go deeper.
Expand Down Expand Up @@ -186,3 +208,26 @@ fn insert_required_deps(
_ => {}
}
}

/// Implemented a la TS, maps and lists are not included as edges.
fn insert_required_alias_deps(
field_type: &FieldType,
ctx: &Context<'_>,
required: &mut HashSet<TypeAliasId>,
) {
match field_type {
FieldType::Symbol(_, ident, _) => {
if let Some(TypeWalker::TypeAlias(alias)) = ctx.db.find_type_by_str(ident.name()) {
required.insert(alias.id);
}
}

FieldType::Union(_, field_types, ..) | FieldType::Tuple(_, field_types, ..) => {
for f in field_types {
insert_required_alias_deps(f, ctx, required);
}
}

_ => {}
}
}
22 changes: 10 additions & 12 deletions engine/baml-lib/baml-types/src/field_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,7 @@ pub enum FieldType {
Union(Vec<FieldType>),
Tuple(Vec<FieldType>),
Optional(Box<FieldType>),
Alias {
/// Name of the alias.
name: String,
/// Type that the alias points to.
target: Box<FieldType>,
/// Final resolved type (an alias can point to other aliases).
resolution: Box<FieldType>,
},
RecursiveTypeAlias(String),
Constrained {
base: Box<FieldType>,
constraints: Vec<Constraint>,
Expand All @@ -103,8 +96,9 @@ pub enum FieldType {
impl std::fmt::Display for FieldType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldType::Enum(name) | FieldType::Class(name) => write!(f, "{name}"),
FieldType::Alias { name, .. } => write!(f, "{name}"),
FieldType::Enum(name)
| FieldType::Class(name)
| FieldType::RecursiveTypeAlias(name) => write!(f, "{name}"),
FieldType::Primitive(t) => write!(f, "{t}"),
FieldType::Literal(v) => write!(f, "{v}"),
FieldType::Union(choices) => {
Expand Down Expand Up @@ -187,8 +181,6 @@ impl FieldType {
}

match (self, other) {
(FieldType::Alias { resolution, .. }, _) => resolution.is_subtype_of(other),
(_, FieldType::Alias { resolution, .. }) => self.is_subtype_of(resolution),
(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(self_item), FieldType::Optional(other_item)) => {
self_item.is_subtype_of(other_item)
Expand All @@ -207,6 +199,12 @@ impl FieldType {
}
(FieldType::Map(_, _), _) => false,

// TODO: is it necessary to check if the alias is part of the same
// cycle?
(FieldType::RecursiveTypeAlias(_), _) | (_, FieldType::RecursiveTypeAlias(_)) => {
self == other
}

(
FieldType::Constrained {
base: self_base,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ type Map = map<string, Map>
// 50 |
// 51 | type EnterCycle = NoStop
// |
// error: Error validating: These aliases form a dependency cycle: Map
// --> class/recursive_type_aliases.baml:56
// |
// 55 | // RecursiveMap
// 56 | type Map = map<string, Map>
// |
// error: Error validating: These classes form a dependency cycle: Recursive
// --> class/recursive_type_aliases.baml:22
// |
Expand Down
Loading
Loading