Skip to content

Commit

Permalink
Resolve type aliases to final type
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Nov 20, 2024
1 parent 8786e3a commit 1f1d777
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 144 deletions.
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ impl ArgCoercer {
Err(())
}
},
(FieldType::Alias(name, target), _) => todo!(),
(FieldType::List(item), _) => match value {
BamlValue::List(arr) => {
let mut items = Vec::new();
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/src/ir/json_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ impl<'db> WithJsonSchema for FieldType {
FieldType::Class(name) | FieldType::Enum(name) => json!({
"$ref": format!("#/definitions/{}", name),
}),
FieldType::Alias(_, target) => todo!(),
FieldType::Literal(v) => json!({
"const": v.to_string(),
}),
Expand Down
162 changes: 82 additions & 80 deletions engine/baml-lib/baml-types/src/field_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub enum FieldType {
Union(Vec<FieldType>),
Tuple(Vec<FieldType>),
Optional(Box<FieldType>),
Alias(String, Box<FieldType>),
Constrained {
base: Box<FieldType>,
constraints: Vec<Constraint>,
Expand All @@ -92,11 +93,10 @@ pub enum FieldType {
impl std::fmt::Display for FieldType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldType::Enum(name) | FieldType::Class(name) => {
write!(f, "{}", name)
}
FieldType::Primitive(t) => write!(f, "{}", t),
FieldType::Literal(v) => write!(f, "{}", v),
FieldType::Enum(name) | FieldType::Class(name) => write!(f, "{name}"),
FieldType::Alias(name, _) => write!(f, "{name}"),
FieldType::Primitive(t) => write!(f, "{t}"),
FieldType::Literal(v) => write!(f, "{v}"),
FieldType::Union(choices) => {
write!(
f,
Expand Down Expand Up @@ -167,83 +167,85 @@ impl FieldType {
/// Consider renaming this to `is_assignable_to`.
pub fn is_subtype_of(&self, other: &FieldType) -> bool {
if self == other {
true
} else {
if let FieldType::Union(items) = other {
if items.iter().any(|item| self.is_subtype_of(item)) {
return true;
}
return true;
}

if let FieldType::Union(items) = other {
if items.iter().any(|item| self.is_subtype_of(item)) {
return true;
}
}

match (self, other) {
(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(self_item), FieldType::Optional(other_item)) => {
self_item.is_subtype_of(other_item)
}
(_, FieldType::Optional(t)) => self.is_subtype_of(t),
(FieldType::Optional(_), _) => false,

// Handle types that nest other types.
(FieldType::List(self_item), FieldType::List(other_item)) => {
self_item.is_subtype_of(other_item)
}
(FieldType::List(_), _) => false,

(FieldType::Map(self_k, self_v), FieldType::Map(other_k, other_v)) => {
other_k.is_subtype_of(self_k) && (**self_v).is_subtype_of(other_v)
}
match (self, other) {
(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(self_item), FieldType::Optional(other_item)) => {
self_item.is_subtype_of(other_item)
}
(_, FieldType::Optional(t)) => self.is_subtype_of(t),
(FieldType::Optional(_), _) => false,

// Handle types that nest other types.
(FieldType::List(self_item), FieldType::List(other_item)) => {
self_item.is_subtype_of(other_item)
}
(FieldType::List(_), _) => false,

(FieldType::Map(self_k, self_v), FieldType::Map(other_k, other_v)) => {
other_k.is_subtype_of(self_k) && (**self_v).is_subtype_of(other_v)
}
(FieldType::Map(_, _), _) => false,

(
FieldType::Constrained {
base: self_base,
constraints: self_cs,
},
FieldType::Constrained {
base: other_base,
constraints: other_cs,
},
) => self_base.is_subtype_of(other_base) && self_cs == other_cs,
(FieldType::Constrained { base, .. }, _) => base.is_subtype_of(other),
(_, FieldType::Constrained { base, .. }) => self.is_subtype_of(base),
(
FieldType::Literal(LiteralValue::Bool(_)),
FieldType::Primitive(TypeValue::Bool),
) => true,
(FieldType::Literal(LiteralValue::Bool(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::Bool))
}
(
FieldType::Literal(LiteralValue::Int(_)),
FieldType::Primitive(TypeValue::Int),
) => true,
(FieldType::Literal(LiteralValue::Int(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::Int))
}
(
FieldType::Literal(LiteralValue::String(_)),
FieldType::Primitive(TypeValue::String),
) => true,
(FieldType::Literal(LiteralValue::String(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::String))
}

(FieldType::Union(self_items), _) => self_items
.iter()
.all(|self_item| self_item.is_subtype_of(other)),

(FieldType::Tuple(self_items), FieldType::Tuple(other_items)) => {
self_items.len() == other_items.len()
&& self_items
.iter()
.zip(other_items)
.all(|(self_item, other_item)| self_item.is_subtype_of(other_item))
}
(FieldType::Tuple(_), _) => false,

(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
(FieldType::Class(_), _) => false,
(FieldType::Map(_, _), _) => false,

(
FieldType::Constrained {
base: self_base,
constraints: self_cs,
},
FieldType::Constrained {
base: other_base,
constraints: other_cs,
},
) => self_base.is_subtype_of(other_base) && self_cs == other_cs,
(FieldType::Constrained { base, .. }, _) => base.is_subtype_of(other),
(_, FieldType::Constrained { base, .. }) => self.is_subtype_of(base),
(FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => {
true
}
(FieldType::Literal(LiteralValue::Bool(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::Bool))
}
(FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => {
true
}
(FieldType::Literal(LiteralValue::Int(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::Int))
}
(
FieldType::Literal(LiteralValue::String(_)),
FieldType::Primitive(TypeValue::String),
) => true,
(FieldType::Literal(LiteralValue::String(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::String))
}

(FieldType::Union(self_items), _) => self_items
.iter()
.all(|self_item| self_item.is_subtype_of(other)),

(FieldType::Tuple(self_items), FieldType::Tuple(other_items)) => {
self_items.len() == other_items.len()
&& self_items
.iter()
.zip(other_items)
.all(|(self_item, other_item)| self_item.is_subtype_of(other_item))
}
// TODO: Can this cause infinite recursion?
// Should the final resolved type (following all the aliases) be
// included in the variant so that we skip recursion?
(FieldType::Alias(_, target), _) => target.is_subtype_of(other),
(FieldType::Tuple(_), _) => false,
(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
(FieldType::Class(_), _) => false,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions engine/baml-lib/jinja-runtime/src/output_format/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ impl OutputFormatContent {

Some(format!("Answer in JSON using this {type_prefix}:{end}"))
}
FieldType::Alias(_, _) => todo!(),
FieldType::List(_) => Some(String::from(
"Answer with a JSON Array using this schema:\n",
)),
Expand Down Expand Up @@ -481,6 +482,7 @@ impl OutputFormatContent {
}
.to_string()
}
FieldType::Alias(_, _) => todo!(),
FieldType::List(inner) => {
let is_recursive = match inner.as_ref() {
FieldType::Class(nested_class) => self.recursive_classes.contains(nested_class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl TypeCoercer for FieldType {
FieldType::Enum(e) => IrRef::Enum(e).coerce(ctx, target, value),
FieldType::Literal(l) => l.coerce(ctx, target, value),
FieldType::Class(c) => IrRef::Class(c).coerce(ctx, target, value),
FieldType::Alias(_, _) => todo!(),
FieldType::List(_) => coerce_array(ctx, self, value),
FieldType::Union(_) => coerce_union(ctx, self, value),
FieldType::Optional(_) => coerce_optional(ctx, self, value),
Expand Down Expand Up @@ -139,7 +140,8 @@ pub fn validate_asserts(constraints: &Vec<(Constraint, bool)>) -> Result<(), Par
expr.0
),
scope: vec![],
}).collect::<Vec<_>>();
})
.collect::<Vec<_>>();
if causes.len() > 0 {
Err(ParsingError {
causes: vec![],
Expand All @@ -163,6 +165,7 @@ impl DefaultValue for FieldType {
FieldType::Enum(e) => None,
FieldType::Literal(_) => None,
FieldType::Class(_) => None,
FieldType::Alias(_, _) => todo!(),
FieldType::List(_) => Some(BamlValueWithFlags::List(get_flags(), Vec::new())),
FieldType::Union(items) => items.iter().find_map(|i| i.default_value(error)),
FieldType::Primitive(TypeValue::Null) | FieldType::Optional(_) => {
Expand Down
2 changes: 2 additions & 0 deletions engine/baml-lib/jsonish/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::{

use baml_types::BamlValue;
use internal_baml_core::{
ast::Field,
internal_baml_diagnostics::SourceFile,
ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, TypeValue},
validate,
Expand Down Expand Up @@ -231,6 +232,7 @@ fn relevant_data_models<'a>(
});
}
}
(FieldType::Alias(_, _), _) => todo!(),
(FieldType::Literal(_), _) => {}
(FieldType::Primitive(_), _constraints) => {}
(FieldType::Constrained { .. }, _) => {
Expand Down
35 changes: 33 additions & 2 deletions engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ mod names;
mod tarjan;
mod types;

use std::collections::{HashMap, HashSet};
use std::collections::{HashMap, HashSet, VecDeque};

pub use coerce_expression::{coerce, coerce_array, coerce_opt};
pub use internal_baml_schema_ast::ast;
Expand Down Expand Up @@ -131,6 +131,35 @@ impl ParserDatabase {
}

fn finalize_dependencies(&mut self, diag: &mut Diagnostics) {
// Fully resolve type aliases.
for (id, targets) in self.types.type_aliases.iter() {
let mut resolved = HashSet::new();
let mut queue = VecDeque::from_iter(targets.iter());

while let Some(target) = queue.pop_front() {
match self.find_type_by_str(target) {
Some(TypeWalker::Class(_) | TypeWalker::Enum(_)) => {
resolved.insert(target.to_owned());
}
// TODO: Cycles and recursive stuff.
Some(TypeWalker::TypeAlias(alias)) => {
let alias_id = alias.id;

if let Some(already_resolved) =
self.types.resolved_type_aliases.get_mut(&alias_id)
{
resolved.extend(already_resolved.iter().cloned());
} else {
queue.extend(&self.types.type_aliases[&alias_id])
}
}
None => panic!("Type alias pointing to invalid type `{target}`"),
};
}

self.types.resolved_type_aliases.insert(*id, resolved);
}

// NOTE: Class dependency cycles are already checked at
// baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs
//
Expand Down Expand Up @@ -203,7 +232,9 @@ impl ParserDatabase {
Some(walker.dependencies().iter().cloned())
}
Some(TypeWalker::Enum(_)) => None,
Some(TypeWalker::TypeAlias(_)) => todo!(),
Some(TypeWalker::TypeAlias(walker)) => {
Some(self.types.resolved_type_aliases[&walker.id].iter().cloned())
}
_ => panic!("Unknown class `{}`", f),
})
.flatten()
Expand Down
3 changes: 3 additions & 0 deletions engine/baml-lib/parser-database/src/names/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub(super) struct Names {
/// Tests have their own namespace.
pub(super) tests: HashMap<StringId, HashMap<StringId, TopId>>,
pub(super) model_fields: HashMap<(ast::TypeExpId, StringId), ast::FieldId>,
pub(super) type_aliases: HashMap<ast::TypeExpId, Option<StringId>>,
// pub(super) composite_type_fields: HashMap<(ast::CompositeTypeId, StringId), ast::FieldId>,
}

Expand Down Expand Up @@ -94,6 +95,8 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) {
(ast::TopId::TypeAlias(_), ast::Top::TypeAlias(type_alias)) => {
validate_type_alias_name(type_alias, ctx.diagnostics);

let type_alias_id = ctx.interner.intern(type_alias.name());

Some(either::Left(&mut names.tops))
}

Expand Down
Loading

0 comments on commit 1f1d777

Please sign in to comment.