From c4e8b85934e88ffe0273bd985a7320f95e10408f Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 03:49:07 +0100 Subject: [PATCH] Fixed `subtype`, `coerce` still doesn't work --- .../baml-core/src/ir/ir_helpers/mod.rs | 274 ++++++++++++++++-- .../src/ir/ir_helpers/to_baml_arg.rs | 6 +- .../baml-lib/baml-types/src/field_type/mod.rs | 181 ------------ .../jsonish/src/tests/test_aliases.rs | 2 +- .../functions/output/type-aliases.baml | 2 +- 5 files changed, 251 insertions(+), 214 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 40d932dce..6c5e71a9d 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -57,6 +57,7 @@ pub trait IRHelper { value: BamlValue, field_type: FieldType, ) -> Result>; + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool; fn distribute_constraints<'a>( &'a self, field_type: &'a FieldType, @@ -203,6 +204,124 @@ impl IRHelper for IntermediateRepr { } } + /// BAML does not support class-based subtyping. Nonetheless some builtin + /// BAML types are subtypes of others, and we need to be able to test this + /// when checking the types of values. + /// + /// For examples of pairs of types and their subtyping relationship, see + /// this module's test suite. + /// + /// Consider renaming this to `is_assignable`. + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool { + if base == other { + return true; + } + + if let FieldType::Union(items) = other { + if items.iter().any(|item| self.is_subtype(base, item)) { + return true; + } + } + + match (base, other) { + // TODO: O(n) + (FieldType::RecursiveTypeAlias(name), _) => self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| match cycle.get(name) { + Some(target) => self.is_subtype(target, other), + None => false, + }), + (_, FieldType::RecursiveTypeAlias(name)) => self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| match cycle.get(name) { + Some(target) => self.is_subtype(base, target), + None => false, + }), + + (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, + (FieldType::Optional(base_item), FieldType::Optional(other_item)) => { + self.is_subtype(base_item, other_item) + } + (_, FieldType::Optional(t)) => self.is_subtype(base, t), + (FieldType::Optional(_), _) => false, + + // Handle types that nest other types. + (FieldType::List(base_item), FieldType::List(other_item)) => { + self.is_subtype(&base_item, other_item) + } + (FieldType::List(_), _) => false, + + (FieldType::Map(base_k, base_v), FieldType::Map(other_k, other_v)) => { + self.is_subtype(other_k, base_k) && self.is_subtype(&**base_v, other_v) + } + (FieldType::Map(_, _), _) => false, + + ( + FieldType::Constrained { + base: constrained_base, + constraints: base_constraints, + }, + FieldType::Constrained { + base: other_base, + constraints: other_constraints, + }, + ) => { + self.is_subtype(constrained_base, other_base) + && base_constraints == other_constraints + } + ( + FieldType::Constrained { + base: contrained_base, + .. + }, + _, + ) => self.is_subtype(contrained_base, other), + ( + _, + FieldType::Constrained { + base: constrained_base, + .. + }, + ) => self.is_subtype(base, constrained_base), + + (FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => { + true + } + (FieldType::Literal(LiteralValue::Bool(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::Bool)) + } + (FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => { + true + } + (FieldType::Literal(LiteralValue::Int(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::Int)) + } + ( + FieldType::Literal(LiteralValue::String(_)), + FieldType::Primitive(TypeValue::String), + ) => true, + (FieldType::Literal(LiteralValue::String(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::String)) + } + + (FieldType::Union(items), _) => items.iter().all(|item| self.is_subtype(item, other)), + + (FieldType::Tuple(base_items), FieldType::Tuple(other_items)) => { + base_items.len() == other_items.len() + && base_items + .iter() + .zip(other_items) + .all(|(base_item, other_item)| self.is_subtype(base_item, other_item)) + } + (FieldType::Tuple(_), _) => false, + (FieldType::Primitive(_), _) => false, + (FieldType::Enum(_), _) => false, + (FieldType::Class(_), _) => false, + } + } + /// For some `BamlValue` with type `FieldType`, walk the structure of both the value /// and the type simultaneously, associating each node in the `BamlValue` with its /// `FieldType`. @@ -216,40 +335,38 @@ impl IRHelper for IntermediateRepr { let literal_type = FieldType::Literal(LiteralValue::String(s.clone())); let primitive_type = FieldType::Primitive(TypeValue::String); - if literal_type.is_subtype_of(&field_type) - || primitive_type.is_subtype_of(&field_type) + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) { return Ok(BamlValueWithMeta::String(s, field_type)); } anyhow::bail!("Could not unify String with {:?}", field_type) } - BamlValue::Int(i) - if FieldType::Literal(LiteralValue::Int(i)).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Int(i, field_type)) - } - BamlValue::Int(i) - if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Int(i, field_type)) - } BamlValue::Int(i) => { + let literal_type = FieldType::Literal(LiteralValue::Int(i)); + let primitive_type = FieldType::Primitive(TypeValue::Int); + + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) + { + return Ok(BamlValueWithMeta::Int(i, field_type)); + } anyhow::bail!("Could not unify Int with {:?}", field_type) } - BamlValue::Float(f) - if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Float(f, field_type)) + BamlValue::Float(f) => { + if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_type) { + return Ok(BamlValueWithMeta::Float(f, field_type)); + } + anyhow::bail!("Could not unify Float with {:?}", field_type) } - BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type), BamlValue::Bool(b) => { let literal_type = FieldType::Literal(LiteralValue::Bool(b)); let primitive_type = FieldType::Primitive(TypeValue::Bool); - if literal_type.is_subtype_of(&field_type) - || primitive_type.is_subtype_of(&field_type) + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) { Ok(BamlValueWithMeta::Bool(b, field_type)) } else { @@ -257,7 +374,9 @@ impl IRHelper for IntermediateRepr { } } - BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => { + BamlValue::Null + if self.is_subtype(&FieldType::Primitive(TypeValue::Null), &field_type) => + { Ok(BamlValueWithMeta::Null(field_type)) } BamlValue::Null => anyhow::bail!("Could not unify Null with {:?}", field_type), @@ -287,7 +406,7 @@ impl IRHelper for IntermediateRepr { Box::new(item_type.clone()), ); - if !map_type.is_subtype_of(&field_type) { + if !self.is_subtype(&map_type, &field_type) { anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type); } @@ -321,7 +440,7 @@ impl IRHelper for IntermediateRepr { Some(item_type) => { let list_type = FieldType::List(Box::new(item_type.clone())); - if !list_type.is_subtype_of(&field_type) { + if !self.is_subtype(&list_type, &field_type) { anyhow::bail!("Could not unify {:?} with {:?}", list_type, field_type); } else { let mapped_items: Vec> = items @@ -335,15 +454,17 @@ impl IRHelper for IntermediateRepr { } BamlValue::Media(m) - if FieldType::Primitive(TypeValue::Media(m.media_type)) - .is_subtype_of(&field_type) => + if self.is_subtype( + &FieldType::Primitive(TypeValue::Media(m.media_type)), + &field_type, + ) => { Ok(BamlValueWithMeta::Media(m, field_type)) } BamlValue::Media(_) => anyhow::bail!("Could not unify Media with {:?}", field_type), BamlValue::Enum(name, val) => { - if FieldType::Enum(name.clone()).is_subtype_of(&field_type) { + if self.is_subtype(&FieldType::Enum(name.clone()), &field_type) { Ok(BamlValueWithMeta::Enum(name, val, field_type)) } else { anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type) @@ -351,7 +472,7 @@ impl IRHelper for IntermediateRepr { } BamlValue::Class(name, fields) => { - if !FieldType::Class(name.clone()).is_subtype_of(&field_type) { + if !self.is_subtype(&FieldType::Class(name.clone()), &field_type) { anyhow::bail!("Could not unify Class {} with {:?}", name, field_type); } else { let class_type = &self.find_class(&name)?.item.elem; @@ -794,3 +915,104 @@ mod tests { assert_eq!(constraints, expected_constraints); } } + +// TODO: Copy pasted from baml-lib/baml-types/src/field_type/mod.rs and poorly +// refactored to match the `is_subtype` changes. Do something with this. +#[cfg(test)] +mod subtype_tests { + use baml_types::BamlMediaType; + use repr::make_test_ir; + + use super::*; + + fn mk_int() -> FieldType { + FieldType::Primitive(TypeValue::Int) + } + fn mk_bool() -> FieldType { + FieldType::Primitive(TypeValue::Bool) + } + fn mk_str() -> FieldType { + FieldType::Primitive(TypeValue::String) + } + + fn mk_optional(ft: FieldType) -> FieldType { + FieldType::Optional(Box::new(ft)) + } + + fn mk_list(ft: FieldType) -> FieldType { + FieldType::List(Box::new(ft)) + } + + fn mk_tuple(ft: Vec) -> FieldType { + FieldType::Tuple(ft) + } + fn mk_union(ft: Vec) -> FieldType { + FieldType::Union(ft) + } + fn mk_str_map(ft: FieldType) -> FieldType { + FieldType::Map(Box::new(mk_str()), Box::new(ft)) + } + + fn ir() -> IntermediateRepr { + make_test_ir("").unwrap() + } + + #[test] + fn subtype_trivial() { + assert!(ir().is_subtype(&mk_int(), &mk_int())) + } + + #[test] + fn subtype_union() { + let i = mk_int(); + let u = mk_union(vec![mk_int(), mk_str()]); + assert!(ir().is_subtype(&i, &u)); + assert!(!ir().is_subtype(&u, &i)); + + let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]); + assert!(ir().is_subtype(&i, &u3)); + assert!(ir().is_subtype(&u, &u3)); + assert!(!ir().is_subtype(&u3, &u)); + } + + #[test] + fn subtype_optional() { + let i = mk_int(); + let o = mk_optional(mk_int()); + assert!(ir().is_subtype(&i, &o)); + assert!(!ir().is_subtype(&o, &i)); + } + + #[test] + fn subtype_list() { + let l_i = mk_list(mk_int()); + let l_o = mk_list(mk_optional(mk_int())); + assert!(ir().is_subtype(&l_i, &l_o)); + assert!(!ir().is_subtype(&l_o, &l_i)); + } + + #[test] + fn subtype_tuple() { + let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); + let y = mk_tuple(vec![mk_int(), mk_int()]); + assert!(ir().is_subtype(&y, &x)); + assert!(!ir().is_subtype(&x, &y)); + } + + #[test] + fn subtype_map_of_list_of_unions() { + let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string()))); + let y = mk_str_map(mk_list(mk_union(vec![ + mk_str(), + mk_int(), + FieldType::Class("Foo".to_string()), + ]))); + assert!(ir().is_subtype(&x, &y)); + } + + #[test] + fn subtype_media() { + let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); + assert!(ir().is_subtype(&x, &x)); + } +} diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 7f529cd72..a13c5cc49 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -43,9 +43,6 @@ impl ArgCoercer { value: &BamlValue, // original value passed in by user scope: &mut ScopeStack, ) -> Result { - eprintln!("coerce_arg: {value:?} -> {field_type:?}"); - eprintln!("scope: {scope}\n"); - let value = match ir.distribute_constraints(field_type) { (FieldType::Primitive(t), _) => match t { TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()), @@ -331,7 +328,6 @@ impl ArgCoercer { let mut scope = ScopeStack::new(); if first_good_result.is_err() { let result = self.coerce_arg(ir, option, value, &mut scope); - eprintln!("union inner scope scope: {scope}\n"); if !scope.has_errors() && first_good_result.is_err() { first_good_result = result } @@ -466,7 +462,7 @@ mod tests { fn test_mutually_recursive_aliases() { let ir = make_test_ir( r##" -type JsonValue = int | string | bool | float | JsonObject | JsonArray +type JsonValue = int | bool | float | string | JsonArray | JsonObject type JsonObject = map type JsonArray = JsonValue[] "##, diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index 7008f26a5..52f59fae0 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -160,185 +160,4 @@ impl FieldType { _ => false, } } - - /// BAML does not support class-based subtyping. Nonetheless some builtin - /// BAML types are subtypes of others, and we need to be able to test this - /// when checking the types of values. - /// - /// For examples of pairs of types and their subtyping relationship, see - /// this module's test suite. - /// - /// Consider renaming this to `is_assignable_to`. - pub fn is_subtype_of(&self, other: &FieldType) -> bool { - if self == other { - return true; - } - - if let FieldType::Union(items) = other { - if items.iter().any(|item| self.is_subtype_of(item)) { - return true; - } - } - - match (self, other) { - (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, - (FieldType::Optional(self_item), FieldType::Optional(other_item)) => { - self_item.is_subtype_of(other_item) - } - (_, FieldType::Optional(t)) => self.is_subtype_of(t), - (FieldType::Optional(_), _) => false, - - // Handle types that nest other types. - (FieldType::List(self_item), FieldType::List(other_item)) => { - self_item.is_subtype_of(other_item) - } - (FieldType::List(_), _) => false, - - (FieldType::Map(self_k, self_v), FieldType::Map(other_k, other_v)) => { - other_k.is_subtype_of(self_k) && (**self_v).is_subtype_of(other_v) - } - (FieldType::Map(_, _), _) => false, - - ( - FieldType::Constrained { - base: self_base, - constraints: self_cs, - }, - FieldType::Constrained { - base: other_base, - constraints: other_cs, - }, - ) => self_base.is_subtype_of(other_base) && self_cs == other_cs, - (FieldType::Constrained { base, .. }, _) => base.is_subtype_of(other), - (_, FieldType::Constrained { base, .. }) => self.is_subtype_of(base), - (FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => { - true - } - (FieldType::Literal(LiteralValue::Bool(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::Bool)) - } - (FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => { - true - } - (FieldType::Literal(LiteralValue::Int(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::Int)) - } - ( - FieldType::Literal(LiteralValue::String(_)), - FieldType::Primitive(TypeValue::String), - ) => true, - (FieldType::Literal(LiteralValue::String(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::String)) - } - - (FieldType::Union(self_items), _) => self_items - .iter() - .all(|self_item| self_item.is_subtype_of(other)), - - (FieldType::Tuple(self_items), FieldType::Tuple(other_items)) => { - self_items.len() == other_items.len() - && self_items - .iter() - .zip(other_items) - .all(|(self_item, other_item)| self_item.is_subtype_of(other_item)) - } - (FieldType::Tuple(_), _) => false, - (FieldType::Primitive(_), _) => false, - (FieldType::Enum(_), _) => false, - (FieldType::Class(_), _) => false, - (FieldType::RecursiveTypeAlias(_), _) => false, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn mk_int() -> FieldType { - FieldType::Primitive(TypeValue::Int) - } - fn mk_bool() -> FieldType { - FieldType::Primitive(TypeValue::Bool) - } - fn mk_str() -> FieldType { - FieldType::Primitive(TypeValue::String) - } - - fn mk_optional(ft: FieldType) -> FieldType { - FieldType::Optional(Box::new(ft)) - } - - fn mk_list(ft: FieldType) -> FieldType { - FieldType::List(Box::new(ft)) - } - - fn mk_tuple(ft: Vec) -> FieldType { - FieldType::Tuple(ft) - } - fn mk_union(ft: Vec) -> FieldType { - FieldType::Union(ft) - } - fn mk_str_map(ft: FieldType) -> FieldType { - FieldType::Map(Box::new(mk_str()), Box::new(ft)) - } - - #[test] - fn subtype_trivial() { - assert!(mk_int().is_subtype_of(&mk_int())) - } - - #[test] - fn subtype_union() { - let i = mk_int(); - let u = mk_union(vec![mk_int(), mk_str()]); - assert!(i.is_subtype_of(&u)); - assert!(!u.is_subtype_of(&i)); - - let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]); - assert!(i.is_subtype_of(&u3)); - assert!(u.is_subtype_of(&u3)); - assert!(!u3.is_subtype_of(&u)); - } - - #[test] - fn subtype_optional() { - let i = mk_int(); - let o = mk_optional(mk_int()); - assert!(i.is_subtype_of(&o)); - assert!(!o.is_subtype_of(&i)); - } - - #[test] - fn subtype_list() { - let l_i = mk_list(mk_int()); - let l_o = mk_list(mk_optional(mk_int())); - assert!(l_i.is_subtype_of(&l_o)); - assert!(!l_o.is_subtype_of(&l_i)); - } - - #[test] - fn subtype_tuple() { - let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); - let y = mk_tuple(vec![mk_int(), mk_int()]); - assert!(y.is_subtype_of(&x)); - assert!(!x.is_subtype_of(&y)); - } - - #[test] - fn subtype_map_of_list_of_unions() { - let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string()))); - let y = mk_str_map(mk_list(mk_union(vec![ - mk_str(), - mk_int(), - FieldType::Class("Foo".to_string()), - ]))); - assert!(x.is_subtype_of(&y)); - } - - #[test] - fn subtype_media() { - let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); - assert!(x.is_subtype_of(&x)); - } } diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs index 35baa352f..97f7f55ef 100644 --- a/engine/baml-lib/jsonish/src/tests/test_aliases.rs +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -60,7 +60,7 @@ type JsonValue = int | string | bool | JsonValue[] | map test_deserializer!( test_complex_recursive_alias, r#" -type JsonValue = int | string | bool | JsonValue[] | map +type JsonValue = int | bool | string | JsonValue[] | map "#, r#" { diff --git a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml index 216bfc2ec..2e407303d 100644 --- a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml +++ b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml @@ -133,4 +133,4 @@ function JsonTypeAliasCycle(input: JsonValue) -> JsonValue { {{ ctx.output_format }} "# -} \ No newline at end of file +}