Skip to content

Commit

Permalink
Fix recursive aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Nov 27, 2024
1 parent dfc3f45 commit 273f13d
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
ctx.db.ast()[component[0]].span().clone(),
));
}

// TODO: Extract this into some generic function.
eprintln!("Type aliases: {:?}", ctx.db.type_aliases());
for component in Tarjan::components(&ctx.db.type_aliases()) {
let cycle = component
.iter()
.map(|id| ctx.db.ast()[*id].name().to_string())
.collect::<Vec<_>>()
.join(" -> ");

// TODO: We can push an error for every sinlge class here (that's what
// Rust does), for now it's an error for every cycle found.
ctx.push_error(DatamodelError::new_validation_error(
&format!("These aliases form a dependency cycle: {}", cycle),
ctx.db.ast()[component[0]].span().clone(),
));
}
}

/// Inserts all the required dependencies of a field into the given set.
Expand All @@ -66,8 +83,33 @@ fn insert_required_deps(
) {
match field {
FieldType::Symbol(arity, ident, _) if arity.is_required() => {
if let Some(TypeWalker::Class(class)) = ctx.db.find_type_by_str(ident.name()) {
deps.insert(class.id);
match ctx.db.find_type_by_str(ident.name()) {
Some(TypeWalker::Class(class)) => {
deps.insert(class.id);
}
Some(TypeWalker::TypeAlias(alias)) => {
// TODO: By the time this code runs we would ideally want
// type aliases to be resolved but we can't do that because
// type alias cycles are not validated yet, we have to
// do that in this file. Take a look at the `validate`
// function at `baml-lib/baml-core/src/lib.rs`.
//
// First we run the `ParserDatabase::validate` function
// which creates the alias graph by visiting all aliases.
// Then we run the `validate::validate` which ends up
// running this code here. Finally we run the
// `ParserDatabase::finalize` which is the place where we
// can resolve type aliases since we've already validated
// that there are no cycles so we won't run into infinite
// recursion. Ideally we want this:
//
// insert_required_deps(id, alias.resolved(), ctx, deps);

// But we'll run this instead which will follow all the
// alias pointers again until it finds the resolved type.
insert_required_deps(id, alias.target(), ctx, deps);
}
_ => {}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Simple alias that points to recursive type.
class Node {
value int
next Node?
}

type LinkedList = Node

// Mutual recursion. There is no "type" here at all.
type One = Two

type Two = One

// Cycle. Same as above but longer.
type A = B

type B = C

type C = A

// Recursive class with alias pointing to itself.
class Recursive {
value int
ptr RecAlias
}

type RecAlias = Recursive

// error: Error validating: These classes form a dependency cycle: Recursive
// --> class/recursive_type_aliases.baml:22
// |
// 21 | // Recursive class with alias pointing to itself.
// 22 | class Recursive {
// 23 | value int
// 24 | ptr RecAlias
// 25 | }
// |
// error: Error validating: These aliases form a dependency cycle: One -> Two
// --> class/recursive_type_aliases.baml:10
// |
// 9 | // Mutual recursion. There is no "type" here at all.
// 10 | type One = Two
// |
// error: Error validating: These aliases form a dependency cycle: A -> B -> C
// --> class/recursive_type_aliases.baml:15
// |
// 14 | // Cycle. Same as above but longer.
// 15 | type A = B
// |
18 changes: 18 additions & 0 deletions engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub use coerce_expression::{coerce, coerce_array, coerce_opt};
pub use internal_baml_schema_ast::ast;
use internal_baml_schema_ast::ast::SchemaAst;
pub use tarjan::Tarjan;
use types::resolve_type_alias;
pub use types::{
Attributes, ContantDelayStrategy, ExponentialBackoffStrategy, PrinterType, PromptAst,
PromptVariable, RetryPolicy, RetryPolicyStrategy, StaticType,
Expand Down Expand Up @@ -173,6 +174,14 @@ impl ParserDatabase {
.map(|cycle| cycle.into_iter().collect())
.collect();

// Resolve type aliases.
// Cycles are already validated so this should not stack overflow and
// it should find the final type.
for alias_id in self.types.type_aliases.keys() {
let resolved = resolve_type_alias(&self.ast[*alias_id].value, &self);
self.types.resolved_type_aliases.insert(*alias_id, resolved);
}

// Additionally ensure the same thing for functions, but since we've
// already handled classes, this should be trivial.
let extends = self
Expand Down Expand Up @@ -226,6 +235,15 @@ impl ParserDatabase {
pub fn ast(&self) -> &ast::SchemaAst {
&self.ast
}

/// Returns the graph of type aliases.
///
/// Each vertex is a type alias and each edge is a reference to another type
/// alias.
pub fn type_aliases(&self) -> &HashMap<ast::TypeAliasId, HashSet<ast::TypeAliasId>> {
&self.types.type_aliases
}

/// The total number of enums in the schema. This is O(1).
pub fn enums_count(&self) -> usize {
self.types.enum_attributes.len()
Expand Down
23 changes: 14 additions & 9 deletions engine/baml-lib/parser-database/src/tarjan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
use std::{
cmp,
collections::{HashMap, HashSet},
hash::Hash,
};

use internal_baml_schema_ast::ast::TypeExpId;

/// Dependency graph represented as an adjacency list.
type Graph = HashMap<TypeExpId, HashSet<TypeExpId>>;
type Graph<V> = HashMap<V, HashSet<V>>;

/// State of each node for Tarjan's algorithm.
#[derive(Clone, Copy)]
Expand All @@ -35,20 +36,24 @@ struct NodeState {
/// This struct is simply bookkeeping for the algorithm, it can be implemented
/// with just function calls but the recursive one would need 6 parameters which
/// is pretty ugly.
pub struct Tarjan<'g> {
pub struct Tarjan<'g, V> {
/// Ref to the depdenency graph.
graph: &'g Graph,
graph: &'g Graph<V>,
/// Node number counter.
index: usize,
/// Nodes are placed on a stack in the order in which they are visited.
stack: Vec<TypeExpId>,
stack: Vec<V>,
/// State of each node.
state: HashMap<TypeExpId, NodeState>,
state: HashMap<V, NodeState>,
/// Strongly connected components.
components: Vec<Vec<TypeExpId>>,
components: Vec<Vec<V>>,
}

impl<'g> Tarjan<'g> {
// V is Copy because we mostly use opaque identifiers for class or alias IDs.
// In practice T ends up being a u32, but if for some reason this needs to
// be used with strings then we can make V Clone instead of Copy and refactor
// the code below.
impl<'g, V: Eq + Ord + Hash + Copy> Tarjan<'g, V> {
/// Unvisited node marker.
///
/// Technically we should use [`Option<usize>`] and [`None`] for
Expand All @@ -63,7 +68,7 @@ impl<'g> Tarjan<'g> {
/// Loops through all the nodes in the graph and visits them if they haven't
/// been visited already. When the algorithm is done, [`Self::components`]
/// will contain all the cycles in the graph.
pub fn components(graph: &'g Graph) -> Vec<Vec<TypeExpId>> {
pub fn components(graph: &'g Graph<V>) -> Vec<Vec<V>> {
let mut tarjans = Self {
graph,
index: 0,
Expand Down Expand Up @@ -105,7 +110,7 @@ impl<'g> Tarjan<'g> {
///
/// This is where the "algorithm" runs. Could be implemented iteratively if
/// needed at some point.
fn strong_connect(&mut self, node_id: TypeExpId) {
fn strong_connect(&mut self, node_id: V) {
// Initialize node state. This node has not yet been visited so we don't
// have to grab the state from the hash map. And if we did, then we'd
// have to fight the borrow checker by taking mut refs and read-only
Expand Down
68 changes: 51 additions & 17 deletions engine/baml-lib/parser-database/src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::{HashMap, HashSet, VecDeque};
use std::hash::Hash;

use crate::coerce;
use crate::types::configurations::visit_test_case;
use crate::{coerce, ParserDatabase};
use crate::{context::Context, DatamodelError};

use baml_types::Constraint;
Expand Down Expand Up @@ -233,6 +233,11 @@ pub(super) struct Types {
pub(super) class_dependencies: HashMap<ast::TypeExpId, HashSet<String>>,
pub(super) enum_dependencies: HashMap<ast::TypeExpId, HashSet<String>>,

/// Graph of type aliases.
///
/// This graph is only used to detect infinite cycles in type aliases.
pub(crate) type_aliases: HashMap<ast::TypeAliasId, HashSet<ast::TypeAliasId>>,

/// Fully resolved type aliases.
///
/// A type alias con point to one or many other type aliases.
Expand Down Expand Up @@ -374,29 +379,29 @@ fn visit_class<'db>(
///
/// The type would resolve to `SomeClass | AnotherClass | int`, which is not
/// stored in the AST.
fn resolve_type_alias(field_type: &FieldType, ctx: &mut Context<'_>) -> FieldType {
pub fn resolve_type_alias(field_type: &FieldType, db: &ParserDatabase) -> FieldType {
match field_type {
// For symbols we need to check if we're dealing with aliases.
FieldType::Symbol(arity, ident, span) => {
let Some(string_id) = ctx.interner.lookup(ident.name()) else {
let Some(string_id) = db.interner.lookup(ident.name()) else {
unreachable!(
"Attempting to resolve alias `{ident}` that does not exist in the interner"
);
};

let Some(top_id) = ctx.names.tops.get(&string_id) else {
let Some(top_id) = db.names.tops.get(&string_id) else {
unreachable!("Alias name `{ident}` is not registered in the context");
};

match top_id {
ast::TopId::TypeAlias(alias_id) => {
// Check if we can avoid deeper recursion.
if let Some(resolved) = ctx.types.resolved_type_aliases.get(alias_id) {
if let Some(resolved) = db.types.resolved_type_aliases.get(alias_id) {
return resolved.to_owned();
}

// Recurse... TODO: Recursive types and infinite cycles :(
let resolved = resolve_type_alias(&ctx.ast[*alias_id].value, ctx);
let resolved = resolve_type_alias(&db.ast[*alias_id].value, db);

// Sync arity. Basically stuff like:
//
Expand All @@ -421,7 +426,7 @@ fn resolve_type_alias(field_type: &FieldType, ctx: &mut Context<'_>) -> FieldTyp
| FieldType::Tuple(arity, items, span, attrs) => {
let resolved = items
.iter()
.map(|item| resolve_type_alias(item, ctx))
.map(|item| resolve_type_alias(item, db))
.collect();

match field_type {
Expand All @@ -446,18 +451,47 @@ fn visit_type_alias<'db>(
assignment: &'db ast::Assignment,
ctx: &mut Context<'db>,
) {
// Maybe this can't even happen since we iterate over the vec of tops and
// just get IDs sequentially, but anyway check just in case.
if ctx.types.resolved_type_aliases.contains_key(&alias_id) {
return;
}
// Insert the entry as soon as we get here then if we find something we'll
// add edges to the graph. Otherwise no edges but we still need the Vertex
// in order for the cycles algorithm to work.
let alias_refs = ctx.types.type_aliases.entry(alias_id).or_default();

let mut stack = vec![&assignment.value];

while let Some(item) = stack.pop() {
match item {
FieldType::Symbol(_, ident, _) => {
let Some(string_id) = ctx.interner.lookup(ident.name()) else {
unreachable!("Visiting alias `{ident}` that does not exist in the interner");
};

// Now resolve the type.
let resolved = resolve_type_alias(&assignment.value, ctx);
let Some(top_id) = ctx.names.tops.get(&string_id) else {
unreachable!("Alias name `{ident}` is not registered in the context");
};

// TODO: Can we add types to the map recursively while solving them at
// the same time? It might speed up very long chains of aliases.
ctx.types.resolved_type_aliases.insert(alias_id, resolved);
// Add alias to the graph.
if let ast::TopId::TypeAlias(nested_alias_id) = top_id {
alias_refs.insert(*nested_alias_id);
}
}

FieldType::Union(_, items, ..) | FieldType::Tuple(_, items, ..) => {
stack.extend(items.iter());
}

FieldType::List(_, nested, ..) => {
stack.push(nested);
}

FieldType::Map(_, nested, ..) => {
let (key, value) = nested.as_ref();
stack.push(key);
stack.push(value);
}

_ => {}
}
}
}

fn visit_function<'db>(idx: ValExpId, function: &'db ast::ValueExprBlock, ctx: &mut Context<'db>) {
Expand Down

0 comments on commit 273f13d

Please sign in to comment.