Skip to content

Commit

Permalink
Fixed subtype, coerce still doesn't work
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Dec 18, 2024
1 parent abb7430 commit c4e8b85
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 214 deletions.
274 changes: 248 additions & 26 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub trait IRHelper {
value: BamlValue,
field_type: FieldType,
) -> Result<BamlValueWithMeta<FieldType>>;
fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool;
fn distribute_constraints<'a>(
&'a self,
field_type: &'a FieldType,
Expand Down Expand Up @@ -203,6 +204,124 @@ impl IRHelper for IntermediateRepr {
}
}

/// BAML does not support class-based subtyping. Nonetheless some builtin
/// BAML types are subtypes of others, and we need to be able to test this
/// when checking the types of values.
///
/// For examples of pairs of types and their subtyping relationship, see
/// this module's test suite.
///
/// Consider renaming this to `is_assignable`.
fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool {
if base == other {
return true;
}

if let FieldType::Union(items) = other {
if items.iter().any(|item| self.is_subtype(base, item)) {
return true;
}
}

match (base, other) {
// TODO: O(n)
(FieldType::RecursiveTypeAlias(name), _) => self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| match cycle.get(name) {
Some(target) => self.is_subtype(target, other),
None => false,
}),
(_, FieldType::RecursiveTypeAlias(name)) => self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| match cycle.get(name) {
Some(target) => self.is_subtype(base, target),
None => false,
}),

(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(base_item), FieldType::Optional(other_item)) => {
self.is_subtype(base_item, other_item)
}
(_, FieldType::Optional(t)) => self.is_subtype(base, t),
(FieldType::Optional(_), _) => false,

// Handle types that nest other types.
(FieldType::List(base_item), FieldType::List(other_item)) => {
self.is_subtype(&base_item, other_item)
}
(FieldType::List(_), _) => false,

(FieldType::Map(base_k, base_v), FieldType::Map(other_k, other_v)) => {
self.is_subtype(other_k, base_k) && self.is_subtype(&**base_v, other_v)
}
(FieldType::Map(_, _), _) => false,

(
FieldType::Constrained {
base: constrained_base,
constraints: base_constraints,
},
FieldType::Constrained {
base: other_base,
constraints: other_constraints,
},
) => {
self.is_subtype(constrained_base, other_base)
&& base_constraints == other_constraints
}
(
FieldType::Constrained {
base: contrained_base,
..
},
_,
) => self.is_subtype(contrained_base, other),
(
_,
FieldType::Constrained {
base: constrained_base,
..
},
) => self.is_subtype(base, constrained_base),

(FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => {
true
}
(FieldType::Literal(LiteralValue::Bool(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::Bool))
}
(FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => {
true
}
(FieldType::Literal(LiteralValue::Int(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::Int))
}
(
FieldType::Literal(LiteralValue::String(_)),
FieldType::Primitive(TypeValue::String),
) => true,
(FieldType::Literal(LiteralValue::String(_)), _) => {
self.is_subtype(base, &FieldType::Primitive(TypeValue::String))
}

(FieldType::Union(items), _) => items.iter().all(|item| self.is_subtype(item, other)),

(FieldType::Tuple(base_items), FieldType::Tuple(other_items)) => {
base_items.len() == other_items.len()
&& base_items
.iter()
.zip(other_items)
.all(|(base_item, other_item)| self.is_subtype(base_item, other_item))
}
(FieldType::Tuple(_), _) => false,
(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
(FieldType::Class(_), _) => false,
}
}

/// For some `BamlValue` with type `FieldType`, walk the structure of both the value
/// and the type simultaneously, associating each node in the `BamlValue` with its
/// `FieldType`.
Expand All @@ -216,48 +335,48 @@ impl IRHelper for IntermediateRepr {
let literal_type = FieldType::Literal(LiteralValue::String(s.clone()));
let primitive_type = FieldType::Primitive(TypeValue::String);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
return Ok(BamlValueWithMeta::String(s, field_type));
}
anyhow::bail!("Could not unify String with {:?}", field_type)
}
BamlValue::Int(i)
if FieldType::Literal(LiteralValue::Int(i)).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i)
if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i) => {
let literal_type = FieldType::Literal(LiteralValue::Int(i));
let primitive_type = FieldType::Primitive(TypeValue::Int);

if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
return Ok(BamlValueWithMeta::Int(i, field_type));
}
anyhow::bail!("Could not unify Int with {:?}", field_type)
}

BamlValue::Float(f)
if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Float(f, field_type))
BamlValue::Float(f) => {
if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_type) {
return Ok(BamlValueWithMeta::Float(f, field_type));
}
anyhow::bail!("Could not unify Float with {:?}", field_type)
}
BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type),

BamlValue::Bool(b) => {
let literal_type = FieldType::Literal(LiteralValue::Bool(b));
let primitive_type = FieldType::Primitive(TypeValue::Bool);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
if self.is_subtype(&literal_type, &field_type)
|| self.is_subtype(&primitive_type, &field_type)
{
Ok(BamlValueWithMeta::Bool(b, field_type))
} else {
anyhow::bail!("Could not unify Bool with {:?}", field_type)
}
}

BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => {
BamlValue::Null
if self.is_subtype(&FieldType::Primitive(TypeValue::Null), &field_type) =>
{
Ok(BamlValueWithMeta::Null(field_type))
}
BamlValue::Null => anyhow::bail!("Could not unify Null with {:?}", field_type),
Expand Down Expand Up @@ -287,7 +406,7 @@ impl IRHelper for IntermediateRepr {
Box::new(item_type.clone()),
);

if !map_type.is_subtype_of(&field_type) {
if !self.is_subtype(&map_type, &field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type);
}

Expand Down Expand Up @@ -321,7 +440,7 @@ impl IRHelper for IntermediateRepr {
Some(item_type) => {
let list_type = FieldType::List(Box::new(item_type.clone()));

if !list_type.is_subtype_of(&field_type) {
if !self.is_subtype(&list_type, &field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", list_type, field_type);
} else {
let mapped_items: Vec<BamlValueWithMeta<FieldType>> = items
Expand All @@ -335,23 +454,25 @@ impl IRHelper for IntermediateRepr {
}

BamlValue::Media(m)
if FieldType::Primitive(TypeValue::Media(m.media_type))
.is_subtype_of(&field_type) =>
if self.is_subtype(
&FieldType::Primitive(TypeValue::Media(m.media_type)),
&field_type,
) =>
{
Ok(BamlValueWithMeta::Media(m, field_type))
}
BamlValue::Media(_) => anyhow::bail!("Could not unify Media with {:?}", field_type),

BamlValue::Enum(name, val) => {
if FieldType::Enum(name.clone()).is_subtype_of(&field_type) {
if self.is_subtype(&FieldType::Enum(name.clone()), &field_type) {
Ok(BamlValueWithMeta::Enum(name, val, field_type))
} else {
anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type)
}
}

BamlValue::Class(name, fields) => {
if !FieldType::Class(name.clone()).is_subtype_of(&field_type) {
if !self.is_subtype(&FieldType::Class(name.clone()), &field_type) {
anyhow::bail!("Could not unify Class {} with {:?}", name, field_type);
} else {
let class_type = &self.find_class(&name)?.item.elem;
Expand Down Expand Up @@ -794,3 +915,104 @@ mod tests {
assert_eq!(constraints, expected_constraints);
}
}

// TODO: Copy pasted from baml-lib/baml-types/src/field_type/mod.rs and poorly
// refactored to match the `is_subtype` changes. Do something with this.
#[cfg(test)]
mod subtype_tests {
use baml_types::BamlMediaType;
use repr::make_test_ir;

use super::*;

fn mk_int() -> FieldType {
FieldType::Primitive(TypeValue::Int)
}
fn mk_bool() -> FieldType {
FieldType::Primitive(TypeValue::Bool)
}
fn mk_str() -> FieldType {
FieldType::Primitive(TypeValue::String)
}

fn mk_optional(ft: FieldType) -> FieldType {
FieldType::Optional(Box::new(ft))
}

fn mk_list(ft: FieldType) -> FieldType {
FieldType::List(Box::new(ft))
}

fn mk_tuple(ft: Vec<FieldType>) -> FieldType {
FieldType::Tuple(ft)
}
fn mk_union(ft: Vec<FieldType>) -> FieldType {
FieldType::Union(ft)
}
fn mk_str_map(ft: FieldType) -> FieldType {
FieldType::Map(Box::new(mk_str()), Box::new(ft))
}

fn ir() -> IntermediateRepr {
make_test_ir("").unwrap()
}

#[test]
fn subtype_trivial() {
assert!(ir().is_subtype(&mk_int(), &mk_int()))
}

#[test]
fn subtype_union() {
let i = mk_int();
let u = mk_union(vec![mk_int(), mk_str()]);
assert!(ir().is_subtype(&i, &u));
assert!(!ir().is_subtype(&u, &i));

let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]);
assert!(ir().is_subtype(&i, &u3));
assert!(ir().is_subtype(&u, &u3));
assert!(!ir().is_subtype(&u3, &u));
}

#[test]
fn subtype_optional() {
let i = mk_int();
let o = mk_optional(mk_int());
assert!(ir().is_subtype(&i, &o));
assert!(!ir().is_subtype(&o, &i));
}

#[test]
fn subtype_list() {
let l_i = mk_list(mk_int());
let l_o = mk_list(mk_optional(mk_int()));
assert!(ir().is_subtype(&l_i, &l_o));
assert!(!ir().is_subtype(&l_o, &l_i));
}

#[test]
fn subtype_tuple() {
let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]);
let y = mk_tuple(vec![mk_int(), mk_int()]);
assert!(ir().is_subtype(&y, &x));
assert!(!ir().is_subtype(&x, &y));
}

#[test]
fn subtype_map_of_list_of_unions() {
let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string())));
let y = mk_str_map(mk_list(mk_union(vec![
mk_str(),
mk_int(),
FieldType::Class("Foo".to_string()),
])));
assert!(ir().is_subtype(&x, &y));
}

#[test]
fn subtype_media() {
let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio));
assert!(ir().is_subtype(&x, &x));
}
}
6 changes: 1 addition & 5 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ impl ArgCoercer {
value: &BamlValue, // original value passed in by user
scope: &mut ScopeStack,
) -> Result<BamlValue, ()> {
eprintln!("coerce_arg: {value:?} -> {field_type:?}");
eprintln!("scope: {scope}\n");

let value = match ir.distribute_constraints(field_type) {
(FieldType::Primitive(t), _) => match t {
TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()),
Expand Down Expand Up @@ -331,7 +328,6 @@ impl ArgCoercer {
let mut scope = ScopeStack::new();
if first_good_result.is_err() {
let result = self.coerce_arg(ir, option, value, &mut scope);
eprintln!("union inner scope scope: {scope}\n");
if !scope.has_errors() && first_good_result.is_err() {
first_good_result = result
}
Expand Down Expand Up @@ -466,7 +462,7 @@ mod tests {
fn test_mutually_recursive_aliases() {
let ir = make_test_ir(
r##"
type JsonValue = int | string | bool | float | JsonObject | JsonArray
type JsonValue = int | bool | float | string | JsonArray | JsonObject
type JsonObject = map<string, JsonValue>
type JsonArray = JsonValue[]
"##,
Expand Down
Loading

0 comments on commit c4e8b85

Please sign in to comment.