Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow structural recursion in type aliases #1207

Merged
merged 26 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e38e5e7
Allow structural recursion
antoniosarosi Dec 2, 2024
794b3f4
Pass structural cycles to IR
antoniosarosi Dec 2, 2024
9869707
Test structural recursion finder
antoniosarosi Dec 2, 2024
81f776a
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 4, 2024
3582b18
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 4, 2024
fab92f5
Merge `antonio/type-aliases`
antoniosarosi Dec 5, 2024
5093eb9
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 5, 2024
c09997a
Merge `antonio/type-aliases`
antoniosarosi Dec 9, 2024
cd5e1f8
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 11, 2024
ba8177a
Merge branch 'antonio/type-aliases' into antonio/type-aliases-with-cy…
antoniosarosi Dec 13, 2024
72ea8cb
Implement codegen for Python type aliases
antoniosarosi Dec 16, 2024
140b3dd
Integ test works! Yeah
antoniosarosi Dec 17, 2024
68b98b7
Fix structural cycles rendering
antoniosarosi Dec 17, 2024
d462e5c
Coerce is wonky
antoniosarosi Dec 17, 2024
e0ae448
Fix test `relevant_data_models`
antoniosarosi Dec 17, 2024
abb7430
`is_subtype_of` causing issues with aliases
antoniosarosi Dec 18, 2024
c4e8b85
Fixed `subtype`, `coerce` still doesn't work
antoniosarosi Dec 18, 2024
d6b1e9e
Add integ tests for TS
antoniosarosi Dec 18, 2024
c5267b5
Remove recursion debug limit
antoniosarosi Dec 18, 2024
cac1a16
Add more tests (doesn't work because of score function)
antoniosarosi Dec 18, 2024
39141cb
Add codegen for TS
antoniosarosi Dec 18, 2024
401a97d
Add docs for Ruby type alias
antoniosarosi Dec 18, 2024
fc25050
Fix OpenAPI map keys
antoniosarosi Dec 18, 2024
342fb5e
Fix score of `JsonToString` flag
antoniosarosi Dec 18, 2024
2e10579
Fix integ tests for json type cycle
antoniosarosi Dec 18, 2024
8ff0397
Fix scoring ranking whaterver
antoniosarosi Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 248 additions & 26 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub trait IRHelper {
value: BamlValue,
field_type: FieldType,
) -> Result<BamlValueWithMeta<FieldType>>;
fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool;
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType,
Expand Down Expand Up @@ -203,6 +204,124 @@ impl IRHelper for IntermediateRepr {
}
}

/// BAML does not support class-based subtyping. Nonetheless some builtin
/// BAML types are subtypes of others, and we need to be able to test this
/// when checking the types of values.
///
/// For examples of pairs of types and their subtyping relationship, see
/// this module's test suite.
///
/// Consider renaming this to `is_assignable`.
fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool {
if base == other {
return true;
}

if let FieldType::Union(items) = other {
if items.iter().any(|item| self.is_subtype(base, item)) {
return true;
}
}

match (base, other) {
// TODO: O(n)
(FieldType::RecursiveTypeAlias(name), _) => self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| match cycle.get(name) {
Some(target) => self.is_subtype(target, other),
None => false,
}),
(_, FieldType::RecursiveTypeAlias(name)) => self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| match cycle.get(name) {
Some(target) => self.is_subtype(base, target),
None => false,
}),

(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(base_item), FieldType::Optional(other_item)) => {
self.is_subtype(base_item, other_item)
}
(_, FieldType::Optional(t)) => self.is_subtype(base, t),
(FieldType::Optional(_), _) => false,

// Handle types that nest other types.
(FieldType::List(base_item), FieldType::List(other_item)) => {
self.is_subtype(&base_item, other_item)
}
(FieldType::List(_), _) => false,

(FieldType::Map(base_k, base_v), FieldType::Map(other_k, other_v)) => {
self.is_subtype(other_k, base_k) && self.is_subtype(&**base_v, other_v)
}
(FieldType::Map(_, _), _) => false,

(
FieldType::Constrained {
base: constrained_base,
constraints: base_constraints,
},
FieldType::Constrained {
base: other_base,
constraints: other_constraints,
},
) => {
self.is_subtype(constrained_base, other_base)
&& base_constraints == other_constraints
}
(
FieldType::Constrained {
base: contrained_base,
..
},
_,
) => self.is_subtype(contrained_base, other),
(
_,
FieldType::Constrained {
base: constrained_base,
..
},
) => self.is_subtype(base, constrained_base),

(FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => {
true
}
(FieldType::Literal(LiteralValue::Bool(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::Bool))
}
(FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => {
true
}
(FieldType::Literal(LiteralValue::Int(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::Int))
}
(
FieldType::Literal(LiteralValue::String(_)),
FieldType::Primitive(TypeValue::String),
) => true,
(FieldType::Literal(LiteralValue::String(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::String))
}

(FieldType::Union(items), _) => items.iter().all(|item| self.is_subtype(item, other)),

(FieldType::Tuple(base_items), FieldType::Tuple(other_items)) => {
base_items.len() == other_items.len()
&& base_items
.iter()
.zip(other_items)
.all(|(base_item, other_item)| self.is_subtype(base_item, other_item))
}
(FieldType::Tuple(_), _) => false,
(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
(FieldType::Class(_), _) => false,
}
}

/// For some `BamlValue` with type `FieldType`, walk the structure of both the value
/// and the type simultaneously, associating each node in the `BamlValue` with its
/// `FieldType`.
Expand All @@ -216,48 +335,48 @@ impl IRHelper for IntermediateRepr {
let literal_type = FieldType::Literal(LiteralValue::String(s.clone()));
let primitive_type = FieldType::Primitive(TypeValue::String);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
return Ok(BamlValueWithMeta::String(s, field_type));
}
anyhow::bail!("Could not unify String with {:?}", field_type)
}
BamlValue::Int(i)
if FieldType::Literal(LiteralValue::Int(i)).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i)
if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i) => {
let literal_type = FieldType::Literal(LiteralValue::Int(i));
let primitive_type = FieldType::Primitive(TypeValue::Int);

if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
return Ok(BamlValueWithMeta::Int(i, field_type));
}
anyhow::bail!("Could not unify Int with {:?}", field_type)
}

BamlValue::Float(f)
if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Float(f, field_type))
BamlValue::Float(f) => {
if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_type) {
return Ok(BamlValueWithMeta::Float(f, field_type));
}
anyhow::bail!("Could not unify Float with {:?}", field_type)
}
BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type),

BamlValue::Bool(b) => {
let literal_type = FieldType::Literal(LiteralValue::Bool(b));
let primitive_type = FieldType::Primitive(TypeValue::Bool);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
Ok(BamlValueWithMeta::Bool(b, field_type))
} else {
anyhow::bail!("Could not unify Bool with {:?}", field_type)
}
}

BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => {
BamlValue::Null
if self.is_subtype(&FieldType::Primitive(TypeValue::Null), &field_type) =>
{
Ok(BamlValueWithMeta::Null(field_type))
}
BamlValue::Null => anyhow::bail!("Could not unify Null with {:?}", field_type),
Expand Down Expand Up @@ -287,7 +406,7 @@ impl IRHelper for IntermediateRepr {
Box::new(item_type.clone()),
);

if !map_type.is_subtype_of(&field_type) {
if !self.is_subtype(&map_type, &field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type);
}

Expand Down Expand Up @@ -321,7 +440,7 @@ impl IRHelper for IntermediateRepr {
Some(item_type) => {
let list_type = FieldType::List(Box::new(item_type.clone()));

if !list_type.is_subtype_of(&field_type) {
if !self.is_subtype(&list_type, &field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", list_type, field_type);
} else {
let mapped_items: Vec<BamlValueWithMeta<FieldType>> = items
Expand All @@ -335,23 +454,25 @@ impl IRHelper for IntermediateRepr {
}

BamlValue::Media(m)
if FieldType::Primitive(TypeValue::Media(m.media_type))
.is_subtype_of(&field_type) =>
if self.is_subtype(
&FieldType::Primitive(TypeValue::Media(m.media_type)),
&field_type,
) =>
{
Ok(BamlValueWithMeta::Media(m, field_type))
}
BamlValue::Media(_) => anyhow::bail!("Could not unify Media with {:?}", field_type),

BamlValue::Enum(name, val) => {
if FieldType::Enum(name.clone()).is_subtype_of(&field_type) {
if self.is_subtype(&FieldType::Enum(name.clone()), &field_type) {
Ok(BamlValueWithMeta::Enum(name, val, field_type))
} else {
anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type)
}
}

BamlValue::Class(name, fields) => {
if !FieldType::Class(name.clone()).is_subtype_of(&field_type) {
if !self.is_subtype(&FieldType::Class(name.clone()), &field_type) {
anyhow::bail!("Could not unify Class {} with {:?}", name, field_type);
} else {
let class_type = &self.find_class(&name)?.item.elem;
Expand Down Expand Up @@ -794,3 +915,104 @@ mod tests {
assert_eq!(constraints, expected_constraints);
}
}

// TODO: Copy pasted from baml-lib/baml-types/src/field_type/mod.rs and poorly
// refactored to match the `is_subtype` changes. Do something with this.
#[cfg(test)]
mod subtype_tests {
use baml_types::BamlMediaType;
use repr::make_test_ir;

use super::*;

fn mk_int() -> FieldType {
FieldType::Primitive(TypeValue::Int)
}
fn mk_bool() -> FieldType {
FieldType::Primitive(TypeValue::Bool)
}
fn mk_str() -> FieldType {
FieldType::Primitive(TypeValue::String)
}

fn mk_optional(ft: FieldType) -> FieldType {
FieldType::Optional(Box::new(ft))
}

fn mk_list(ft: FieldType) -> FieldType {
FieldType::List(Box::new(ft))
}

fn mk_tuple(ft: Vec<FieldType>) -> FieldType {
FieldType::Tuple(ft)
}
fn mk_union(ft: Vec<FieldType>) -> FieldType {
FieldType::Union(ft)
}
fn mk_str_map(ft: FieldType) -> FieldType {
FieldType::Map(Box::new(mk_str()), Box::new(ft))
}

fn ir() -> IntermediateRepr {
make_test_ir("").unwrap()
}

#[test]
fn subtype_trivial() {
assert!(ir().is_subtype(&mk_int(), &mk_int()))
}

#[test]
fn subtype_union() {
let i = mk_int();
let u = mk_union(vec![mk_int(), mk_str()]);
assert!(ir().is_subtype(&i, &u));
assert!(!ir().is_subtype(&u, &i));

let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]);
assert!(ir().is_subtype(&i, &u3));
assert!(ir().is_subtype(&u, &u3));
assert!(!ir().is_subtype(&u3, &u));
}

#[test]
fn subtype_optional() {
let i = mk_int();
let o = mk_optional(mk_int());
assert!(ir().is_subtype(&i, &o));
assert!(!ir().is_subtype(&o, &i));
}

#[test]
fn subtype_list() {
let l_i = mk_list(mk_int());
let l_o = mk_list(mk_optional(mk_int()));
assert!(ir().is_subtype(&l_i, &l_o));
assert!(!ir().is_subtype(&l_o, &l_i));
}

#[test]
fn subtype_tuple() {
let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]);
let y = mk_tuple(vec![mk_int(), mk_int()]);
assert!(ir().is_subtype(&y, &x));
assert!(!ir().is_subtype(&x, &y));
}

#[test]
fn subtype_map_of_list_of_unions() {
let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string())));
let y = mk_str_map(mk_list(mk_union(vec![
mk_str(),
mk_int(),
FieldType::Class("Foo".to_string()),
])));
assert!(ir().is_subtype(&x, &y));
}

#[test]
fn subtype_media() {
let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio));
assert!(ir().is_subtype(&x, &x));
}
}
Loading
Loading