Skip to content

Commit

Permalink
Allow structural recursion in type aliases (#1207)
Browse files Browse the repository at this point in the history
Allow structural recursion (like TS):

```baml
type RecursiveMap = map<string, RecursiveMap>

type A = A[]
```

Follow up on #1163

<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Enable structural recursion in type aliases by updating cycle
validation logic, adding tests, and supporting recursive type aliases in
Python, Ruby, and TypeScript clients.
> 
>   - **Behavior**:
> - Allow structural recursion for type alias cycles involving maps and
lists in `cycle.rs`.
> - Update `validate()` in `cycle.rs` to handle structural recursion and
report cycles.
> - Modify `ParserDatabase::finalize_dependencies()` in `lib.rs` to
resolve type aliases and handle cycles.
>   - **Functions**:
> - Add `insert_required_alias_deps()` in `cycle.rs` to handle alias
dependencies.
> - Rename `insert_required_deps()` to `insert_required_class_deps()` in
`cycle.rs`.
>   - **Models**:
>     - Add `structural_recursive_alias_cycles` to `Types` in `mod.rs`.
>   - **Tests**:
> - Update `recursive_type_aliases.baml` to reflect new cycle validation
behavior.
>     - Add tests in `lib.rs` to verify structural recursion handling.
>   - **Language Support**:
> - Update Python, Ruby, and TypeScript clients to handle recursive type
aliases.
>   - **Misc**:
> - Add `is_subtype()` to `IRHelper` in `mod.rs` to support subtyping
checks.
> - Add `coerce_alias()` in `coerce_alias.rs` to handle alias coercion.
> - Add integration tests for recursive alias cycles in
`integ-tests/typescript/tests/integ-tests.test.ts`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for fc25050. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Dec 18, 2024
1 parent 3d95e20 commit 4eb543a
Show file tree
Hide file tree
Showing 48 changed files with 2,484 additions and 329 deletions.
274 changes: 248 additions & 26 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub trait IRHelper {
value: BamlValue,
field_type: FieldType,
) -> Result<BamlValueWithMeta<FieldType>>;
fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool;
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType,
Expand Down Expand Up @@ -203,6 +204,124 @@ impl IRHelper for IntermediateRepr {
}
}

/// BAML does not support class-based subtyping. Nonetheless some builtin
/// BAML types are subtypes of others, and we need to be able to test this
/// when checking the types of values.
///
/// For examples of pairs of types and their subtyping relationship, see
/// this module's test suite.
///
/// Consider renaming this to `is_assignable`.
fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool {
if base == other {
return true;
}

if let FieldType::Union(items) = other {
if items.iter().any(|item| self.is_subtype(base, item)) {
return true;
}
}

match (base, other) {
// TODO: O(n)
(FieldType::RecursiveTypeAlias(name), _) => self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| match cycle.get(name) {
Some(target) => self.is_subtype(target, other),
None => false,
}),
(_, FieldType::RecursiveTypeAlias(name)) => self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| match cycle.get(name) {
Some(target) => self.is_subtype(base, target),
None => false,
}),

(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(base_item), FieldType::Optional(other_item)) => {
self.is_subtype(base_item, other_item)
}
(_, FieldType::Optional(t)) => self.is_subtype(base, t),
(FieldType::Optional(_), _) => false,

// Handle types that nest other types.
(FieldType::List(base_item), FieldType::List(other_item)) => {
self.is_subtype(&base_item, other_item)
}
(FieldType::List(_), _) => false,

(FieldType::Map(base_k, base_v), FieldType::Map(other_k, other_v)) => {
self.is_subtype(other_k, base_k) && self.is_subtype(&**base_v, other_v)
}
(FieldType::Map(_, _), _) => false,

(
FieldType::Constrained {
base: constrained_base,
constraints: base_constraints,
},
FieldType::Constrained {
base: other_base,
constraints: other_constraints,
},
) => {
self.is_subtype(constrained_base, other_base)
&& base_constraints == other_constraints
}
(
FieldType::Constrained {
base: contrained_base,
..
},
_,
) => self.is_subtype(contrained_base, other),
(
_,
FieldType::Constrained {
base: constrained_base,
..
},
) => self.is_subtype(base, constrained_base),

(FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => {
true
}
(FieldType::Literal(LiteralValue::Bool(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::Bool))
}
(FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => {
true
}
(FieldType::Literal(LiteralValue::Int(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::Int))
}
(
FieldType::Literal(LiteralValue::String(_)),
FieldType::Primitive(TypeValue::String),
) => true,
(FieldType::Literal(LiteralValue::String(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::String))
}

(FieldType::Union(items), _) => items.iter().all(|item| self.is_subtype(item, other)),

(FieldType::Tuple(base_items), FieldType::Tuple(other_items)) => {
base_items.len() == other_items.len()
&& base_items
.iter()
.zip(other_items)
.all(|(base_item, other_item)| self.is_subtype(base_item, other_item))
}
(FieldType::Tuple(_), _) => false,
(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
(FieldType::Class(_), _) => false,
}
}

/// For some `BamlValue` with type `FieldType`, walk the structure of both the value
/// and the type simultaneously, associating each node in the `BamlValue` with its
/// `FieldType`.
Expand All @@ -216,48 +335,48 @@ impl IRHelper for IntermediateRepr {
let literal_type = FieldType::Literal(LiteralValue::String(s.clone()));
let primitive_type = FieldType::Primitive(TypeValue::String);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
return Ok(BamlValueWithMeta::String(s, field_type));
}
anyhow::bail!("Could not unify String with {:?}", field_type)
}
BamlValue::Int(i)
if FieldType::Literal(LiteralValue::Int(i)).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i)
if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i) => {
let literal_type = FieldType::Literal(LiteralValue::Int(i));
let primitive_type = FieldType::Primitive(TypeValue::Int);

if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
return Ok(BamlValueWithMeta::Int(i, field_type));
}
anyhow::bail!("Could not unify Int with {:?}", field_type)
}

BamlValue::Float(f)
if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Float(f, field_type))
BamlValue::Float(f) => {
if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_type) {
return Ok(BamlValueWithMeta::Float(f, field_type));
}
anyhow::bail!("Could not unify Float with {:?}", field_type)
}
BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type),

BamlValue::Bool(b) => {
let literal_type = FieldType::Literal(LiteralValue::Bool(b));
let primitive_type = FieldType::Primitive(TypeValue::Bool);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
Ok(BamlValueWithMeta::Bool(b, field_type))
} else {
anyhow::bail!("Could not unify Bool with {:?}", field_type)
}
}

BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => {
BamlValue::Null
if self.is_subtype(&FieldType::Primitive(TypeValue::Null), &field_type) =>
{
Ok(BamlValueWithMeta::Null(field_type))
}
BamlValue::Null => anyhow::bail!("Could not unify Null with {:?}", field_type),
Expand Down Expand Up @@ -287,7 +406,7 @@ impl IRHelper for IntermediateRepr {
Box::new(item_type.clone()),
);

if !map_type.is_subtype_of(&field_type) {
if !self.is_subtype(&map_type, &field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type);
}

Expand Down Expand Up @@ -321,7 +440,7 @@ impl IRHelper for IntermediateRepr {
Some(item_type) => {
let list_type = FieldType::List(Box::new(item_type.clone()));

if !list_type.is_subtype_of(&field_type) {
if !self.is_subtype(&list_type, &field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", list_type, field_type);
} else {
let mapped_items: Vec<BamlValueWithMeta<FieldType>> = items
Expand All @@ -335,23 +454,25 @@ impl IRHelper for IntermediateRepr {
}

BamlValue::Media(m)
if FieldType::Primitive(TypeValue::Media(m.media_type))
.is_subtype_of(&field_type) =>
if self.is_subtype(
&FieldType::Primitive(TypeValue::Media(m.media_type)),
&field_type,
) =>
{
Ok(BamlValueWithMeta::Media(m, field_type))
}
BamlValue::Media(_) => anyhow::bail!("Could not unify Media with {:?}", field_type),

BamlValue::Enum(name, val) => {
if FieldType::Enum(name.clone()).is_subtype_of(&field_type) {
if self.is_subtype(&FieldType::Enum(name.clone()), &field_type) {
Ok(BamlValueWithMeta::Enum(name, val, field_type))
} else {
anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type)
}
}

BamlValue::Class(name, fields) => {
if !FieldType::Class(name.clone()).is_subtype_of(&field_type) {
if !self.is_subtype(&FieldType::Class(name.clone()), &field_type) {
anyhow::bail!("Could not unify Class {} with {:?}", name, field_type);
} else {
let class_type = &self.find_class(&name)?.item.elem;
Expand Down Expand Up @@ -794,3 +915,104 @@ mod tests {
assert_eq!(constraints, expected_constraints);
}
}

// TODO: Copy pasted from baml-lib/baml-types/src/field_type/mod.rs and poorly
// refactored to match the `is_subtype` changes. Do something with this.
#[cfg(test)]
mod subtype_tests {
use baml_types::BamlMediaType;
use repr::make_test_ir;

use super::*;

fn mk_int() -> FieldType {
FieldType::Primitive(TypeValue::Int)
}
fn mk_bool() -> FieldType {
FieldType::Primitive(TypeValue::Bool)
}
fn mk_str() -> FieldType {
FieldType::Primitive(TypeValue::String)
}

fn mk_optional(ft: FieldType) -> FieldType {
FieldType::Optional(Box::new(ft))
}

fn mk_list(ft: FieldType) -> FieldType {
FieldType::List(Box::new(ft))
}

fn mk_tuple(ft: Vec<FieldType>) -> FieldType {
FieldType::Tuple(ft)
}
fn mk_union(ft: Vec<FieldType>) -> FieldType {
FieldType::Union(ft)
}
fn mk_str_map(ft: FieldType) -> FieldType {
FieldType::Map(Box::new(mk_str()), Box::new(ft))
}

fn ir() -> IntermediateRepr {
make_test_ir("").unwrap()
}

#[test]
fn subtype_trivial() {
assert!(ir().is_subtype(&mk_int(), &mk_int()))
}

#[test]
fn subtype_union() {
let i = mk_int();
let u = mk_union(vec![mk_int(), mk_str()]);
assert!(ir().is_subtype(&i, &u));
assert!(!ir().is_subtype(&u, &i));

let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]);
assert!(ir().is_subtype(&i, &u3));
assert!(ir().is_subtype(&u, &u3));
assert!(!ir().is_subtype(&u3, &u));
}

#[test]
fn subtype_optional() {
let i = mk_int();
let o = mk_optional(mk_int());
assert!(ir().is_subtype(&i, &o));
assert!(!ir().is_subtype(&o, &i));
}

#[test]
fn subtype_list() {
let l_i = mk_list(mk_int());
let l_o = mk_list(mk_optional(mk_int()));
assert!(ir().is_subtype(&l_i, &l_o));
assert!(!ir().is_subtype(&l_o, &l_i));
}

#[test]
fn subtype_tuple() {
let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]);
let y = mk_tuple(vec![mk_int(), mk_int()]);
assert!(ir().is_subtype(&y, &x));
assert!(!ir().is_subtype(&x, &y));
}

#[test]
fn subtype_map_of_list_of_unions() {
let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string())));
let y = mk_str_map(mk_list(mk_union(vec![
mk_str(),
mk_int(),
FieldType::Class("Foo".to_string()),
])));
assert!(ir().is_subtype(&x, &y));
}

#[test]
fn subtype_media() {
let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio));
assert!(ir().is_subtype(&x, &x));
}
}
Loading

0 comments on commit 4eb543a

Please sign in to comment.