diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 07cc068d8..6c5e71a9d 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -57,6 +57,7 @@ pub trait IRHelper { value: BamlValue, field_type: FieldType, ) -> Result>; + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool; fn distribute_constraints<'a>( &'a self, field_type: &'a FieldType, @@ -203,6 +204,124 @@ impl IRHelper for IntermediateRepr { } } + /// BAML does not support class-based subtyping. Nonetheless some builtin + /// BAML types are subtypes of others, and we need to be able to test this + /// when checking the types of values. + /// + /// For examples of pairs of types and their subtyping relationship, see + /// this module's test suite. + /// + /// Consider renaming this to `is_assignable`. + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool { + if base == other { + return true; + } + + if let FieldType::Union(items) = other { + if items.iter().any(|item| self.is_subtype(base, item)) { + return true; + } + } + + match (base, other) { + // TODO: O(n) + (FieldType::RecursiveTypeAlias(name), _) => self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| match cycle.get(name) { + Some(target) => self.is_subtype(target, other), + None => false, + }), + (_, FieldType::RecursiveTypeAlias(name)) => self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| match cycle.get(name) { + Some(target) => self.is_subtype(base, target), + None => false, + }), + + (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, + (FieldType::Optional(base_item), FieldType::Optional(other_item)) => { + self.is_subtype(base_item, other_item) + } + (_, FieldType::Optional(t)) => self.is_subtype(base, t), + (FieldType::Optional(_), _) => false, + + // Handle types that nest other types. + (FieldType::List(base_item), FieldType::List(other_item)) => { + self.is_subtype(&base_item, other_item) + } + (FieldType::List(_), _) => false, + + (FieldType::Map(base_k, base_v), FieldType::Map(other_k, other_v)) => { + self.is_subtype(other_k, base_k) && self.is_subtype(&**base_v, other_v) + } + (FieldType::Map(_, _), _) => false, + + ( + FieldType::Constrained { + base: constrained_base, + constraints: base_constraints, + }, + FieldType::Constrained { + base: other_base, + constraints: other_constraints, + }, + ) => { + self.is_subtype(constrained_base, other_base) + && base_constraints == other_constraints + } + ( + FieldType::Constrained { + base: contrained_base, + .. + }, + _, + ) => self.is_subtype(contrained_base, other), + ( + _, + FieldType::Constrained { + base: constrained_base, + .. + }, + ) => self.is_subtype(base, constrained_base), + + (FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => { + true + } + (FieldType::Literal(LiteralValue::Bool(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::Bool)) + } + (FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => { + true + } + (FieldType::Literal(LiteralValue::Int(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::Int)) + } + ( + FieldType::Literal(LiteralValue::String(_)), + FieldType::Primitive(TypeValue::String), + ) => true, + (FieldType::Literal(LiteralValue::String(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::String)) + } + + (FieldType::Union(items), _) => items.iter().all(|item| self.is_subtype(item, other)), + + (FieldType::Tuple(base_items), FieldType::Tuple(other_items)) => { + base_items.len() == other_items.len() + && base_items + .iter() + .zip(other_items) + .all(|(base_item, other_item)| self.is_subtype(base_item, other_item)) + } + (FieldType::Tuple(_), _) => false, + (FieldType::Primitive(_), _) => false, + (FieldType::Enum(_), _) => false, + (FieldType::Class(_), _) => false, + } + } + /// For some `BamlValue` with type `FieldType`, walk the structure of both the value /// and the type simultaneously, associating each node in the `BamlValue` with its /// `FieldType`. @@ -216,40 +335,38 @@ impl IRHelper for IntermediateRepr { let literal_type = FieldType::Literal(LiteralValue::String(s.clone())); let primitive_type = FieldType::Primitive(TypeValue::String); - if literal_type.is_subtype_of(&field_type) - || primitive_type.is_subtype_of(&field_type) + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) { return Ok(BamlValueWithMeta::String(s, field_type)); } anyhow::bail!("Could not unify String with {:?}", field_type) } - BamlValue::Int(i) - if FieldType::Literal(LiteralValue::Int(i)).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Int(i, field_type)) - } - BamlValue::Int(i) - if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Int(i, field_type)) - } BamlValue::Int(i) => { + let literal_type = FieldType::Literal(LiteralValue::Int(i)); + let primitive_type = FieldType::Primitive(TypeValue::Int); + + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) + { + return Ok(BamlValueWithMeta::Int(i, field_type)); + } anyhow::bail!("Could not unify Int with {:?}", field_type) } - BamlValue::Float(f) - if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Float(f, field_type)) + BamlValue::Float(f) => { + if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_type) { + return Ok(BamlValueWithMeta::Float(f, field_type)); + } + anyhow::bail!("Could not unify Float with {:?}", field_type) } - BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type), BamlValue::Bool(b) => { let literal_type = FieldType::Literal(LiteralValue::Bool(b)); let primitive_type = FieldType::Primitive(TypeValue::Bool); - if literal_type.is_subtype_of(&field_type) - || primitive_type.is_subtype_of(&field_type) + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) { Ok(BamlValueWithMeta::Bool(b, field_type)) } else { @@ -257,7 +374,9 @@ impl IRHelper for IntermediateRepr { } } - BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => { + BamlValue::Null + if self.is_subtype(&FieldType::Primitive(TypeValue::Null), &field_type) => + { Ok(BamlValueWithMeta::Null(field_type)) } BamlValue::Null => anyhow::bail!("Could not unify Null with {:?}", field_type), @@ -287,10 +406,11 @@ impl IRHelper for IntermediateRepr { Box::new(item_type.clone()), ); - if !map_type.is_subtype_of(&field_type) { + if !self.is_subtype(&map_type, &field_type) { anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type); - } else { - let mapped_fields: BamlMap> = + } + + let mapped_fields: BamlMap> = pairs .into_iter() .map(|(key, val)| { @@ -298,8 +418,7 @@ impl IRHelper for IntermediateRepr { Ok((key, sub_value)) }) .collect::>>>()?; - Ok(BamlValueWithMeta::Map(mapped_fields, field_type)) - } + Ok(BamlValueWithMeta::Map(mapped_fields, field_type)) } None => Ok(BamlValueWithMeta::Map(BamlMap::new(), field_type)), } @@ -321,7 +440,7 @@ impl IRHelper for IntermediateRepr { Some(item_type) => { let list_type = FieldType::List(Box::new(item_type.clone())); - if !list_type.is_subtype_of(&field_type) { + if !self.is_subtype(&list_type, &field_type) { anyhow::bail!("Could not unify {:?} with {:?}", list_type, field_type); } else { let mapped_items: Vec> = items @@ -335,15 +454,17 @@ impl IRHelper for IntermediateRepr { } BamlValue::Media(m) - if FieldType::Primitive(TypeValue::Media(m.media_type)) - .is_subtype_of(&field_type) => + if self.is_subtype( + &FieldType::Primitive(TypeValue::Media(m.media_type)), + &field_type, + ) => { Ok(BamlValueWithMeta::Media(m, field_type)) } BamlValue::Media(_) => anyhow::bail!("Could not unify Media with {:?}", field_type), BamlValue::Enum(name, val) => { - if FieldType::Enum(name.clone()).is_subtype_of(&field_type) { + if self.is_subtype(&FieldType::Enum(name.clone()), &field_type) { Ok(BamlValueWithMeta::Enum(name, val, field_type)) } else { anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type) @@ -351,7 +472,7 @@ impl IRHelper for IntermediateRepr { } BamlValue::Class(name, fields) => { - if !FieldType::Class(name.clone()).is_subtype_of(&field_type) { + if !self.is_subtype(&FieldType::Class(name.clone()), &field_type) { anyhow::bail!("Could not unify Class {} with {:?}", name, field_type); } else { let class_type = &self.find_class(&name)?.item.elem; @@ -794,3 +915,104 @@ mod tests { assert_eq!(constraints, expected_constraints); } } + +// TODO: Copy pasted from baml-lib/baml-types/src/field_type/mod.rs and poorly +// refactored to match the `is_subtype` changes. Do something with this. +#[cfg(test)] +mod subtype_tests { + use baml_types::BamlMediaType; + use repr::make_test_ir; + + use super::*; + + fn mk_int() -> FieldType { + FieldType::Primitive(TypeValue::Int) + } + fn mk_bool() -> FieldType { + FieldType::Primitive(TypeValue::Bool) + } + fn mk_str() -> FieldType { + FieldType::Primitive(TypeValue::String) + } + + fn mk_optional(ft: FieldType) -> FieldType { + FieldType::Optional(Box::new(ft)) + } + + fn mk_list(ft: FieldType) -> FieldType { + FieldType::List(Box::new(ft)) + } + + fn mk_tuple(ft: Vec) -> FieldType { + FieldType::Tuple(ft) + } + fn mk_union(ft: Vec) -> FieldType { + FieldType::Union(ft) + } + fn mk_str_map(ft: FieldType) -> FieldType { + FieldType::Map(Box::new(mk_str()), Box::new(ft)) + } + + fn ir() -> IntermediateRepr { + make_test_ir("").unwrap() + } + + #[test] + fn subtype_trivial() { + assert!(ir().is_subtype(&mk_int(), &mk_int())) + } + + #[test] + fn subtype_union() { + let i = mk_int(); + let u = mk_union(vec![mk_int(), mk_str()]); + assert!(ir().is_subtype(&i, &u)); + assert!(!ir().is_subtype(&u, &i)); + + let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]); + assert!(ir().is_subtype(&i, &u3)); + assert!(ir().is_subtype(&u, &u3)); + assert!(!ir().is_subtype(&u3, &u)); + } + + #[test] + fn subtype_optional() { + let i = mk_int(); + let o = mk_optional(mk_int()); + assert!(ir().is_subtype(&i, &o)); + assert!(!ir().is_subtype(&o, &i)); + } + + #[test] + fn subtype_list() { + let l_i = mk_list(mk_int()); + let l_o = mk_list(mk_optional(mk_int())); + assert!(ir().is_subtype(&l_i, &l_o)); + assert!(!ir().is_subtype(&l_o, &l_i)); + } + + #[test] + fn subtype_tuple() { + let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); + let y = mk_tuple(vec![mk_int(), mk_int()]); + assert!(ir().is_subtype(&y, &x)); + assert!(!ir().is_subtype(&x, &y)); + } + + #[test] + fn subtype_map_of_list_of_unions() { + let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string()))); + let y = mk_str_map(mk_list(mk_union(vec![ + mk_str(), + mk_int(), + FieldType::Class("Foo".to_string()), + ]))); + assert!(ir().is_subtype(&x, &y)); + } + + #[test] + fn subtype_media() { + let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); + assert!(ir().is_subtype(&x, &x)); + } +} diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index c314d9be9..a13c5cc49 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -263,6 +263,24 @@ impl ArgCoercer { Err(()) } }, + (FieldType::RecursiveTypeAlias(name), _) => { + let mut maybe_coerced = None; + // TODO: Fix this O(n) + for cycle in ir.structural_recursive_alias_cycles() { + if let Some(target) = cycle.get(name) { + maybe_coerced = Some(self.coerce_arg(ir, target, value, scope)?); + break; + } + } + + match maybe_coerced { + Some(coerced) => Ok(coerced), + None => { + scope.push_error(format!("Recursive type alias {} not found", name)); + Err(()) + } + } + } (FieldType::List(item), _) => match value { BamlValue::List(arr) => { let mut items = Vec::new(); @@ -439,4 +457,36 @@ mod tests { let res = arg_coercer.coerce_arg(&ir, &type_, &value, &mut ScopeStack::new()); assert!(res.is_err()); } + + #[test] + fn test_mutually_recursive_aliases() { + let ir = make_test_ir( + r##" +type JsonValue = int | bool | float | string | JsonArray | JsonObject +type JsonObject = map +type JsonArray = JsonValue[] + "##, + ) + .unwrap(); + + let arg_coercer = ArgCoercer { + span_path: None, + allow_implicit_cast_to_string: true, + }; + + let json = BamlValue::Map(BamlMap::from([ + ("number".to_string(), BamlValue::Int(1)), + ("string".to_string(), BamlValue::String("test".to_string())), + ("bool".to_string(), BamlValue::Bool(true)), + ])); + + let res = arg_coercer.coerce_arg( + &ir, + &FieldType::RecursiveTypeAlias("JsonValue".to_string()), + &json, + &mut ScopeStack::new(), + ); + + assert_eq!(res, Ok(json)); + } } diff --git a/engine/baml-lib/baml-core/src/ir/json_schema.rs b/engine/baml-lib/baml-core/src/ir/json_schema.rs index 56e492f3f..79c463ea1 100644 --- a/engine/baml-lib/baml-core/src/ir/json_schema.rs +++ b/engine/baml-lib/baml-core/src/ir/json_schema.rs @@ -159,6 +159,9 @@ impl WithJsonSchema for FieldType { FieldType::Literal(v) => json!({ "const": v.to_string(), }), + FieldType::RecursiveTypeAlias(_) => json!({ + "type": ["number", "string", "boolean", "object", "array", "null"] + }), FieldType::Primitive(t) => match t { TypeValue::String => json!({ "type": "string", diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index 76a223c2e..292e30922 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -1,20 +1,19 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; -use baml_types::{Constraint, ConstraintLevel, FieldType, StringOr, UnresolvedValue}; -use either::Either; +use baml_types::{ + Constraint, ConstraintLevel, FieldType, JinjaExpression, StringOr, UnresolvedValue, +}; use indexmap::{IndexMap, IndexSet}; use internal_baml_parser_database::{ walkers::{ ClassWalker, ClientWalker, ConfigurationWalker, EnumValueWalker, EnumWalker, FieldWalker, FunctionWalker, TemplateStringWalker, Walker as AstWalker, }, - Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, + Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; -use internal_baml_schema_ast::ast::{SubType, ValExpId}; -use baml_types::JinjaExpression; -use internal_baml_schema_ast::ast::{self, FieldArity, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{self, FieldArity, SubType, ValExpId, WithName, WithSpan}; use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty}; use serde::Serialize; @@ -28,13 +27,17 @@ use crate::Configuration; pub struct IntermediateRepr { enums: Vec>, classes: Vec>, - /// Strongly connected components of the dependency graph (finite cycles). - finite_recursive_cycles: Vec>, functions: Vec>, clients: Vec>, retry_policies: Vec>, template_strings: Vec>, + /// Strongly connected components of the dependency graph (finite cycles). + finite_recursive_cycles: Vec>, + + /// Type alias cycles introduced by lists and maps. + structural_recursive_alias_cycles: Vec>, + configuration: Configuration, } @@ -53,6 +56,7 @@ impl IntermediateRepr { enums: vec![], classes: vec![], finite_recursive_cycles: vec![], + structural_recursive_alias_cycles: vec![], functions: vec![], clients: vec![], retry_policies: vec![], @@ -98,6 +102,10 @@ impl IntermediateRepr { &self.finite_recursive_cycles } + pub fn structural_recursive_alias_cycles(&self) -> &[IndexMap] { + &self.structural_recursive_alias_cycles + } + pub fn walk_enums(&self) -> impl ExactSizeIterator>> { self.enums.iter().map(|e| Walker { db: self, item: e }) } @@ -106,6 +114,14 @@ impl IntermediateRepr { self.classes.iter().map(|e| Walker { db: self, item: e }) } + // TODO: Exact size Iterator + Node<>? + pub fn walk_alias_cycles(&self) -> impl Iterator> { + self.structural_recursive_alias_cycles + .iter() + .flatten() + .map(|e| Walker { db: self, item: e }) + } + pub fn function_names(&self) -> impl ExactSizeIterator { self.functions.iter().map(|f| f.elem.name()) } @@ -168,6 +184,18 @@ impl IntermediateRepr { .collect() }) .collect(), + structural_recursive_alias_cycles: { + let mut recursive_aliases = vec![]; + for cycle in db.structural_recursive_alias_cycles() { + let mut component = IndexMap::new(); + for id in cycle { + let alias = &db.ast()[*id]; + component.insert(alias.name().to_string(), alias.value.repr(db)?); + } + recursive_aliases.push(component); + } + recursive_aliases + }, functions: db .walk_functions() .map(|e| e.node(db)) @@ -395,10 +423,9 @@ impl WithRepr for ast::FieldType { } ast::FieldType::Symbol(arity, idn, ..) => type_with_arity( match db.find_type(idn) { - Some(Either::Left(class_walker)) => { + Some(TypeWalker::Class(class_walker)) => { let base_class = FieldType::Class(class_walker.name().to_string()); - let maybe_constraints = class_walker.get_constraints(SubType::Class); - match maybe_constraints { + match class_walker.get_constraints(SubType::Class) { Some(constraints) if !constraints.is_empty() => { FieldType::Constrained { base: Box::new(base_class), @@ -408,10 +435,9 @@ impl WithRepr for ast::FieldType { _ => base_class, } } - Some(Either::Right(enum_walker)) => { + Some(TypeWalker::Enum(enum_walker)) => { let base_type = FieldType::Enum(enum_walker.name().to_string()); - let maybe_constraints = enum_walker.get_constraints(SubType::Enum); - match maybe_constraints { + match enum_walker.get_constraints(SubType::Enum) { Some(constraints) if !constraints.is_empty() => { FieldType::Constrained { base: Box::new(base_type), @@ -421,6 +447,18 @@ impl WithRepr for ast::FieldType { _ => base_type, } } + Some(TypeWalker::TypeAlias(alias_walker)) => { + if db + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| cycle.contains(&alias_walker.id)) + { + FieldType::RecursiveTypeAlias(alias_walker.name().to_string()) + } else { + alias_walker.resolved().to_owned().repr(db)? + } + } + None => return Err(anyhow!("Field type uses unresolvable local identifier")), }, arity, @@ -1109,7 +1147,7 @@ pub fn make_test_ir(source_code: &str) -> anyhow::Result { #[cfg(test)] mod tests { use super::*; - use crate::ir::ir_helpers::IRHelper; + use crate::ir::{ir_helpers::IRHelper, TypeValue}; #[test] fn test_docstrings() { @@ -1202,4 +1240,62 @@ mod tests { let walker = ir.find_test(&function, "Foo").unwrap(); assert_eq!(walker.item.1.elem.constraints.len(), 1); } + + #[test] + fn test_resolve_type_alias() { + let ir = make_test_ir( + r##" + type One = int + type Two = One + type Three = Two + + class Test { + field Three + } + "##, + ) + .unwrap(); + + let class = ir.find_class("Test").unwrap(); + let alias = class.find_field("field").unwrap(); + + assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int)); + } + + #[test] + fn test_merge_type_alias_attributes() { + let ir = make_test_ir( + r##" + type One = int @check(gt_ten, {{ this > 10 }}) + type Two = One @check(lt_twenty, {{ this < 20 }}) + type Three = Two @assert({{ this != 15 }}) + + class Test { + field Three + } + "##, + ) + .unwrap(); + + let class = ir.find_class("Test").unwrap(); + let alias = class.find_field("field").unwrap(); + + let FieldType::Constrained { base, constraints } = alias.r#type() else { + panic!( + "expected resolved constrained type, found {:?}", + alias.r#type() + ); + }; + + assert_eq!(constraints.len(), 3); + + assert_eq!(constraints[0].level, ConstraintLevel::Assert); + assert_eq!(constraints[0].label, None); + + assert_eq!(constraints[1].level, ConstraintLevel::Check); + assert_eq!(constraints[1].label, Some("lt_twenty".to_string())); + + assert_eq!(constraints[2].level, ConstraintLevel::Check); + assert_eq!(constraints[2].label, Some("gt_ten".to_string())); + } } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index c177d1f8e..f7cf76ebb 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -1,21 +1,53 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + ops::Index, +}; -use either::Either; use internal_baml_diagnostics::DatamodelError; -use internal_baml_parser_database::Tarjan; -use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan}; +use internal_baml_parser_database::{Tarjan, TypeWalker}; +use internal_baml_schema_ast::ast::{ + FieldType, SchemaAst, TypeAliasId, TypeExpId, WithName, WithSpan, +}; use crate::validate::validation_pipeline::context::Context; /// Validates if the dependency graph contains one or more infinite cycles. pub(super) fn validate(ctx: &mut Context<'_>) { - // First, build a graph of all the "required" dependencies represented as an + // We'll check type alias cycles first. Just like Typescript, cycles are + // allowed only for maps and lists. We'll call such cycles "structural + // recursion". Anything else like nulls or unions won't terminate a cycle. + let structural_type_aliases = HashMap::from_iter(ctx.db.walk_type_aliases().map(|alias| { + let mut dependencies = HashSet::new(); + insert_required_alias_deps(alias.target(), ctx, &mut dependencies); + + (alias.id, dependencies) + })); + + // Based on the graph we've built with does not include the edges created + // by maps and lists, check the cycles and report them. + report_infinite_cycles( + &structural_type_aliases, + ctx, + "These aliases form a dependency cycle", + ); + + // In order to avoid infinite recursion when resolving types for class + // dependencies below, we'll compute the cycles of aliases including maps + // and lists so that the recursion can be stopped before entering a cycle. + let complete_alias_cycles = Tarjan::components(ctx.db.type_alias_dependencies()) + .iter() + .flatten() + .copied() + .collect(); + + // Now build a graph of all the "required" dependencies represented as an // adjacency list. We're only going to consider type dependencies that can // actually cause infinite recursion. Unions and optionals can stop the // recursion at any point, so they don't have to be part of the "dependency" // graph because technically an optional field doesn't "depend" on anything, // it can just be null. - let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| { + let class_dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| { let expr_block = &ctx.db.ast()[class.id]; // TODO: There's already a hash set that returns "dependencies" in @@ -31,14 +63,44 @@ pub(super) fn validate(ctx: &mut Context<'_>) { for field in &expr_block.fields { if let Some(field_type) = &field.expr { - insert_required_deps(class.id, field_type, ctx, &mut dependencies); + insert_required_class_deps( + class.id, + field_type, + ctx, + &mut dependencies, + &complete_alias_cycles, + ); } } (class.id, dependencies) })); - for component in Tarjan::components(&dependency_graph) { + report_infinite_cycles( + &class_dependency_graph, + ctx, + "These classes form a dependency cycle", + ); +} + +/// Finds and reports all the infinite cycles in the given graph. +/// +/// It prints errors like this: +/// +/// "Error validating: These classes form a dependency cycle: A -> B -> C" +fn report_infinite_cycles( + graph: &HashMap>, + ctx: &mut Context<'_>, + message: &str, +) -> Vec> +where + SchemaAst: Index, + >::Output: WithName, + >::Output: WithSpan, +{ + let components = Tarjan::components(graph); + + for component in &components { let cycle = component .iter() .map(|id| ctx.db.ast()[*id].name().to_string()) @@ -48,10 +110,12 @@ pub(super) fn validate(ctx: &mut Context<'_>) { // TODO: We can push an error for every sinlge class here (that's what // Rust does), for now it's an error for every cycle found. ctx.push_error(DatamodelError::new_validation_error( - &format!("These classes form a dependency cycle: {}", cycle), + &format!("{message}: {cycle}"), ctx.db.ast()[component[0]].span().clone(), )); } + + components } /// Inserts all the required dependencies of a field into the given set. @@ -59,16 +123,48 @@ pub(super) fn validate(ctx: &mut Context<'_>) { /// Recursively deals with unions of unions. Can be implemented iteratively with /// a while loop and a stack/queue if this ends up being slow / inefficient or /// it reaches stack overflows with large inputs. -fn insert_required_deps( +/// +/// TODO: Use a struct to keep all this state. Too many parameters already. +fn insert_required_class_deps( id: TypeExpId, field: &FieldType, ctx: &Context<'_>, deps: &mut HashSet, + alias_cycles: &HashSet, ) { match field { FieldType::Symbol(arity, ident, _) if arity.is_required() => { - if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) { - deps.insert(class.id); + match ctx.db.find_type_by_str(ident.name()) { + Some(TypeWalker::Class(class)) => { + deps.insert(class.id); + } + Some(TypeWalker::TypeAlias(alias)) => { + // TODO: By the time this code runs we would ideally want + // type aliases to be resolved but we can't do that because + // type alias cycles are not validated yet, we have to + // do that in this file. Take a look at the `validate` + // function at `baml-lib/baml-core/src/lib.rs`. + // + // First we run the `ParserDatabase::validate` function + // which creates the alias graph by visiting all aliases. + // Then we run the `validate::validate` which ends up + // running this code here. Finally we run the + // `ParserDatabase::finalize` which is the place where we + // can resolve type aliases since we've already validated + // that there are no cycles so we won't run into infinite + // recursion. Ideally we want this: + // + // insert_required_deps(id, alias.resolved(), ctx, deps); + + // But we'll run this instead which will follow all the + // alias pointers again until it finds the resolved type. + // We also have to stop recursion if we know the alias is + // part of a cycle. + if !alias_cycles.contains(&alias.id) { + insert_required_class_deps(id, alias.target(), ctx, deps, alias_cycles) + } + } + _ => {} } } @@ -82,7 +178,7 @@ fn insert_required_deps( let mut nested_deps = HashSet::new(); for f in field_types { - insert_required_deps(id, f, ctx, &mut nested_deps); + insert_required_class_deps(id, f, ctx, &mut nested_deps, alias_cycles); // No nested deps found on this component, this makes the // union finite, so no need to go deeper. @@ -112,3 +208,26 @@ fn insert_required_deps( _ => {} } } + +/// Implemented a la TS, maps and lists are not included as edges. +fn insert_required_alias_deps( + field_type: &FieldType, + ctx: &Context<'_>, + required: &mut HashSet, +) { + match field_type { + FieldType::Symbol(_, ident, _) => { + if let Some(TypeWalker::TypeAlias(alias)) = ctx.db.find_type_by_str(ident.name()) { + required.insert(alias.id); + } + } + + FieldType::Union(_, field_types, ..) | FieldType::Tuple(_, field_types, ..) => { + for f in field_types { + insert_required_alias_deps(f, ctx, required); + } + } + + _ => {} + } +} diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs index e657bb9f7..37b252949 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs @@ -2,9 +2,9 @@ use std::collections::HashSet; use crate::validate::validation_pipeline::context::Context; -use either::Either; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; +use internal_baml_parser_database::TypeWalker; use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; use super::types::validate_type; @@ -247,7 +247,7 @@ impl<'c> NestedChecks<'c> { match field_type { FieldType::Symbol(_, id, ..) => match self.ctx.db.find_type(id) { - Some(Either::Left(class_walker)) => { + Some(TypeWalker::Class(class_walker)) => { // Stop recursion when dealing with recursive types. if !self.visited.insert(class_walker.id) { return false; diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs index d8a33485c..68b6d252f 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use crate::validate::validation_pipeline::context::Context; -use either::Either; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs index 2d13bd8f6..639195c3a 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs @@ -3,6 +3,7 @@ use std::collections::VecDeque; use baml_types::{LiteralValue, TypeValue}; use either::Either; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; +use internal_baml_parser_database::TypeWalker; use internal_baml_schema_ast::ast::{ Argument, Attribute, Expression, FieldArity, FieldType, Identifier, WithName, WithSpan, }; @@ -62,7 +63,7 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { if ctx .db .find_type(identifier) - .is_some_and(|t| matches!(t, Either::Right(_))) => {} + .is_some_and(|t| matches!(t, TypeWalker::Enum(_))) => {} // Literal string key. FieldType::Literal(FieldArity::Required, LiteralValue::String(_), ..) => {} diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index 7c9809871..52f59fae0 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -85,6 +85,7 @@ pub enum FieldType { Union(Vec), Tuple(Vec), Optional(Box), + RecursiveTypeAlias(String), Constrained { base: Box, constraints: Vec, @@ -95,9 +96,9 @@ pub enum FieldType { impl std::fmt::Display for FieldType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - FieldType::Enum(name) | FieldType::Class(name) => { - write!(f, "{name}") - } + FieldType::Enum(name) + | FieldType::Class(name) + | FieldType::RecursiveTypeAlias(name) => write!(f, "{name}"), FieldType::Primitive(t) => write!(f, "{t}"), FieldType::Literal(v) => write!(f, "{v}"), FieldType::Union(choices) => { @@ -159,186 +160,4 @@ impl FieldType { _ => false, } } - - /// BAML does not support class-based subtyping. Nonetheless some builtin - /// BAML types are subtypes of others, and we need to be able to test this - /// when checking the types of values. - /// - /// For examples of pairs of types and their subtyping relationship, see - /// this module's test suite. - /// - /// Consider renaming this to `is_assignable_to`. - pub fn is_subtype_of(&self, other: &FieldType) -> bool { - if self == other { - true - } else { - if let FieldType::Union(items) = other { - if items.iter().any(|item| self.is_subtype_of(item)) { - return true; - } - } - match (self, other) { - (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, - (FieldType::Optional(self_item), FieldType::Optional(other_item)) => { - self_item.is_subtype_of(other_item) - } - (_, FieldType::Optional(t)) => self.is_subtype_of(t), - (FieldType::Optional(_), _) => false, - - // Handle types that nest other types. - (FieldType::List(self_item), FieldType::List(other_item)) => { - self_item.is_subtype_of(other_item) - } - (FieldType::List(_), _) => false, - - (FieldType::Map(self_k, self_v), FieldType::Map(other_k, other_v)) => { - other_k.is_subtype_of(self_k) && (**self_v).is_subtype_of(other_v) - } - (FieldType::Map(_, _), _) => false, - - ( - FieldType::Constrained { - base: self_base, - constraints: self_cs, - }, - FieldType::Constrained { - base: other_base, - constraints: other_cs, - }, - ) => self_base.is_subtype_of(other_base) && self_cs == other_cs, - (FieldType::Constrained { base, .. }, _) => base.is_subtype_of(other), - (_, FieldType::Constrained { base, .. }) => self.is_subtype_of(base), - ( - FieldType::Literal(LiteralValue::Bool(_)), - FieldType::Primitive(TypeValue::Bool), - ) => true, - (FieldType::Literal(LiteralValue::Bool(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::Bool)) - } - ( - FieldType::Literal(LiteralValue::Int(_)), - FieldType::Primitive(TypeValue::Int), - ) => true, - (FieldType::Literal(LiteralValue::Int(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::Int)) - } - ( - FieldType::Literal(LiteralValue::String(_)), - FieldType::Primitive(TypeValue::String), - ) => true, - (FieldType::Literal(LiteralValue::String(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::String)) - } - - (FieldType::Union(self_items), _) => self_items - .iter() - .all(|self_item| self_item.is_subtype_of(other)), - - (FieldType::Tuple(self_items), FieldType::Tuple(other_items)) => { - self_items.len() == other_items.len() - && self_items - .iter() - .zip(other_items) - .all(|(self_item, other_item)| self_item.is_subtype_of(other_item)) - } - (FieldType::Tuple(_), _) => false, - - (FieldType::Primitive(_), _) => false, - (FieldType::Enum(_), _) => false, - (FieldType::Class(_), _) => false, - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn mk_int() -> FieldType { - FieldType::Primitive(TypeValue::Int) - } - fn mk_bool() -> FieldType { - FieldType::Primitive(TypeValue::Bool) - } - fn mk_str() -> FieldType { - FieldType::Primitive(TypeValue::String) - } - - fn mk_optional(ft: FieldType) -> FieldType { - FieldType::Optional(Box::new(ft)) - } - - fn mk_list(ft: FieldType) -> FieldType { - FieldType::List(Box::new(ft)) - } - - fn mk_tuple(ft: Vec) -> FieldType { - FieldType::Tuple(ft) - } - fn mk_union(ft: Vec) -> FieldType { - FieldType::Union(ft) - } - fn mk_str_map(ft: FieldType) -> FieldType { - FieldType::Map(Box::new(mk_str()), Box::new(ft)) - } - - #[test] - fn subtype_trivial() { - assert!(mk_int().is_subtype_of(&mk_int())) - } - - #[test] - fn subtype_union() { - let i = mk_int(); - let u = mk_union(vec![mk_int(), mk_str()]); - assert!(i.is_subtype_of(&u)); - assert!(!u.is_subtype_of(&i)); - - let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]); - assert!(i.is_subtype_of(&u3)); - assert!(u.is_subtype_of(&u3)); - assert!(!u3.is_subtype_of(&u)); - } - - #[test] - fn subtype_optional() { - let i = mk_int(); - let o = mk_optional(mk_int()); - assert!(i.is_subtype_of(&o)); - assert!(!o.is_subtype_of(&i)); - } - - #[test] - fn subtype_list() { - let l_i = mk_list(mk_int()); - let l_o = mk_list(mk_optional(mk_int())); - assert!(l_i.is_subtype_of(&l_o)); - assert!(!l_o.is_subtype_of(&l_i)); - } - - #[test] - fn subtype_tuple() { - let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); - let y = mk_tuple(vec![mk_int(), mk_int()]); - assert!(y.is_subtype_of(&x)); - assert!(!x.is_subtype_of(&y)); - } - - #[test] - fn subtype_map_of_list_of_unions() { - let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string()))); - let y = mk_str_map(mk_list(mk_union(vec![ - mk_str(), - mk_int(), - FieldType::Class("Foo".to_string()), - ]))); - assert!(x.is_subtype_of(&y)); - } - - #[test] - fn subtype_media() { - let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); - assert!(x.is_subtype_of(&x)); - } } diff --git a/engine/baml-lib/baml/tests/validation_files/class/invalid_attrs_on_type_alias.baml b/engine/baml-lib/baml/tests/validation_files/class/invalid_attrs_on_type_alias.baml new file mode 100644 index 000000000..6cec9d633 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/invalid_attrs_on_type_alias.baml @@ -0,0 +1,32 @@ +type DescNotAllowed = string @description("This is not allowed") + +type AliasNotAllowed = float @alias("Alias not allowed") + +type SkipNotAllowed = float @skip + +type AttrNotFound = int @assert({{ this > 0 }}) @unknown + +// error: Error validating: type aliases may only have @check and @assert attributes +// --> class/invalid_attrs_on_type_alias.baml:1 +// | +// | +// 1 | type DescNotAllowed = string @description("This is not allowed") +// | +// error: Error validating: type aliases may only have @check and @assert attributes +// --> class/invalid_attrs_on_type_alias.baml:3 +// | +// 2 | +// 3 | type AliasNotAllowed = float @alias("Alias not allowed") +// | +// error: Error validating: type aliases may only have @check and @assert attributes +// --> class/invalid_attrs_on_type_alias.baml:5 +// | +// 4 | +// 5 | type SkipNotAllowed = float @skip +// | +// error: Attribute not known: "@unknown". +// --> class/invalid_attrs_on_type_alias.baml:7 +// | +// 6 | +// 7 | type AttrNotFound = int @assert({{ this > 0 }}) @unknown +// | diff --git a/engine/baml-lib/baml/tests/validation_files/class/invalid_type_aliases.baml b/engine/baml-lib/baml/tests/validation_files/class/invalid_type_aliases.baml new file mode 100644 index 000000000..e1d584c98 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/invalid_type_aliases.baml @@ -0,0 +1,40 @@ +class One { + f int +} + +// Already existing name. +type One = int + +// Unexpected keyword. +typpe Two = float + +// Unknown identifier. +type Three = i + +// Unknown identifier in union. +type Four = int | string | b + +// error: Error validating: Unexpected keyword used in assignment: typpe +// --> class/invalid_type_aliases.baml:9 +// | +// 8 | // Unexpected keyword. +// 9 | typpe Two = float +// | +// error: The type_alias "One" cannot be defined because a class with that name already exists. +// --> class/invalid_type_aliases.baml:6 +// | +// 5 | // Already existing name. +// 6 | type One = int +// | +// error: Error validating: Type alias points to unknown identifier `i` +// --> class/invalid_type_aliases.baml:12 +// | +// 11 | // Unknown identifier. +// 12 | type Three = i +// | +// error: Error validating: Type alias points to unknown identifier `b` +// --> class/invalid_type_aliases.baml:15 +// | +// 14 | // Unknown identifier in union. +// 15 | type Four = int | string | b +// | diff --git a/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml b/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml new file mode 100644 index 000000000..f7ff88ef0 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml @@ -0,0 +1,84 @@ +// Simple alias that points to recursive type. +class Node { + value int + next Node? +} + +type LinkedList = Node + +// Mutual recursion. There is no "type" here at all. +type One = Two + +type Two = One + +// Cycle. Same as above but longer. +type A = B + +type B = C + +type C = A + +// Recursive class with alias pointing to itself. +class Recursive { + value int + ptr RecAlias +} + +type RecAlias = Recursive + +// Same but finite. +class FiniteRecursive { + value int + ptr FiniteRecAlias? +} + +type FiniteRecAlias = FiniteRecursive + +// Move the "finite" condition to the alias itself. Should still work. +class RecursiveWithOptionalAlias { + value int + ptr RecOptionalAlias +} + +type RecOptionalAlias = RecursiveWithOptionalAlias? + +// Class that points to alias which enters infinite cycle. +class InfiniteCycle { + value int + ptr EnterCycle +} + +type EnterCycle = NoStop + +type NoStop = EnterCycle + +// RecursiveMap +type Map = map + +// error: Error validating: These aliases form a dependency cycle: One -> Two +// --> class/recursive_type_aliases.baml:10 +// | +// 9 | // Mutual recursion. There is no "type" here at all. +// 10 | type One = Two +// | +// error: Error validating: These aliases form a dependency cycle: A -> B -> C +// --> class/recursive_type_aliases.baml:15 +// | +// 14 | // Cycle. Same as above but longer. +// 15 | type A = B +// | +// error: Error validating: These aliases form a dependency cycle: EnterCycle -> NoStop +// --> class/recursive_type_aliases.baml:51 +// | +// 50 | +// 51 | type EnterCycle = NoStop +// | +// error: Error validating: These classes form a dependency cycle: Recursive +// --> class/recursive_type_aliases.baml:22 +// | +// 21 | // Recursive class with alias pointing to itself. +// 22 | class Recursive { +// 23 | value int +// 24 | ptr RecAlias +// 25 | } +// | diff --git a/engine/baml-lib/baml/tests/validation_files/class/type_aliases.baml b/engine/baml-lib/baml/tests/validation_files/class/type_aliases.baml new file mode 100644 index 000000000..27963ba76 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/type_aliases.baml @@ -0,0 +1,58 @@ +type Primitive = int | string | bool | float + +type List = string[] + +type Graph = map + +type Combination = Primitive | List | Graph + +// Alias with attrs. +type Currency = int @check(gt_ten, {{ this > 10 }}) +type Amount = Currency @assert({{ this > 0 }}) + +// Should be allowed. +type MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }}) + +class MergeAttrs { + amount Amount @description("In USD") +} + +function PrimitiveAlias(p: Primitive) -> Primitive { + client "openai/gpt-4o" + prompt r#" + Return the given value back: {{ p }} + "# +} + +function MapAlias(m: Graph) -> Graph { + client "openai/gpt-4o" + prompt r#" + Return the given Graph back: + + {{ m }} + + {{ ctx.output_format }} + "# +} + +function NestedAlias(c: Combination) -> Combination { + client "openai/gpt-4o" + prompt r#" + Return the given value back: + + {{ c }} + + {{ ctx.output_format }} + "# +} + +function MergeAliasAttributes(money: int) -> MergeAttrs { + client "openai/gpt-4o" + prompt r#" + Return the given integer in the specified format: + + {{ money }} + + {{ ctx.output_format }} + "# +} diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index 404b153f9..188f37035 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -58,6 +58,7 @@ pub struct OutputFormatContent { pub enums: Arc>, pub classes: Arc>, recursive_classes: Arc>, + structural_recursive_aliases: Arc>, pub target: FieldType, } @@ -67,6 +68,8 @@ pub struct Builder { classes: Vec, /// Order matters for this one. recursive_classes: IndexSet, + /// Recursive aliases introduced maps and lists. + structural_recursive_aliases: IndexMap, target: FieldType, } @@ -76,6 +79,7 @@ impl Builder { enums: vec![], classes: vec![], recursive_classes: IndexSet::new(), + structural_recursive_aliases: IndexMap::new(), target, } } @@ -95,6 +99,14 @@ impl Builder { self } + pub fn structural_recursive_aliases( + mut self, + structural_recursive_aliases: IndexMap, + ) -> Self { + self.structural_recursive_aliases = structural_recursive_aliases; + self + } + pub fn target(mut self, target: FieldType) -> Self { self.target = target; self @@ -115,6 +127,9 @@ impl Builder { .collect(), ), recursive_classes: Arc::new(self.recursive_classes.into_iter().collect()), + structural_recursive_aliases: Arc::new( + self.structural_recursive_aliases.into_iter().collect(), + ), target: self.target, } } @@ -333,6 +348,14 @@ impl OutputFormatContent { Some(format!("Answer in JSON using this {type_prefix}:{end}")) } + FieldType::RecursiveTypeAlias(_) => { + let type_prefix = match &options.hoisted_class_prefix { + RenderSetting::Always(prefix) if !prefix.is_empty() => prefix, + _ => RenderOptions::DEFAULT_TYPE_PREFIX_IN_RENDER_MESSAGE, + }; + + Some(format!("Answer in JSON using this {type_prefix}: ")) + } FieldType::List(_) => Some(String::from( "Answer with a JSON Array using this schema:\n", )), @@ -423,9 +446,12 @@ impl OutputFormatContent { } }, FieldType::Literal(v) => v.to_string(), - FieldType::Constrained { base, .. } => { - self.inner_type_render(options, base, render_state, group_hoisted_literals)? - } + FieldType::Constrained { base, .. } => self.render_possibly_recursive_type( + options, + base, + render_state, + group_hoisted_literals, + )?, FieldType::Enum(e) => { let Some(enm) = self.enums.get(e) else { return Err(minijinja::Error::new( @@ -481,9 +507,13 @@ impl OutputFormatContent { } .to_string() } + FieldType::RecursiveTypeAlias(name) => name.to_owned(), FieldType::List(inner) => { let is_recursive = match inner.as_ref() { FieldType::Class(nested_class) => self.recursive_classes.contains(nested_class), + FieldType::RecursiveTypeAlias(name) => { + self.structural_recursive_aliases.contains_key(name) + } _ => false, }; @@ -527,9 +557,14 @@ impl OutputFormatContent { } FieldType::Map(key_type, value_type) => MapRender { style: &options.map_style, - // TODO: Key can't be recursive because we only support strings - // as keys. Change this if needed in the future. - key_type: self.inner_type_render(options, key_type, render_state, false)?, + // NOTE: Key can't be recursive because we only support strings + // as keys. + key_type: self.render_possibly_recursive_type( + options, + key_type, + render_state, + false, + )?, value_type: self.render_possibly_recursive_type( options, value_type, @@ -580,6 +615,7 @@ impl OutputFormatContent { })); let mut class_definitions = Vec::new(); + let mut type_alias_definitions = Vec::new(); // Hoist recursive classes. The render_state struct doesn't need to // contain these classes because we already know that we're gonna hoist @@ -601,6 +637,18 @@ impl OutputFormatContent { }); } + for (alias, target) in self.structural_recursive_aliases.iter() { + let recursive_pointer = + self.inner_type_render(&options, target, &mut render_state, false)?; + + type_alias_definitions.push(match &options.hoisted_class_prefix { + RenderSetting::Always(prefix) if !prefix.is_empty() => { + format!("{prefix} {alias} = {recursive_pointer}") + } + _ => format!("{alias} = {recursive_pointer}"), + }); + } + let mut output = String::new(); if !enum_definitions.is_empty() { @@ -613,6 +661,11 @@ impl OutputFormatContent { output.push_str("\n\n"); } + if !type_alias_definitions.is_empty() { + output.push_str(&type_alias_definitions.join("\n")); + output.push_str("\n\n"); + } + if let Some(p) = prefix { output.push_str(&p); } @@ -649,13 +702,19 @@ impl OutputFormatContent { pub fn find_enum(&self, name: &str) -> Result<&Enum> { self.enums .get(name) - .ok_or_else(|| anyhow::anyhow!("Enum {} not found", name)) + .ok_or_else(|| anyhow::anyhow!("Enum {name} not found")) } pub fn find_class(&self, name: &str) -> Result<&Class> { self.classes .get(name) - .ok_or_else(|| anyhow::anyhow!("Class {} not found", name)) + .ok_or_else(|| anyhow::anyhow!("Class {name} not found")) + } + + pub fn find_recursive_alias_target(&self, name: &str) -> Result<&FieldType> { + self.structural_recursive_aliases + .get(name) + .ok_or_else(|| anyhow::anyhow!("Recursive alias {name} not found")) } } @@ -2214,4 +2273,95 @@ Answer in JSON using this schema: )) ); } + + #[test] + fn render_simple_recursive_aliases() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias( + "RecursiveMapAlias".to_string(), + )) + .structural_recursive_aliases(IndexMap::from([( + "RecursiveMapAlias".to_string(), + FieldType::map( + FieldType::string(), + FieldType::RecursiveTypeAlias("RecursiveMapAlias".to_string()), + ), + )])) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"RecursiveMapAlias = map + +Answer in JSON using this schema: RecursiveMapAlias"# + )) + ); + } + + #[test] + fn render_recursive_alias_cycle() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias("A".to_string())) + .structural_recursive_aliases(IndexMap::from([ + ( + "A".to_string(), + FieldType::RecursiveTypeAlias("B".to_string()), + ), + ( + "B".to_string(), + FieldType::RecursiveTypeAlias("C".to_string()), + ), + ( + "C".to_string(), + FieldType::list(FieldType::RecursiveTypeAlias("A".to_string())), + ), + ])) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"A = B +B = C +C = A[] + +Answer in JSON using this schema: A"# + )) + ); + } + + #[test] + fn render_recursive_alias_cycle_with_hoist_prefix() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias("A".to_string())) + .structural_recursive_aliases(IndexMap::from([ + ( + "A".to_string(), + FieldType::RecursiveTypeAlias("B".to_string()), + ), + ( + "B".to_string(), + FieldType::RecursiveTypeAlias("C".to_string()), + ), + ( + "C".to_string(), + FieldType::list(FieldType::RecursiveTypeAlias("A".to_string())), + ), + ])) + .build(); + let rendered = content + .render(RenderOptions::with_hoisted_class_prefix("type")) + .unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"type A = B +type B = C +type C = A[] + +Answer in JSON using this type: A"# + )) + ); + } } diff --git a/engine/baml-lib/jinja/src/evaluate_type/expr.rs b/engine/baml-lib/jinja/src/evaluate_type/expr.rs index 4172ad03b..ed295d716 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/expr.rs @@ -406,7 +406,7 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type { acc.push(x); Some(Type::Union(acc)) } else { - unreachable!() + unreachable!("minijinja") } } Some(acc) => { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs index 0f22e5bbf..5c1deb6a9 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs @@ -14,7 +14,13 @@ pub fn coerce_array_to_singular( ) -> Result { let parsed = items.iter().map(|item| coercion(item)).collect::>(); - pick_best(ctx, target, &parsed) + let mut best = pick_best(ctx, target, &parsed); + + if let Ok(ref mut f) = best { + f.add_flag(Flag::FirstMatch(0, parsed.to_vec())) + } + + best } pub(super) fn pick_best( @@ -180,6 +186,29 @@ pub(super) fn pick_best( } } + // Devalue strings that were cast from objects. + if !a_val.is_composite() && b_val.is_composite() { + if a_val + .conditions() + .flags() + .iter() + .any(|f| matches!(f, Flag::JsonToString(..) | Flag::FirstMatch(_, _))) + { + return std::cmp::Ordering::Greater; + } + } + + if a_val.is_composite() && !b_val.is_composite() { + if b_val + .conditions() + .flags() + .iter() + .any(|f| matches!(f, Flag::JsonToString(..) | Flag::FirstMatch(_, _))) + { + return std::cmp::Ordering::Less; + } + } + match a_default.cmp(&b_default) { std::cmp::Ordering::Equal => match a_score.cmp(&b_score) { std::cmp::Ordering::Equal => a.cmp(&b), diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs index d22496147..2fcb32b21 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs @@ -24,7 +24,7 @@ pub(super) fn coerce_array( let inner = match list_target { FieldType::List(inner) => inner, - _ => unreachable!(), + _ => unreachable!("coerce_array"), }; let mut items = vec![]; diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_optional.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_optional.rs index 76778913b..4a1e16faa 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_optional.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_optional.rs @@ -23,7 +23,7 @@ pub(super) fn coerce_optional( let inner = match optional_target { FieldType::Optional(inner) => inner, - _ => unreachable!(), + _ => unreachable!("coerce_optional"), }; let mut flags = DeserializerConditions::new(); diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs index 801ec9735..ebac33720 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs @@ -64,18 +64,16 @@ fn coerce_string( target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - if let Some(value) = value { - match value { - crate::jsonish::Value::String(s) => { - Ok(BamlValueWithFlags::String(s.to_string().into())) - } - crate::jsonish::Value::Null => Err(ctx.error_unexpected_null(target)), - v => Ok(BamlValueWithFlags::String( - (v.to_string(), Flag::JsonToString(v.clone())).into(), - )), - } - } else { - Err(ctx.error_unexpected_null(target)) + let Some(value) = value else { + return Err(ctx.error_unexpected_null(target)); + }; + + match value { + crate::jsonish::Value::String(s) => Ok(BamlValueWithFlags::String(s.to_string().into())), + crate::jsonish::Value::Null => Err(ctx.error_unexpected_null(target)), + v => Ok(BamlValueWithFlags::String( + (v.to_string(), Flag::JsonToString(v.clone())).into(), + )), } } @@ -84,54 +82,54 @@ pub(super) fn coerce_int( target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - if let Some(value) = value { - match value { - crate::jsonish::Value::Number(n) => { - if let Some(n) = n.as_i64() { - Ok(BamlValueWithFlags::Int(n.into())) - } else if let Some(n) = n.as_u64() { - Ok(BamlValueWithFlags::Int((n as i64).into())) - } else if let Some(n) = n.as_f64() { - Ok(BamlValueWithFlags::Int( - ((n.round() as i64), Flag::FloatToInt(n)).into(), - )) - } else { - Err(ctx.error_unexpected_type(target, value)) - } - } - crate::jsonish::Value::String(s) => { - let s = s.trim(); - // Trim trailing commas - let s = s.trim_end_matches(','); - if let Ok(n) = s.parse::() { - Ok(BamlValueWithFlags::Int(n.into())) - } else if let Ok(n) = s.parse::() { - Ok(BamlValueWithFlags::Int((n as i64).into())) - } else if let Ok(n) = s.parse::() { - Ok(BamlValueWithFlags::Int( - ((n.round() as i64), Flag::FloatToInt(n)).into(), - )) - } else if let Some(frac) = float_from_maybe_fraction(s) { - Ok(BamlValueWithFlags::Int( - ((frac.round() as i64), Flag::FloatToInt(frac)).into(), - )) - } else if let Some(frac) = float_from_comma_separated(s) { - Ok(BamlValueWithFlags::Int( - ((frac.round() as i64), Flag::FloatToInt(frac)).into(), - )) - } else { - Err(ctx.error_unexpected_type(target, value)) - } + let Some(value) = value else { + return Err(ctx.error_unexpected_null(target)); + }; + + match value { + crate::jsonish::Value::Number(n) => { + if let Some(n) = n.as_i64() { + Ok(BamlValueWithFlags::Int(n.into())) + } else if let Some(n) = n.as_u64() { + Ok(BamlValueWithFlags::Int((n as i64).into())) + } else if let Some(n) = n.as_f64() { + Ok(BamlValueWithFlags::Int( + ((n.round() as i64), Flag::FloatToInt(n)).into(), + )) + } else { + Err(ctx.error_unexpected_type(target, value)) } - crate::jsonish::Value::Array(items) => { - coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { - coerce_int(ctx, target, Some(value)) - }) + } + crate::jsonish::Value::String(s) => { + let s = s.trim(); + // Trim trailing commas + let s = s.trim_end_matches(','); + if let Ok(n) = s.parse::() { + Ok(BamlValueWithFlags::Int(n.into())) + } else if let Ok(n) = s.parse::() { + Ok(BamlValueWithFlags::Int((n as i64).into())) + } else if let Ok(n) = s.parse::() { + Ok(BamlValueWithFlags::Int( + ((n.round() as i64), Flag::FloatToInt(n)).into(), + )) + } else if let Some(frac) = float_from_maybe_fraction(s) { + Ok(BamlValueWithFlags::Int( + ((frac.round() as i64), Flag::FloatToInt(frac)).into(), + )) + } else if let Some(frac) = float_from_comma_separated(s) { + Ok(BamlValueWithFlags::Int( + ((frac.round() as i64), Flag::FloatToInt(frac)).into(), + )) + } else { + Err(ctx.error_unexpected_type(target, value)) } - _ => Err(ctx.error_unexpected_type(target, value)), } - } else { - Err(ctx.error_unexpected_null(target)) + crate::jsonish::Value::Array(items) => { + coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { + coerce_int(ctx, target, Some(value)) + }) + } + _ => Err(ctx.error_unexpected_type(target, value)), } } @@ -172,46 +170,53 @@ fn coerce_float( target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - if let Some(value) = value { - match value { - crate::jsonish::Value::Number(n) => { - if let Some(n) = n.as_f64() { - Ok(BamlValueWithFlags::Float(n.into())) - } else if let Some(n) = n.as_i64() { - Ok(BamlValueWithFlags::Float((n as f64).into())) - } else if let Some(n) = n.as_u64() { - Ok(BamlValueWithFlags::Float((n as f64).into())) - } else { - Err(ctx.error_unexpected_type(target, value)) - } - } - crate::jsonish::Value::String(s) => { - let s = s.trim(); - // Trim trailing commas - let s = s.trim_end_matches(','); - if let Ok(n) = s.parse::() { - Ok(BamlValueWithFlags::Float(n.into())) - } else if let Ok(n) = s.parse::() { - Ok(BamlValueWithFlags::Float((n as f64).into())) - } else if let Ok(n) = s.parse::() { - Ok(BamlValueWithFlags::Float((n as f64).into())) - } else if let Some(frac) = float_from_maybe_fraction(s) { - Ok(BamlValueWithFlags::Float(frac.into())) - } else if let Some(frac) = float_from_comma_separated(s) { - Ok(BamlValueWithFlags::Float(frac.into())) - } else { - Err(ctx.error_unexpected_type(target, value)) - } + let Some(value) = value else { + return Err(ctx.error_unexpected_null(target)); + }; + + match value { + crate::jsonish::Value::Number(n) => { + if let Some(n) = n.as_f64() { + Ok(BamlValueWithFlags::Float(n.into())) + } else if let Some(n) = n.as_i64() { + Ok(BamlValueWithFlags::Float((n as f64).into())) + } else if let Some(n) = n.as_u64() { + Ok(BamlValueWithFlags::Float((n as f64).into())) + } else { + Err(ctx.error_unexpected_type(target, value)) } - crate::jsonish::Value::Array(items) => { - coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { - coerce_float(ctx, target, Some(value)) - }) + } + crate::jsonish::Value::String(s) => { + let s = s.trim(); + // Trim trailing commas + let s = s.trim_end_matches(','); + if let Ok(n) = s.parse::() { + Ok(BamlValueWithFlags::Float(n.into())) + } else if let Ok(n) = s.parse::() { + Ok(BamlValueWithFlags::Float((n as f64).into())) + } else if let Ok(n) = s.parse::() { + Ok(BamlValueWithFlags::Float((n as f64).into())) + } else if let Some(frac) = float_from_maybe_fraction(s) { + Ok(BamlValueWithFlags::Float(frac.into())) + } else if let Some(frac) = float_from_comma_separated(s) { + let mut baml_value = BamlValueWithFlags::Float(frac.into()); + // Add flag here to penalize strings like + // "1 cup unsalted butter, room temperature". + // If we're trying to parse this to a float it should work + // anyway but unions like "float | string" should still coerce + // this to a string. + baml_value.add_flag(Flag::StringToFloat(s.to_string())); + Ok(baml_value) + } else { + Err(ctx.error_unexpected_type(target, value)) } - _ => Err(ctx.error_unexpected_type(target, value)), } - } else { - Err(ctx.error_unexpected_null(target)) + crate::jsonish::Value::Array(items) => { + coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { + coerce_float(ctx, target, Some(value)) + }) + } + _ => Err(ctx.error_unexpected_type(target, value)), } } @@ -220,51 +225,51 @@ pub(super) fn coerce_bool( target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - if let Some(value) = value { - match value { - crate::jsonish::Value::Boolean(b) => Ok(BamlValueWithFlags::Bool((*b).into())), - crate::jsonish::Value::String(s) => match s.to_lowercase().as_str() { - "true" => Ok(BamlValueWithFlags::Bool( - (true, Flag::StringToBool(s.clone())).into(), - )), - "false" => Ok(BamlValueWithFlags::Bool( - (false, Flag::StringToBool(s.clone())).into(), - )), - _ => { - match super::match_string::match_string( - ctx, - target, - Some(value), - &[ - ("true", vec!["true".into(), "True".into(), "TRUE".into()]), - ( - "false", - vec!["false".into(), "False".into(), "FALSE".into()], - ), - ], - ) { - Ok(val) => match val.value().as_str() { - "true" => Ok(BamlValueWithFlags::Bool( - (true, Flag::StringToBool(val.value().clone())).into(), - )), - "false" => Ok(BamlValueWithFlags::Bool( - (false, Flag::StringToBool(val.value().clone())).into(), - )), - _ => Err(ctx.error_unexpected_type(target, value)), - }, - Err(_) => Err(ctx.error_unexpected_type(target, value)), - } + let Some(value) = value else { + return Err(ctx.error_unexpected_null(target)); + }; + + match value { + crate::jsonish::Value::Boolean(b) => Ok(BamlValueWithFlags::Bool((*b).into())), + crate::jsonish::Value::String(s) => match s.to_lowercase().as_str() { + "true" => Ok(BamlValueWithFlags::Bool( + (true, Flag::StringToBool(s.clone())).into(), + )), + "false" => Ok(BamlValueWithFlags::Bool( + (false, Flag::StringToBool(s.clone())).into(), + )), + _ => { + match super::match_string::match_string( + ctx, + target, + Some(value), + &[ + ("true", vec!["true".into(), "True".into(), "TRUE".into()]), + ( + "false", + vec!["false".into(), "False".into(), "FALSE".into()], + ), + ], + ) { + Ok(val) => match val.value().as_str() { + "true" => Ok(BamlValueWithFlags::Bool( + (true, Flag::StringToBool(val.value().clone())).into(), + )), + "false" => Ok(BamlValueWithFlags::Bool( + (false, Flag::StringToBool(val.value().clone())).into(), + )), + _ => Err(ctx.error_unexpected_type(target, value)), + }, + Err(_) => Err(ctx.error_unexpected_type(target, value)), } - }, - crate::jsonish::Value::Array(items) => { - coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { - coerce_bool(ctx, target, Some(value)) - }) } - _ => Err(ctx.error_unexpected_type(target, value)), + }, + crate::jsonish::Value::Array(items) => { + coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { + coerce_bool(ctx, target, Some(value)) + }) } - } else { - Err(ctx.error_unexpected_null(target)) + _ => Err(ctx.error_unexpected_type(target, value)), } } @@ -337,7 +342,7 @@ mod tests { ("12,111,123,", Some(12111123.0)), ]; - for &(input, expected) in &test_cases { + for (input, expected) in test_cases { let result = float_from_comma_separated(input); assert_eq!( result, expected, diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_union.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_union.rs index 27cf3acb4..8a2312a2a 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_union.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_union.rs @@ -20,7 +20,7 @@ pub(super) fn coerce_union( let options = match union_target { FieldType::Union(options) => options, - _ => unreachable!(), + _ => unreachable!("coerce_union"), }; let parsed = options diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index 59c347e3e..26d973dd3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -9,9 +9,13 @@ use crate::deserializer::{ }; use super::{ - array_helper, coerce_array::coerce_array, coerce_map::coerce_map, - coerce_optional::coerce_optional, coerce_union::coerce_union, ir_ref::IrRef, ParsingContext, - ParsingError, + array_helper, + coerce_array::coerce_array, + coerce_map::coerce_map, + coerce_optional::coerce_optional, + coerce_union::coerce_union, + ir_ref::{coerce_alias::coerce_alias, IrRef}, + ParsingContext, ParsingError, }; impl TypeCoercer for FieldType { @@ -79,6 +83,7 @@ impl TypeCoercer for FieldType { FieldType::Enum(e) => IrRef::Enum(e).coerce(ctx, target, value), FieldType::Literal(l) => l.coerce(ctx, target, value), FieldType::Class(c) => IrRef::Class(c).coerce(ctx, target, value), + FieldType::RecursiveTypeAlias(name) => coerce_alias(ctx, self, value), FieldType::List(_) => coerce_array(ctx, self, value), FieldType::Union(_) => coerce_union(ctx, self, value), FieldType::Optional(_) => coerce_optional(ctx, self, value), @@ -88,7 +93,7 @@ impl TypeCoercer for FieldType { let mut coerced_value = base.coerce(ctx, base, value)?; let constraint_results = run_user_checks(&coerced_value.clone().into(), self) .map_err(|e| ParsingError { - reason: format!("Failed to evaluate constraints: {:?}", e), + reason: format!("Failed to evaluate constraints: {e:?}"), scope: ctx.scope.clone(), causes: Vec::new(), })?; @@ -164,6 +169,7 @@ impl DefaultValue for FieldType { FieldType::Enum(e) => None, FieldType::Literal(_) => None, FieldType::Class(_) => None, + FieldType::RecursiveTypeAlias(_) => None, FieldType::List(_) => Some(BamlValueWithFlags::List(get_flags(), Vec::new())), FieldType::Union(items) => items.iter().find_map(|i| i.default_value(error)), FieldType::Primitive(TypeValue::Null) | FieldType::Optional(_) => { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs new file mode 100644 index 000000000..352724bd8 --- /dev/null +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs @@ -0,0 +1,44 @@ +use anyhow::Result; +use internal_baml_core::ir::FieldType; + +use crate::deserializer::types::BamlValueWithFlags; + +use super::{ParsingContext, ParsingError, TypeCoercer}; + +pub fn coerce_alias( + ctx: &ParsingContext, + target: &FieldType, + value: Option<&crate::jsonish::Value>, +) -> Result { + assert!(matches!(target, FieldType::RecursiveTypeAlias(_))); + log::debug!( + "scope: {scope} :: coercing to: {name} (current: {current})", + name = target.to_string(), + scope = ctx.display_scope(), + current = value.map(|v| v.r#type()).unwrap_or("".into()) + ); + + let FieldType::RecursiveTypeAlias(alias) = target else { + unreachable!("coerce_alias"); + }; + + // See coerce_class.rs + let mut nested_ctx = None; + if let Some(v) = value { + let cls_value_pair = (alias.to_string(), v.to_owned()); + if ctx.visited.contains(&cls_value_pair) { + return Err(ctx.error_circular_reference(alias, v)); + } + nested_ctx = Some(ctx.visit_class_value_pair(cls_value_pair)); + } + let ctx = nested_ctx.as_ref().unwrap_or(ctx); + + ctx.of + .find_recursive_alias_target(alias) + .map_err(|e| ParsingError { + reason: format!("Failed to find recursive alias target: {e}"), + scope: ctx.scope.clone(), + causes: Vec::new(), + })? + .coerce(ctx, target, value) +} diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs index a16b9ed1a..88a3cfde3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs @@ -1,6 +1,9 @@ +pub mod coerce_alias; mod coerce_class; pub mod coerce_enum; +use core::panic; + use anyhow::Result; use internal_baml_core::ir::FieldType; @@ -11,6 +14,7 @@ use super::{ParsingContext, ParsingError}; pub(super) enum IrRef<'a> { Enum(&'a String), Class(&'a String), + RecursiveAlias(&'a String), } impl TypeCoercer for IrRef<'_> { @@ -29,6 +33,10 @@ impl TypeCoercer for IrRef<'_> { Ok(c) => c.coerce(ctx, target, value), Err(e) => Err(ctx.error_internal(e.to_string())), }, + IrRef::RecursiveAlias(a) => match ctx.of.find_recursive_alias_target(a.as_str()) { + Ok(a) => a.coerce(ctx, target, value), + Err(e) => Err(ctx.error_internal(e.to_string())), + }, } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index 5aab40624..106981ebc 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -38,6 +38,7 @@ pub enum Flag { StringToBool(String), StringToNull(String), StringToChar(String), + StringToFloat(String), // Number -> X convertions. FloatToInt(f64), @@ -91,6 +92,7 @@ impl DeserializerConditions { Flag::StringToBool(_) => None, Flag::StringToNull(_) => None, Flag::StringToChar(_) => None, + Flag::StringToFloat(_) => None, Flag::FloatToInt(_) => None, Flag::NoFields(_) => None, Flag::UnionMatch(_idx, _) => None, @@ -235,6 +237,9 @@ impl std::fmt::Display for Flag { Flag::StringToChar(value) => { write!(f, "String to char: {}", value)?; } + Flag::StringToFloat(value) => { + write!(f, "String to float: {}", value)?; + } Flag::FloatToInt(value) => { write!(f, "Float to int: {}", value)?; } diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index f198702bd..7fbfefd8a 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -64,6 +64,7 @@ impl WithScore for Flag { Flag::StringToBool(_) => 1, Flag::StringToNull(_) => 1, Flag::StringToChar(_) => 1, + Flag::StringToFloat(_) => 1, Flag::FloatToInt(_) => 1, Flag::NoFields(_) => 1, // No scores for contraints diff --git a/engine/baml-lib/jsonish/src/deserializer/types.rs b/engine/baml-lib/jsonish/src/deserializer/types.rs index 1303ce7bc..b614b915e 100644 --- a/engine/baml-lib/jsonish/src/deserializer/types.rs +++ b/engine/baml-lib/jsonish/src/deserializer/types.rs @@ -33,6 +33,22 @@ pub enum BamlValueWithFlags { } impl BamlValueWithFlags { + pub fn is_composite(&self) -> bool { + match self { + BamlValueWithFlags::String(_) + | BamlValueWithFlags::Int(_) + | BamlValueWithFlags::Float(_) + | BamlValueWithFlags::Bool(_) + | BamlValueWithFlags::Null(_) + | BamlValueWithFlags::Enum(_, _) => false, + + BamlValueWithFlags::List(_, _) + | BamlValueWithFlags::Map(_, _) + | BamlValueWithFlags::Class(_, _, _) + | BamlValueWithFlags::Media(_) => true, + } + } + pub fn score(&self) -> i32 { match self { BamlValueWithFlags::String(f) => f.score(), diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 991ec17f4..00da9d619 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -4,6 +4,7 @@ use internal_baml_jinja::types::{Class, Enum, Name, OutputFormatContent}; #[macro_use] pub mod macros; +mod test_aliases; mod test_basics; mod test_class; mod test_class_2; @@ -16,7 +17,7 @@ mod test_maps; mod test_partials; mod test_unions; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use std::{ collections::{HashMap, HashSet}, path::PathBuf, @@ -24,6 +25,7 @@ use std::{ use baml_types::{BamlValue, EvaluationContext}; use internal_baml_core::{ + ast::Field, internal_baml_diagnostics::SourceFile, ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, TypeValue}, validate, @@ -55,12 +57,14 @@ fn render_output_format( output: &FieldType, env_values: &EvaluationContext<'_>, ) -> Result { - let (enums, classes, recursive_classes) = relevant_data_models(ir, output, env_values)?; + let (enums, classes, recursive_classes, structural_recursive_aliases) = + relevant_data_models(ir, output, env_values)?; Ok(OutputFormatContent::target(output.clone()) .enums(enums) .classes(classes) .recursive_classes(recursive_classes) + .structural_recursive_aliases(structural_recursive_aliases) .build()) } @@ -125,11 +129,17 @@ fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, env_values: &EvaluationContext<'_>, -) -> Result<(Vec, Vec, IndexSet)> { +) -> Result<( + Vec, + Vec, + IndexSet, + IndexMap, +)> { let mut checked_types: HashSet = HashSet::new(); let mut enums = Vec::new(); let mut classes: Vec = Vec::new(); let mut recursive_classes = IndexSet::new(); + let mut structural_recursive_aliases = IndexMap::new(); let mut start: Vec = vec![output.clone()]; while let Some(output) = start.pop() { @@ -229,6 +239,16 @@ fn relevant_data_models<'a>( }); } } + (FieldType::RecursiveTypeAlias(name), _) => { + // TODO: Same O(n) problem as above. + for cycle in ir.structural_recursive_alias_cycles() { + if cycle.contains_key(name) { + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } + } + } + } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _constraints) => {} (FieldType::Constrained { .. }, _) => { @@ -237,7 +257,12 @@ fn relevant_data_models<'a>( } } - Ok((enums, classes, recursive_classes)) + Ok(( + enums, + classes, + recursive_classes, + structural_recursive_aliases, + )) } const EMPTY_FILE: &str = r#" diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs new file mode 100644 index 000000000..9570be34f --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -0,0 +1,313 @@ +use baml_types::LiteralValue; + +use super::*; + +test_deserializer!( + test_simple_recursive_alias_list, + r#" +type A = A[] + "#, + "[[], [], [[]]]", + FieldType::RecursiveTypeAlias("A".into()), + [[], [], [[]]] +); + +test_deserializer!( + test_simple_recursive_alias_map, + r#" +type A = map + "#, + r#"{"one": {"two": {}}, "three": {"four": {}}}"#, + FieldType::RecursiveTypeAlias("A".into()), + { + "one": {"two": {}}, + "three": {"four": {}} + } +); + +test_deserializer!( + test_recursive_alias_cycle, + r#" +type A = B +type B = C +type C = A[] + "#, + "[[], [], [[]]]", + FieldType::RecursiveTypeAlias("A".into()), + [[], [], [[]]] +); + +test_deserializer!( + test_json_without_nested_objects, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + { + "int": 1, + "float": 1.0, + "string": "test", + "bool": true + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "int": 1, + "float": 1.0, + "string": "test", + "bool": true + } +); + +test_deserializer!( + test_json_with_nested_list, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } +); + +test_deserializer!( + test_json_with_nested_object, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "json": { + "number": 1, + "string": "test", + "bool": true + } + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "json": { + "number": 1, + "string": "test", + "bool": true + } + } +); + +test_deserializer!( + test_full_json_with_nested_objects, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + "json": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + } + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + "json": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + } + } +); + +test_deserializer!( + test_list_of_json_objects, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + [ + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + ] + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + [ + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + ] +); + +test_deserializer!( + test_nested_list, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + [[42.1]] + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + // [[[[[[[[[[[[[[[[[[[[42]]]]]]]]]]]]]]]]]]]] + [[42.1]] +); + +test_deserializer!( + test_json_defined_with_cycles, + r#" +type JsonValue = int | float | bool | string | null | JsonArray | JsonObject +type JsonArray = JsonValue[] +type JsonObject = map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "json": { + "number": 1, + "string": "test", + "bool": true + } + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "json": { + "number": 1, + "string": "test", + "bool": true + } + } +); + +test_deserializer!( + test_ambiguous_int_string_json_type, + r#" +type JsonValue = int | float | bool | string | null | JsonValue[] | map + "#, + r#" + { + "recipe": { + "name": "Chocolate Chip Cookies", + "servings": 24, + "ingredients": [ + "2 1/4 cups all-purpose flour", "1/2 teaspoon baking soda", + "1 cup unsalted butter, room temperature", + "1/2 cup granulated sugar", + "1 cup packed light-brown sugar", + "1 teaspoon salt", "2 teaspoons pure vanilla extract", + "2 large eggs", "2 cups semisweet and/or milk chocolate chips" + ], + "instructions": [ + "Preheat oven to 350°F (180°C).", + "In a small bowl, whisk together flour and baking soda; set aside.", + "In a large bowl, cream butter and sugars until light and fluffy.", + "Add salt, vanilla, and eggs; mix well.", + "Gradually stir in flour mixture.", + "Fold in chocolate chips.", + "Drop by rounded tablespoons onto ungreased baking sheets.", + "Bake for 10-12 minutes or until golden brown.", + "Cool on wire racks." + ] + } + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "recipe": { + "name": "Chocolate Chip Cookies", + "servings": 24, + "ingredients": [ + "2 1/4 cups all-purpose flour", "1/2 teaspoon baking soda", + "1 cup unsalted butter, room temperature", + "1/2 cup granulated sugar", + "1 cup packed light-brown sugar", + "1 teaspoon salt", "2 teaspoons pure vanilla extract", + "2 large eggs", "2 cups semisweet and/or milk chocolate chips" + ], + "instructions": [ + "Preheat oven to 350°F (180°C).", + "In a small bowl, whisk together flour and baking soda; set aside.", + "In a large bowl, cream butter and sugars until light and fluffy.", + "Add salt, vanilla, and eggs; mix well.", + "Gradually stir in flour mixture.", + "Fold in chocolate chips.", + "Drop by rounded tablespoons onto ungreased baking sheets.", + "Bake for 10-12 minutes or until golden brown.", + "Cool on wire racks." + ] + } + } +); diff --git a/engine/baml-lib/jsonish/src/tests/test_basics.rs b/engine/baml-lib/jsonish/src/tests/test_basics.rs index ea88ae476..a79698b18 100644 --- a/engine/baml-lib/jsonish/src/tests/test_basics.rs +++ b/engine/baml-lib/jsonish/src/tests/test_basics.rs @@ -146,6 +146,14 @@ test_deserializer!( [1., 2., 3.] ); +test_deserializer!( + test_string_to_float_from_comma_separated, + "", + "1 cup unsalted butter, room temperature", + FieldType::Primitive(TypeValue::Float), + 1.0 +); + test_deserializer!( test_object, r#" diff --git a/engine/baml-lib/jsonish/src/tests/test_constraints.rs b/engine/baml-lib/jsonish/src/tests/test_constraints.rs index f543cd115..728c1bf67 100644 --- a/engine/baml-lib/jsonish/src/tests/test_constraints.rs +++ b/engine/baml-lib/jsonish/src/tests/test_constraints.rs @@ -16,7 +16,7 @@ test_deserializer_with_expected_score!( CLASS_FOO_INT_STRING, r#"{"age": 11, "name": "Greg"}"#, FieldType::Class("Foo".to_string()), - 0 + 1 ); test_deserializer_with_expected_score!( @@ -24,7 +24,7 @@ test_deserializer_with_expected_score!( CLASS_FOO_INT_STRING, r#"{"age": 21, "name": "Grog"}"#, FieldType::Class("Foo".to_string()), - 0 + 1 ); test_failing_deserializer!( @@ -61,7 +61,7 @@ test_deserializer_with_expected_score!( UNION_WITH_CHECKS, r#"{"bar": 5, "things":[]}"#, FieldType::Class("Either".to_string()), - 2 + 3 ); test_deserializer_with_expected_score!( @@ -69,7 +69,7 @@ test_deserializer_with_expected_score!( UNION_WITH_CHECKS, r#"{"bar": 15, "things":[]}"#, FieldType::Class("Either".to_string()), - 2 + 3 ); test_failing_deserializer!( @@ -90,7 +90,7 @@ test_deserializer_with_expected_score!( MAP_WITH_CHECKS, r#"{"foo": {"hello": 10, "there":13}}"#, FieldType::Class("Foo".to_string()), - 1 + 2 ); test_deserializer_with_expected_score!( @@ -98,7 +98,7 @@ test_deserializer_with_expected_score!( MAP_WITH_CHECKS, r#"{"foo": {"hello": 11, "there":13}}"#, FieldType::Class("Foo".to_string()), - 1 + 2 ); const NESTED_CLASS_CONSTRAINTS: &str = r#" @@ -116,7 +116,7 @@ test_deserializer_with_expected_score!( NESTED_CLASS_CONSTRAINTS, r#"{"inner": {"value": 15}}"#, FieldType::Class("Outer".to_string()), - 0 + 1 ); const BLOCK_LEVEL: &str = r#" diff --git a/engine/baml-lib/jsonish/src/tests/test_maps.rs b/engine/baml-lib/jsonish/src/tests/test_maps.rs index 906e4a1c5..5c1b1d641 100644 --- a/engine/baml-lib/jsonish/src/tests/test_maps.rs +++ b/engine/baml-lib/jsonish/src/tests/test_maps.rs @@ -166,6 +166,7 @@ fn test_union_of_map_and_class() { assert!(result.is_ok(), "Failed to parse: {:?}", result); let value = result.unwrap(); + dbg!(&value); assert!(matches!(value, BamlValueWithFlags::Class(_, _, _))); log::trace!("Score: {}", value.score()); diff --git a/engine/baml-lib/jsonish/src/tests/test_unions.rs b/engine/baml-lib/jsonish/src/tests/test_unions.rs index c1f7173c9..dbdacbf7d 100644 --- a/engine/baml-lib/jsonish/src/tests/test_unions.rs +++ b/engine/baml-lib/jsonish/src/tests/test_unions.rs @@ -277,3 +277,25 @@ test_deserializer!( FieldType::Class("ContactInfo".to_string()), {"primary": {"value": "help@boundaryml.com"}} ); + +test_deserializer!( + test_ignore_float_in_string_if_string_in_union, + "", + "1 cup unsalted butter, room temperature", + FieldType::Union(vec![ + FieldType::Primitive(TypeValue::Float), + FieldType::Primitive(TypeValue::String), + ]), + "1 cup unsalted butter, room temperature" +); + +test_deserializer!( + test_ignore_int_if_string_in_union, + "", + "1 cup unsalted butter, room temperature", + FieldType::Union(vec![ + FieldType::Primitive(TypeValue::Int), + FieldType::Primitive(TypeValue::String), + ]), + "1 cup unsalted butter, room temperature" +); diff --git a/engine/baml-lib/parser-database/src/attributes/mod.rs b/engine/baml-lib/parser-database/src/attributes/mod.rs index ec7dcdd2b..92e8c58bf 100644 --- a/engine/baml-lib/parser-database/src/attributes/mod.rs +++ b/engine/baml-lib/parser-database/src/attributes/mod.rs @@ -1,5 +1,7 @@ -use internal_baml_diagnostics::Span; -use internal_baml_schema_ast::ast::{Top, TopId, TypeExpId, TypeExpressionBlock}; +use internal_baml_diagnostics::{DatamodelError, Span}; +use internal_baml_schema_ast::ast::{ + Assignment, Top, TopId, TypeAliasId, TypeExpId, TypeExpressionBlock, +}; mod alias; pub mod constraint; @@ -79,6 +81,9 @@ pub(super) fn resolve_attributes(ctx: &mut Context<'_>) { (TopId::Enum(enum_id), Top::Enum(ast_enum)) => { resolve_type_exp_block_attributes(enum_id, ast_enum, ctx, SubType::Enum) } + (TopId::TypeAlias(alias_id), Top::TypeAlias(assignment)) => { + resolve_type_alias_attributes(alias_id, assignment, ctx) + } _ => (), } } @@ -132,3 +137,53 @@ fn resolve_type_exp_block_attributes<'db>( _ => (), } } + +/// Quick hack to validate type alias attributes. +/// +/// Unlike classes and enums, type aliases only support checks and asserts. +/// Everything else is reported as an error. On top of that, checks and asserts +/// must be merged when aliases point to other aliases. We do this recursively +/// when resolving the type alias to its final "virtual" type at +/// [`crate::types::resolve_type_alias`]. +/// +/// Then checks and asserts are collected from the virtual type and stored in +/// the IR at `engine/baml-lib/baml-core/src/ir/repr.rs`, so there's no need to +/// store them in separate classes like [`ClassAttributes`] or similar, at least +/// for now. +fn resolve_type_alias_attributes<'db>( + alias_id: TypeAliasId, + assignment: &'db Assignment, + ctx: &mut Context<'db>, +) { + ctx.assert_all_attributes_processed(alias_id.into()); + + for _ in 0..assignment.value.attributes().len() { + // TODO: How does this thing work exactly, the code in the functions + // above for visiting class fields suggests that this returns "all" the + // attributes that it finds but it does not return repeated checks and + // asserts, they are left in the state machine and the Context panics. + // So we're gonna run this in a for loop so that the visit function + // calls visit_repeated_attr_from_names enough times to consume all the + // checks and asserts. + let type_alias_attributes = to_string_attribute::visit(ctx, assignment.value.span(), false); + + // Some additional specific validation for type alias attributes. + if let Some(attrs) = &type_alias_attributes { + if attrs.dynamic_type().is_some() + || attrs.alias().is_some() + || attrs.skip().is_some() + || attrs.description().is_some() + { + ctx.diagnostics + .push_error(DatamodelError::new_validation_error( + "type aliases may only have @check and @assert attributes", + assignment.span.clone(), + )); + } + } + } + + // Now this should be safe to call and it should not panic because there are + // checks and asserts left. + ctx.validate_visited_attributes(); +} diff --git a/engine/baml-lib/parser-database/src/context/mod.rs b/engine/baml-lib/parser-database/src/context/mod.rs index b9a46fb0a..6a417c15b 100644 --- a/engine/baml-lib/parser-database/src/context/mod.rs +++ b/engine/baml-lib/parser-database/src/context/mod.rs @@ -16,12 +16,13 @@ mod attributes; /// /// ## Attribute Validation /// -/// The Context also acts as a state machine for attribute validation. The goal is to avoid manual -/// work validating things that are valid for every attribute set, and every argument set inside an -/// attribute: multiple unnamed arguments are not valid, attributes we do not use in parser-database -/// are not valid, multiple arguments with the same name are not valid, etc. +/// The Context also acts as a state machine for attribute validation. The goal +/// is to avoid manual work validating things that are valid for every attribute +/// set, and every argument set inside an attribute: multiple unnamed arguments +/// are not valid, attributes we do not use in parser-database are not valid, +/// multiple arguments with the same name are not valid, etc. /// -/// See `visit_attributes()`. +/// See [`Self::assert_all_attributes_processed`]. pub(crate) struct Context<'db> { pub(crate) ast: &'db ast::SchemaAst, pub(crate) interner: &'db mut StringInterner, @@ -75,14 +76,17 @@ impl<'db> Context<'db> { self.diagnostics.push_warning(warning) } - /// All attribute validation should go through `visit_attributes()`. It lets - /// us enforce some rules, for example that certain attributes should not be - /// repeated, and make sure that _all_ attributes are visited during the - /// validation process, emitting unknown attribute errors when it is not - /// the case. + /// Attribute processing entry point. /// - /// - When you are done validating an attribute, you must call `discard_arguments()` or - /// `validate_visited_arguments()`. Otherwise, Context will helpfully panic. + /// All attribute validation should go through + /// [`Self::assert_all_attributes_processed`]. It lets us enforce some + /// rules, for example that certain attributes should not be repeated, and + /// make sure that _all_ attributes are visited during the validation + /// process, emitting unknown attribute errors when it is not the case. + /// + /// - When you are done validating an attribute, you must call + /// [`Self::discard_arguments()`] or [`Self::validate_visited_arguments()`]. + /// Otherwise, [`Context`] will helpfully panic. pub(super) fn assert_all_attributes_processed( &mut self, ast_attributes: ast::AttributeContainer, @@ -98,28 +102,6 @@ impl<'db> Context<'db> { self.attributes.set_attributes(ast_attributes, self.ast); } - /// Extract an attribute that can occur zero or more times. Example: @@index on models. - /// - /// Returns `true` as long as a next attribute is found. - pub(crate) fn _visit_repeated_attr(&mut self, name: &'static str) -> bool { - let mut has_valid_attribute = false; - - while !has_valid_attribute { - let first_attr = iter_attributes(self.attributes.attributes.as_ref(), self.ast) - .filter(|(_, attr)| attr.name.name() == name) - .find(|(attr_id, _)| self.attributes.unused_attributes.contains(attr_id)); - let (attr_id, attr) = if let Some(first_attr) = first_attr { - first_attr - } else { - break; - }; - self.attributes.unused_attributes.remove(&attr_id); - has_valid_attribute = self.set_attribute(attr_id, attr); - } - - has_valid_attribute - } - /// Extract an attribute that can occur zero or more times. Example: @assert on types. /// Argument is a list of names that are all valid for this attribute. /// diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index c6bb6ea11..c90c141ba 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -36,17 +36,18 @@ mod names; mod tarjan; mod types; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap, HashSet, VecDeque}; pub use coerce_expression::{coerce, coerce_array, coerce_opt}; -use either::Either; pub use internal_baml_schema_ast::ast; -use internal_baml_schema_ast::ast::SchemaAst; +use internal_baml_schema_ast::ast::{FieldType, SchemaAst, WithName}; pub use tarjan::Tarjan; +use types::resolve_type_alias; pub use types::{ Attributes, ClientProperties, ContantDelayStrategy, ExponentialBackoffStrategy, PrinterType, PromptAst, PromptVariable, RetryPolicy, RetryPolicyStrategy, StaticType, }; +pub use walkers::TypeWalker; use self::{context::Context, interner::StringId, types::Types}; use internal_baml_diagnostics::{DatamodelError, Diagnostics}; @@ -113,8 +114,6 @@ impl ParserDatabase { // First pass: resolve names. names::resolve_names(&mut ctx); - // Return early on name resolution errors. - // Second pass: resolve top-level items and field types. types::resolve_types(&mut ctx); @@ -125,12 +124,25 @@ impl ParserDatabase { ctx.diagnostics.to_result() } - /// Updates the prompt + /// Last changes after validation. pub fn finalize(&mut self, diag: &mut Diagnostics) { self.finalize_dependencies(diag); } fn finalize_dependencies(&mut self, diag: &mut Diagnostics) { + // Cycles left here after cycle validation are allowed. Basically lists + // and maps can introduce cycles. + self.types.structural_recursive_alias_cycles = + Tarjan::components(&self.types.type_alias_dependencies); + + // Resolve type aliases. + // Cycles are already validated so this should not stack overflow and + // it should find the final type. + for alias_id in self.types.type_alias_dependencies.keys() { + let resolved = resolve_type_alias(&self.ast[*alias_id].value, &self); + self.types.resolved_type_aliases.insert(*alias_id, resolved); + } + // NOTE: Class dependency cycles are already checked at // baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs // @@ -152,56 +164,57 @@ impl ParserDatabase { // instead of strings (class names). That requires less conversions when // working with the graph. Once the work is done, IDs can be converted // to names where needed. - let cycles = Tarjan::components(&HashMap::from_iter( - self.types.class_dependencies.iter().map(|(id, deps)| { - let deps = - HashSet::from_iter(deps.iter().filter_map( - |dep| match self.find_type_by_str(dep) { - Some(Either::Left(cls)) => Some(cls.id), - Some(Either::Right(_)) => None, - None => panic!("Unknown class `{dep}`"), - }, - )); - (*id, deps) - }), - )); + let mut resolved_dependency_graph = HashMap::new(); + + for (id, deps) in self.types.class_dependencies.iter() { + let mut resolved_deps = HashSet::new(); + + for dep in deps { + match self.find_type_by_str(dep) { + Some(TypeWalker::Class(cls)) => { + resolved_deps.insert(cls.id); + } + Some(TypeWalker::Enum(_)) => {} + // Gotta resolve type aliases. + Some(TypeWalker::TypeAlias(alias)) => { + resolved_deps.extend(alias.resolved().flat_idns().iter().map(|ident| { + match self.find_type_by_str(ident.name()) { + Some(TypeWalker::Class(cls)) => cls.id, + Some(TypeWalker::Enum(_)) => { + panic!("Enums are not allowed in type aliases") + } + Some(TypeWalker::TypeAlias(alias)) => { + panic!("Alias should be resolved at this point") + } + None => panic!("Unknown class `{dep}`"), + } + })) + } + None => panic!("Unknown class `{dep}`"), + } + } + + resolved_dependency_graph.insert(*id, resolved_deps); + } - // Inject finite cycles into parser DB. This will then be passed into - // the IR and then into the Jinja output format. - self.types.finite_recursive_cycles = cycles - .into_iter() - .map(|cycle| cycle.into_iter().collect()) - .collect(); + // Find the cycles and inject them into parser DB. This will then be + // passed into the IR and then into the Jinja output format. + // + // TODO: Should we update `class_dependencies` to include resolved + // aliases or not? + self.types.finite_recursive_cycles = Tarjan::components(&resolved_dependency_graph); - // Additionally ensure the same thing for functions, but since we've already handled classes, - // this should be trivial. + // Fully resolve function dependencies. let extends = self .types .function .iter() - .map(|(&k, func)| { + .map(|(&id, func)| { let (input, output) = &func.dependencies; - let input_deps = input - .iter() - .filter_map(|f| match self.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker.dependencies().iter().cloned()), - Some(Either::Right(_)) => None, - _ => panic!("Unknown class `{}`", f), - }) - .flatten() - .collect::>(); - - let output_deps = output - .iter() - .filter_map(|f| match self.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker.dependencies().iter().cloned()), - Some(Either::Right(_)) => None, - _ => panic!("Unknown class `{}`", f), - }) - .flatten() - .collect::>(); - - (k, (input_deps, output_deps)) + let input_deps = self.collect_dependency_tree(input); + let output_deps = self.collect_dependency_tree(output); + + (id, (input_deps, output_deps)) }) .collect::>(); @@ -212,10 +225,71 @@ impl ParserDatabase { } } + /// Resolve the entire tree of dependencies for functions. + /// + /// Initial passes through the AST can only resolve one level of + /// dependencies for functions. This method will go through that first level + /// and collect all the dependencies of the dependencies. + fn collect_dependency_tree(&self, deps: &HashSet) -> HashSet { + let mut collected_deps = HashSet::new(); + let mut stack = Vec::from_iter(deps.iter().map(|dep| dep.as_str())); + + while let Some(dep) = stack.pop() { + match self.find_type_by_str(dep) { + // Add all the dependencies of the class. + Some(TypeWalker::Class(walker)) => { + for nested_dep in walker.dependencies() { + if collected_deps.insert(nested_dep.to_owned()) { + // Recurse if not already visited. + stack.push(nested_dep); + } + } + } + + // For aliases just get the resolved identifiers and + // push them into the stack. If we find resolved classes we'll + // add their dependencies as well. + Some(TypeWalker::TypeAlias(walker)) => { + stack.extend(walker.resolved().flat_idns().iter().filter_map(|ident| { + // Add the resolved name itself to the deps. + collected_deps.insert(ident.name().to_owned()); + // If the type is an alias then don't recurse. + if self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| cycle.contains(&walker.id)) + { + None + } else { + Some(ident.name()) + } + })) + } + + // Skip enums. + Some(TypeWalker::Enum(_)) => {} + + // This should not happen. + _ => panic!("Unknown class `{dep}`"), + } + } + + collected_deps + } + /// The parsed AST. pub fn ast(&self) -> &ast::SchemaAst { &self.ast } + + /// Returns the graph of type aliases. + /// + /// Each vertex is a type alias and each edge is a reference to another type + /// alias. + pub fn type_alias_dependencies(&self) -> &HashMap> { + &self.types.type_alias_dependencies + } + /// The total number of enums in the schema. This is O(1). pub fn enums_count(&self) -> usize { self.types.enum_attributes.len() @@ -241,10 +315,12 @@ mod test { use std::path::PathBuf; use super::*; + use ast::FieldArity; + use baml_types::TypeValue; use internal_baml_diagnostics::{Diagnostics, SourceFile}; use internal_baml_schema_ast::parse_schema; - fn assert_finite_cycles(baml: &'static str, expected: &[&[&str]]) -> Result<(), Diagnostics> { + fn parse(baml: &'static str) -> Result { let mut db = ParserDatabase::new(); let source = SourceFile::new_static(PathBuf::from("test.baml"), baml); let (ast, mut diag) = parse_schema(source.path_buf(), &source)?; @@ -253,6 +329,14 @@ mod test { db.validate(&mut diag)?; db.finalize(&mut diag); + diag.to_result()?; + + Ok(db) + } + + fn assert_finite_cycles(baml: &'static str, expected: &[&[&str]]) -> Result<(), Diagnostics> { + let db = parse(baml)?; + assert_eq!( db.finite_recursive_cycles() .iter() @@ -267,6 +351,26 @@ mod test { Ok(()) } + fn assert_structural_alias_cycles( + baml: &'static str, + expected: &[&[&str]], + ) -> Result<(), Diagnostics> { + let db = parse(baml)?; + + assert_eq!( + db.structural_recursive_alias_cycles() + .iter() + .map(|ids| Vec::from_iter(ids.iter().map(|id| db.ast()[*id].name().to_string()))) + .collect::>(), + expected + .iter() + .map(|cycle| Vec::from_iter(cycle.iter().map(ToString::to_string))) + .collect::>() + ); + + Ok(()) + } + #[test] fn find_simple_recursive_class() -> Result<(), Diagnostics> { assert_finite_cycles( @@ -479,4 +583,128 @@ mod test { &[&["RecMap"]], ) } + + #[test] + fn resolve_simple_alias() -> Result<(), Diagnostics> { + let db = parse("type Number = int")?; + + assert!(matches!( + db.resolved_type_alias_by_name("Number").unwrap(), + FieldType::Primitive(FieldArity::Required, TypeValue::Int, _, _) + )); + + Ok(()) + } + + #[test] + fn resolve_multiple_levels_of_aliases() -> Result<(), Diagnostics> { + #[rustfmt::skip] + let db = parse(r#" + type One = string + type Two = One + type Three = Two + type Four = Three + "#)?; + + assert!(matches!( + db.resolved_type_alias_by_name("Four").unwrap(), + FieldType::Primitive(FieldArity::Required, TypeValue::String, _, _) + )); + + Ok(()) + } + + #[test] + fn sync_alias_arity() -> Result<(), Diagnostics> { + #[rustfmt::skip] + let db = parse(r#" + type Required = float + type Optional = Required? + "#)?; + + assert!(matches!( + db.resolved_type_alias_by_name("Optional").unwrap(), + FieldType::Primitive(FieldArity::Optional, TypeValue::Float, _, _) + )); + + Ok(()) + } + + #[test] + fn find_basic_map_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles( + "type RecursiveMap = map", + &[&["RecursiveMap"]], + ) + } + + #[test] + fn find_basic_list_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles("type A = A[]", &[&["A"]]) + } + + #[test] + fn find_long_list_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles( + r#" + type A = B + type B = C + type C = A[] + "#, + &[&["A", "B", "C"]], + ) + } + + #[test] + fn find_intricate_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles( + r#" + type JsonValue = string | int | float | bool | null | JsonArray | JsonObject + type JsonArray = JsonValue[] + type JsonObject = map + "#, + &[&["JsonValue", "JsonArray", "JsonObject"]], + ) + } + + #[test] + fn merged_alias_attrs() -> Result<(), Diagnostics> { + #[rustfmt::skip] + let db = parse(r#" + type One = int @assert({{ this < 5 }}) + type Two = One @assert({{ this > 0 }}) + "#)?; + + let resolved = db.resolved_type_alias_by_name("Two").unwrap(); + + assert_eq!(resolved.attributes().len(), 2); + + Ok(()) + } + + // Resolution of aliases here at the parser database level doesn't matter + // as much because there's no notion of "classes" or "enums", it's just + // "symbols". But the resolve type function should not stack overflow + // anyway. + #[test] + fn resolve_simple_structural_recursive_alias() -> Result<(), Diagnostics> { + #[rustfmt::skip] + let db = parse(r#" + type A = A[] + "#)?; + + let resolved = db.resolved_type_alias_by_name("A").unwrap(); + + let FieldType::List(_, inner, ..) = resolved else { + panic!("expected a list type, got {resolved:?}"); + }; + + let FieldType::Symbol(_, ident, _) = &**inner else { + panic!("expected a symbol type, got {inner:?}"); + }; + + assert_eq!(ident.name(), "A"); + + Ok(()) + } } diff --git a/engine/baml-lib/parser-database/src/names/mod.rs b/engine/baml-lib/parser-database/src/names/mod.rs index fe85ca024..4d9446e27 100644 --- a/engine/baml-lib/parser-database/src/names/mod.rs +++ b/engine/baml-lib/parser-database/src/names/mod.rs @@ -90,6 +90,19 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) { (_, ast::Top::Class(_)) => { unreachable!("Encountered impossible class declaration during parsing") } + + (ast::TopId::TypeAlias(_), ast::Top::TypeAlias(type_alias)) => { + validate_type_alias_name(type_alias, ctx.diagnostics); + + ctx.interner.intern(type_alias.name()); + + Some(either::Left(&mut names.tops)) + } + + (_, ast::Top::TypeAlias(_)) => { + unreachable!("Encountered impossible type alias declaration during parsing") + } + (ast::TopId::TemplateString(_), ast::Top::TemplateString(template_string)) => { validate_template_string_name(template_string, ctx.diagnostics); validate_attribute_identifiers(template_string, ctx); diff --git a/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs b/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs index 781a838fc..6389c32de 100644 --- a/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs +++ b/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs @@ -44,6 +44,10 @@ pub(crate) fn validate_class_name( validate_name("class", ast_class.identifier(), diagnostics, true); } +pub(crate) fn validate_type_alias_name(ast_class: &ast::Assignment, diagnostics: &mut Diagnostics) { + validate_name("type alias", ast_class.identifier(), diagnostics, true); +} + pub(crate) fn validate_class_field_name( ast_class_field: &ast::Field, diagnostics: &mut Diagnostics, diff --git a/engine/baml-lib/parser-database/src/tarjan.rs b/engine/baml-lib/parser-database/src/tarjan.rs index 1ecdf6f1b..a80a6a52d 100644 --- a/engine/baml-lib/parser-database/src/tarjan.rs +++ b/engine/baml-lib/parser-database/src/tarjan.rs @@ -6,12 +6,14 @@ use std::{ cmp, collections::{HashMap, HashSet}, + fmt::Debug, + hash::Hash, }; use internal_baml_schema_ast::ast::TypeExpId; /// Dependency graph represented as an adjacency list. -type Graph = HashMap>; +type Graph = HashMap>; /// State of each node for Tarjan's algorithm. #[derive(Clone, Copy)] @@ -35,20 +37,24 @@ struct NodeState { /// This struct is simply bookkeeping for the algorithm, it can be implemented /// with just function calls but the recursive one would need 6 parameters which /// is pretty ugly. -pub struct Tarjan<'g> { +pub struct Tarjan<'g, V> { /// Ref to the depdenency graph. - graph: &'g Graph, + graph: &'g Graph, /// Node number counter. index: usize, /// Nodes are placed on a stack in the order in which they are visited. - stack: Vec, + stack: Vec, /// State of each node. - state: HashMap, + state: HashMap, /// Strongly connected components. - components: Vec>, + components: Vec>, } -impl<'g> Tarjan<'g> { +// V is Copy because we mostly use opaque identifiers for class or alias IDs. +// In practice V ends up being a u32, but if for some reason this needs to +// be used with strings then we can make V Clone instead of Copy and refactor +// the code below. +impl<'g, V: Eq + Ord + Hash + Copy> Tarjan<'g, V> { /// Unvisited node marker. /// /// Technically we should use [`Option`] and [`None`] for @@ -63,7 +69,7 @@ impl<'g> Tarjan<'g> { /// Loops through all the nodes in the graph and visits them if they haven't /// been visited already. When the algorithm is done, [`Self::components`] /// will contain all the cycles in the graph. - pub fn components(graph: &'g Graph) -> Vec> { + pub fn components(graph: &'g Graph) -> Vec> { let mut tarjans = Self { graph, index: 0, @@ -105,7 +111,7 @@ impl<'g> Tarjan<'g> { /// /// This is where the "algorithm" runs. Could be implemented iteratively if /// needed at some point. - fn strong_connect(&mut self, node_id: TypeExpId) { + fn strong_connect(&mut self, node_id: V) { // Initialize node state. This node has not yet been visited so we don't // have to grab the state from the hash map. And if we did, then we'd // have to fight the borrow checker by taking mut refs and read-only @@ -121,8 +127,15 @@ impl<'g> Tarjan<'g> { self.index += 1; self.stack.push(node_id); + // TODO: @antoniosarosi: HashSet is random, won't always iterate in the + // same order. Fix this with IndexSet or something, we really don't want + // to sort this every single time. Also order only matters for tests, we + // can do `if cfg!(test)` or something. + let mut successors = Vec::from_iter(&self.graph[&node_id]); + successors.sort(); + // Visit neighbors to find strongly connected components. - for successor_id in &self.graph[&node_id] { + for successor_id in successors { // Grab owned state to circumvent borrow checker. let mut successor = self.state[successor_id]; if successor.index == Self::UNVISITED { diff --git a/engine/baml-lib/parser-database/src/types/mod.rs b/engine/baml-lib/parser-database/src/types/mod.rs index 7bc12ccb9..59f2e73e7 100644 --- a/engine/baml-lib/parser-database/src/types/mod.rs +++ b/engine/baml-lib/parser-database/src/types/mod.rs @@ -1,10 +1,10 @@ use ouroboros::self_referencing; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::hash::Hash; use std::ops::Deref; -use crate::coerce; use crate::types::configurations::visit_test_case; +use crate::{coerce, ParserDatabase}; use crate::{context::Context, DatamodelError}; use baml_types::Constraint; @@ -13,7 +13,7 @@ use indexmap::IndexMap; use internal_baml_diagnostics::{Diagnostics, Span}; use internal_baml_prompt_parser::ast::{ChatBlock, PrinterBlock, Variable}; use internal_baml_schema_ast::ast::{ - self, Expression, FieldId, RawString, ValExpId, WithIdentifier, WithName, WithSpan, + self, Expression, FieldId, FieldType, RawString, ValExpId, WithIdentifier, WithName, WithSpan, }; use internal_llm_client::{ClientProvider, PropertyHandler, UnresolvedClientProperty}; @@ -37,6 +37,12 @@ pub(super) fn resolve_types(ctx: &mut Context<'_>) { visit_class(idx, model, ctx); } (_, ast::Top::Class(_)) => unreachable!("Class misconfigured"), + + (ast::TopId::TypeAlias(idx), ast::Top::TypeAlias(assignment)) => { + visit_type_alias(idx, assignment, ctx); + } + (_, ast::Top::TypeAlias(assignment)) => unreachable!("Type alias misconfigured"), + (ast::TopId::TemplateString(idx), ast::Top::TemplateString(template_string)) => { visit_template_string(idx, template_string, ctx) } @@ -235,6 +241,29 @@ pub(super) struct Types { pub(super) class_dependencies: HashMap>, pub(super) enum_dependencies: HashMap>, + /// Graph of type aliases. + /// + /// This graph is only used to detect infinite cycles in type aliases. + pub(crate) type_alias_dependencies: HashMap>, + + /// Fully resolved type aliases. + /// + /// A type alias con point to one or many other type aliases. + /// + /// ```ignore + /// type AliasOne = SomeClass + /// type AliasTwo = AnotherClass + /// type AliasThree = AliasTwo + /// type AliasFour = AliasOne | AliasTwo + /// ``` + /// + /// In the above example, `AliasFour` would be resolved to the type + /// `SomeClass | AnotherClass`, which does not even exist in the AST. That's + /// why we need to store the resolution here. + /// + /// Contents would be `AliasThree -> SomeClass`. + pub(super) resolved_type_aliases: HashMap, + /// Strongly connected components of the dependency graph. /// /// Basically contains all the different cycles. This allows us to find a @@ -247,6 +276,13 @@ pub(super) struct Types { /// Merge-Find Set or something like that. pub(super) finite_recursive_cycles: Vec>, + /// Contains recursive type aliases. + /// + /// Recursive type aliases are a little bit trickier than recursive classes + /// because the termination condition is tied to lists and maps only. Nulls + /// and unions won't allow type alias cycles to be resolved. + pub(super) structural_recursive_alias_cycles: Vec>, + pub(super) function: HashMap, pub(super) client_properties: HashMap, @@ -346,6 +382,174 @@ fn visit_class<'db>( }); } +/// Returns a "virtual" type that represents the fully resolved alias. +/// +/// We call it "virtual" because it might not exist in the AST. Basic example: +/// +/// ```ignore +/// type AliasOne = SomeClass +/// type AliasTwo = AnotherClass +/// type AliasThree = AliasOne | AliasTwo | int +/// ``` +/// +/// The type would resolve to `SomeClass | AnotherClass | int`, which is not +/// stored in the AST. +/// +/// **Important**: This function can only be called once infinite cycles have +/// been detected! Otherwise it'll stack overflow. +pub fn resolve_type_alias(field_type: &FieldType, db: &ParserDatabase) -> FieldType { + match field_type { + // For symbols we need to check if we're dealing with aliases. + FieldType::Symbol(arity, ident, attrs) => { + let Some(string_id) = db.interner.lookup(ident.name()) else { + unreachable!( + "Attempting to resolve alias `{ident}` that does not exist in the interner" + ); + }; + + let Some(top_id) = db.names.tops.get(&string_id) else { + unreachable!("Alias name `{ident}` is not registered in the context"); + }; + + match top_id { + ast::TopId::TypeAlias(alias_id) => { + let mut resolved = match db.types.resolved_type_aliases.get(alias_id) { + // Check if we can avoid deeper recursion. + Some(already_resolved) => already_resolved.to_owned(), + + // No luck, check if the type is resolvable. + None => { + // TODO: O(n) + if db + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| cycle.contains(alias_id)) + { + // Not resolvable, part of a cycle. + field_type.to_owned() + } else { + // Maybe resolvable, recurse deeper. + resolve_type_alias(&db.ast[*alias_id].value, db) + } + } + }; + + // Sync arity. Basically stuff like: + // + // type AliasOne = SomeClass? + // type AliasTwo = AliasOne + // + // AliasTwo resolves to an "optional" type. + // + // TODO: Add a `set_arity` function or something and avoid + // this clone. + resolved = if resolved.is_optional() || arity.is_optional() { + resolved.to_nullable() + } else { + resolved + }; + + // Merge attributes. + resolved.set_attributes({ + let mut merged_attrs = Vec::from(field_type.attributes()); + merged_attrs.extend(resolved.attributes().to_owned()); + + merged_attrs + }); + + resolved + } + + // Class or enum. Already "resolved", pop off the stack. + _ => field_type.to_owned(), + } + } + + // Recurse and resolve each type individually. + FieldType::Union(arity, items, span, attrs) + | FieldType::Tuple(arity, items, span, attrs) => { + let resolved = items + .iter() + .map(|item| resolve_type_alias(item, db)) + .collect(); + + match field_type { + FieldType::Union(..) => { + FieldType::Union(*arity, resolved, span.clone(), attrs.clone()) + } + FieldType::Tuple(..) => { + FieldType::Tuple(*arity, resolved, span.clone(), attrs.clone()) + } + _ => unreachable!("should only match tuples and unions"), + } + } + + // Base case, primitives or other types that are not aliases. No more + // "pointers" and graphs here. + _ => field_type.to_owned(), + } +} + +fn visit_type_alias<'db>( + alias_id: ast::TypeAliasId, + assignment: &'db ast::Assignment, + ctx: &mut Context<'db>, +) { + // Insert the entry as soon as we get here then if we find something we'll + // add edges to the graph. Otherwise no edges but we still need the Vertex + // in order for the cycles algorithm to work. + let alias_refs = ctx + .types + .type_alias_dependencies + .entry(alias_id) + .or_default(); + + let mut stack = vec![&assignment.value]; + + while let Some(item) = stack.pop() { + match item { + FieldType::Symbol(_, ident, _) => { + let Some(string_id) = ctx.interner.lookup(ident.name()) else { + ctx.push_error(DatamodelError::new_validation_error( + &format!("Type alias points to unknown identifier `{ident}`"), + item.span().clone(), + )); + return; + }; + + let Some(top_id) = ctx.names.tops.get(&string_id) else { + ctx.push_error(DatamodelError::new_validation_error( + &format!("Type alias points to unknown identifier `{ident}`"), + item.span().clone(), + )); + return; + }; + + // Add alias to the graph. + if let ast::TopId::TypeAlias(nested_alias_id) = top_id { + alias_refs.insert(*nested_alias_id); + } + } + + FieldType::Union(_, items, ..) | FieldType::Tuple(_, items, ..) => { + stack.extend(items.iter()); + } + + FieldType::List(_, nested, ..) => { + stack.push(nested); + } + + FieldType::Map(_, nested, ..) => { + let (key, value) = nested.as_ref(); + stack.push(key); + stack.push(value); + } + + _ => {} + } + } +} + fn visit_function<'db>(idx: ValExpId, function: &'db ast::ValueExprBlock, ctx: &mut Context<'db>) { let input_deps = function .input() diff --git a/engine/baml-lib/parser-database/src/walkers/alias.rs b/engine/baml-lib/parser-database/src/walkers/alias.rs new file mode 100644 index 000000000..8ba968dfc --- /dev/null +++ b/engine/baml-lib/parser-database/src/walkers/alias.rs @@ -0,0 +1,26 @@ +use std::collections::HashSet; + +use super::TypeWalker; +use internal_baml_schema_ast::ast::{self, FieldType, Identifier, WithName}; + +pub type TypeAliasWalker<'db> = super::Walker<'db, ast::TypeAliasId>; + +impl<'db> TypeAliasWalker<'db> { + /// Name of the type alias. + pub fn name(&self) -> &str { + &self.db.ast[self.id].identifier.name() + } + + /// Returns the field type that the alias points to. + pub fn target(&self) -> &'db FieldType { + &self.db.ast[self.id].value + } + + /// Returns a "virtual" type that represents the fully resolved alias. + /// + /// Since an alias can point to other aliases we might have to create a + /// type that does not exist in the AST. + pub fn resolved(&self) -> &'db FieldType { + &self.db.types.resolved_type_aliases[&self.id] + } +} diff --git a/engine/baml-lib/parser-database/src/walkers/class.rs b/engine/baml-lib/parser-database/src/walkers/class.rs index 9bc8098f1..16f68567c 100644 --- a/engine/baml-lib/parser-database/src/walkers/class.rs +++ b/engine/baml-lib/parser-database/src/walkers/class.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; +use super::TypeWalker; use super::{field::FieldWalker, EnumWalker}; use crate::types::Attributes; use baml_types::Constraint; @@ -9,7 +10,7 @@ use internal_baml_schema_ast::ast::SubType; use internal_baml_schema_ast::ast::{self, ArgumentId, WithIdentifier, WithName, WithSpan}; use std::collections::HashMap; -/// A `class` declaration in the Prisma schema. +/// Class walker with some helper methods to extract info from the parser DB. pub type ClassWalker<'db> = super::Walker<'db, ast::TypeExpId>; impl<'db> ClassWalker<'db> { @@ -42,9 +43,8 @@ impl<'db> ClassWalker<'db> { self.db.types.class_dependencies[&self.class_id()] .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(_cls)) => None, - Some(Either::Right(walker)) => Some(walker), - None => None, + Some(TypeWalker::Enum(walker)) => Some(walker), + _ => None, }) } @@ -53,9 +53,8 @@ impl<'db> ClassWalker<'db> { self.db.types.class_dependencies[&self.class_id()] .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker), - Some(Either::Right(_enm)) => None, - None => None, + Some(TypeWalker::Class(walker)) => Some(walker), + _ => None, }) } @@ -175,9 +174,8 @@ impl<'db> ArgWalker<'db> { input .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(_cls)) => None, - Some(Either::Right(walker)) => Some(walker), - None => None, + Some(TypeWalker::Enum(walker)) => Some(walker), + _ => None, }) } @@ -187,9 +185,8 @@ impl<'db> ArgWalker<'db> { input .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker), - Some(Either::Right(_enm)) => None, - None => None, + Some(TypeWalker::Class(walker)) => Some(walker), + _ => None, }) } } diff --git a/engine/baml-lib/parser-database/src/walkers/function.rs b/engine/baml-lib/parser-database/src/walkers/function.rs index bebeb1977..d2074183e 100644 --- a/engine/baml-lib/parser-database/src/walkers/function.rs +++ b/engine/baml-lib/parser-database/src/walkers/function.rs @@ -8,7 +8,7 @@ use crate::{ types::FunctionType, }; -use super::{ClassWalker, ConfigurationWalker, EnumWalker, Walker}; +use super::{ClassWalker, ConfigurationWalker, EnumWalker, TypeWalker, Walker}; use std::iter::ExactSizeIterator; @@ -221,9 +221,8 @@ impl<'db> ArgWalker<'db> { if self.id.1 { input } else { output } .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(_cls)) => None, - Some(Either::Right(walker)) => Some(walker), - None => None, + Some(TypeWalker::Enum(walker)) => Some(walker), + _ => None, }) } @@ -233,9 +232,8 @@ impl<'db> ArgWalker<'db> { if self.id.1 { input } else { output } .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker), - Some(Either::Right(_enm)) => None, - None => None, + Some(TypeWalker::Class(walker)) => Some(walker), + _ => None, }) } } diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index e1a7d5960..a6154a1cf 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -6,6 +6,7 @@ //! - Know about relations. //! - Do not know anything about connectors, they are generic. +mod alias; mod r#class; mod client; mod configuration; @@ -14,13 +15,16 @@ mod field; mod function; mod template_string; +use alias::TypeAliasWalker; use baml_types::TypeValue; pub use client::*; pub use configuration::*; use either::Either; pub use field::*; pub use function::FunctionWalker; -use internal_baml_schema_ast::ast::{FieldType, Identifier, TopId, TypeExpId, WithName}; +use internal_baml_schema_ast::ast::{ + FieldType, Identifier, TopId, TypeAliasId, TypeExpId, WithName, +}; pub use r#class::*; pub use r#enum::*; pub use template_string::TemplateStringWalker; @@ -50,11 +54,21 @@ where } } +/// Walker kind. +pub enum TypeWalker<'db> { + /// Class walker. + Class(ClassWalker<'db>), + /// Enum walker. + Enum(EnumWalker<'db>), + /// Type alias walker. + TypeAlias(TypeAliasWalker<'db>), +} + impl<'db> crate::ParserDatabase { /// Find an enum by name. pub fn find_enum(&'db self, idn: &Identifier) -> Option> { self.find_type(idn).and_then(|either| match either { - Either::Right(class) => Some(class), + TypeWalker::Enum(enm) => Some(enm), _ => None, }) } @@ -66,22 +80,19 @@ impl<'db> crate::ParserDatabase { } /// Find a type by name. - pub fn find_type_by_str( - &'db self, - name: &str, - ) -> Option, EnumWalker<'db>>> { + pub fn find_type_by_str(&'db self, name: &str) -> Option> { self.find_top_by_str(name).and_then(|top_id| match top_id { - TopId::Class(class_id) => Some(Either::Left(self.walk(*class_id))), - TopId::Enum(enum_id) => Some(Either::Right(self.walk(*enum_id))), + TopId::Class(class_id) => Some(TypeWalker::Class(self.walk(*class_id))), + TopId::Enum(enum_id) => Some(TypeWalker::Enum(self.walk(*enum_id))), + TopId::TypeAlias(type_alias_id) => { + Some(TypeWalker::TypeAlias(self.walk(*type_alias_id))) + } _ => None, }) } /// Find a type by name. - pub fn find_type( - &'db self, - idn: &Identifier, - ) -> Option, EnumWalker<'db>>> { + pub fn find_type(&'db self, idn: &Identifier) -> Option> { match idn { Identifier::Local(local, _) => self.find_type_by_str(local), _ => None, @@ -91,7 +102,7 @@ impl<'db> crate::ParserDatabase { /// Find a model by name. pub fn find_class(&'db self, idn: &Identifier) -> Option> { self.find_type(idn).and_then(|either| match either { - Either::Left(class) => Some(class), + TypeWalker::Class(class) => Some(class), _ => None, }) } @@ -133,6 +144,22 @@ impl<'db> crate::ParserDatabase { &self.types.finite_recursive_cycles } + /// Set of all aliases that are part of a structural cycle. + /// + /// A structural cycle is created through a map or list, which introduce one + /// level of indirection. + pub fn structural_recursive_alias_cycles(&self) -> &[Vec] { + &self.types.structural_recursive_alias_cycles + } + + /// Returns the resolved aliases map. + pub fn resolved_type_alias_by_name(&self, alias: &str) -> Option<&FieldType> { + match self.find_type_by_str(alias) { + Some(TypeWalker::TypeAlias(walker)) => Some(walker.resolved()), + _ => None, + } + } + /// Traverse a schema element by id. pub fn walk(&self, id: I) -> Walker<'_, I> { Walker { db: self, id } @@ -192,6 +219,17 @@ impl<'db> crate::ParserDatabase { }) } + /// Walk all the type aliases in the AST. + pub fn walk_type_aliases(&self) -> impl Iterator> { + self.ast() + .iter_tops() + .filter_map(|(top_id, _)| top_id.as_type_alias_id()) + .map(move |top_id| Walker { + db: self, + id: top_id, + }) + } + /// Walk all template strings in the schema. pub fn walk_templates(&self) -> impl Iterator> { self.ast() @@ -255,8 +293,9 @@ impl<'db> crate::ParserDatabase { FieldType::Symbol(arity, idn, ..) => { let mut t = match self.find_type(idn) { None => Type::Undefined, - Some(Either::Left(_)) => Type::ClassRef(idn.to_string()), - Some(Either::Right(_)) => Type::String, + Some(TypeWalker::Class(_)) => Type::ClassRef(idn.to_string()), + Some(TypeWalker::Enum(_)) => Type::String, + Some(TypeWalker::TypeAlias(_)) => Type::String, }; if arity.is_optional() { t = Type::None | t; diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index 1c27b7900..d34b3ecad 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -1,4 +1,5 @@ mod argument; +mod assignment; mod attribute; mod comment; @@ -19,6 +20,7 @@ mod value_expression_block; pub(crate) use self::comment::Comment; pub use argument::{Argument, ArgumentId, ArgumentsList}; +pub use assignment::Assignment; pub use attribute::{Attribute, AttributeContainer, AttributeId}; pub use config::ConfigBlockProperty; pub use expression::{Expression, RawString}; @@ -35,13 +37,14 @@ pub use value_expression_block::{BlockArg, BlockArgs, ValueExprBlock, ValueExprB /// AST representation of a prisma schema. /// -/// This module is used internally to represent an AST. The AST's nodes can be used -/// during validation of a schema, especially when implementing custom attributes. +/// This module is used internally to represent an AST. The AST's nodes can be +/// used during validation of a schema, especially when implementing custom +/// attributes. /// -/// The AST is not validated, also fields and attributes are not resolved. Every node is -/// annotated with its location in the text representation. -/// Basically, the AST is an object oriented representation of the datamodel's text. -/// Schema = Datamodel + Generators + Datasources +/// The AST is not validated, also fields and attributes are not resolved. Every +/// node is annotated with its location in the text representation. +/// Basically, the AST is an object oriented representation of the datamodel's +/// text. Schema = Datamodel + Generators + Datasources #[derive(Debug)] pub struct SchemaAst { /// All models, enums, composite types, datasources, generators and type aliases. @@ -99,6 +102,20 @@ impl std::ops::Index for SchemaAst { } } +/// An opaque identifier for a type alias in a schema AST. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TypeAliasId(u32); + +impl std::ops::Index for SchemaAst { + type Output = Assignment; + + fn index(&self, index: TypeAliasId) -> &Self::Output { + self.tops[index.0 as usize] + .as_type_alias_assignment() + .expect("expected type expression") + } +} + /// An opaque identifier for a model in a schema AST. Use the /// `schema[model_id]` syntax to resolve the id to an `ast::Model`. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -130,25 +147,28 @@ impl std::ops::Index for SchemaAst { /// syntax to resolve the id to an `ast::Top`. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum TopId { - /// An enum declaration + /// An enum declaration. Enum(TypeExpId), - // A class declaration + /// A class declaration. Class(TypeExpId), - // A function declaration + /// A function declaration. Function(ValExpId), - // A client declaration + /// A type alias declaration. + TypeAlias(TypeAliasId), + + /// A client declaration. Client(ValExpId), - // A generator declaration + /// A generator declaration. Generator(ValExpId), - // Template Strings + /// Template Strings. TemplateString(TemplateStringId), - // A config block + /// A config block. TestCase(ValExpId), RetryPolicy(ValExpId), @@ -171,6 +191,14 @@ impl TopId { } } + /// Try to interpret the top as a type alias. + pub fn as_type_alias_id(self) -> Option { + match self { + TopId::TypeAlias(id) => Some(id), + _ => None, + } + } + /// Try to interpret the top as a function. pub fn as_function_id(self) -> Option { match self { @@ -215,6 +243,7 @@ impl std::ops::Index for SchemaAst { let idx = match index { TopId::Enum(TypeExpId(idx)) => idx, TopId::Class(TypeExpId(idx)) => idx, + TopId::TypeAlias(TypeAliasId(idx)) => idx, TopId::Function(ValExpId(idx)) => idx, TopId::TemplateString(TemplateStringId(idx)) => idx, TopId::Client(ValExpId(idx)) => idx, @@ -232,6 +261,7 @@ fn top_idx_to_top_id(top_idx: usize, top: &Top) -> TopId { Top::Enum(_) => TopId::Enum(TypeExpId(top_idx as u32)), Top::Class(_) => TopId::Class(TypeExpId(top_idx as u32)), Top::Function(_) => TopId::Function(ValExpId(top_idx as u32)), + Top::TypeAlias(_) => TopId::TypeAlias(TypeAliasId(top_idx as u32)), Top::Client(_) => TopId::Client(ValExpId(top_idx as u32)), Top::TemplateString(_) => TopId::TemplateString(TemplateStringId(top_idx as u32)), Top::Generator(_) => TopId::Generator(ValExpId(top_idx as u32)), diff --git a/engine/baml-lib/schema-ast/src/ast/assignment.rs b/engine/baml-lib/schema-ast/src/ast/assignment.rs new file mode 100644 index 000000000..bfa675645 --- /dev/null +++ b/engine/baml-lib/schema-ast/src/ast/assignment.rs @@ -0,0 +1,41 @@ +//! Assignment expressions. +//! +//! As of right now the only supported "assignments" are type aliases. + +use super::{ + traits::WithAttributes, Attribute, BlockArgs, Comment, Field, FieldType, Identifier, Span, + WithDocumentation, WithIdentifier, WithSpan, +}; + +/// Assignment expression. `left = right`. +#[derive(Debug, Clone)] +pub struct Assignment { + /// Left side of the assignment. + /// + /// For now this can only be an identifier, but if we end up needing to + /// support stuff like destructuring then change it to some sort of + /// expression. + pub identifier: Identifier, + + /// Right side of the assignment. + /// + /// Since for now it's only used for type aliases then it's just a type. + pub value: FieldType, + + /// Span of the entire assignment. + pub span: Span, +} + +impl WithSpan for Assignment { + fn span(&self) -> &Span { + &self.span + } +} + +// TODO: Right now the left side is always an identifier, but if it ends up +// being an expression we'll have to refactor this somehow. +impl WithIdentifier for Assignment { + fn identifier(&self) -> &Identifier { + &self.identifier + } +} diff --git a/engine/baml-lib/schema-ast/src/ast/attribute.rs b/engine/baml-lib/schema-ast/src/ast/attribute.rs index 18c05be40..544b07656 100644 --- a/engine/baml-lib/schema-ast/src/ast/attribute.rs +++ b/engine/baml-lib/schema-ast/src/ast/attribute.rs @@ -65,6 +65,7 @@ pub enum AttributeContainer { ClassField(super::TypeExpId, super::FieldId), Enum(super::TypeExpId), EnumValue(super::TypeExpId, super::FieldId), + TypeAlias(super::TypeAliasId), } impl From for AttributeContainer { @@ -79,6 +80,12 @@ impl From<(super::TypeExpId, super::FieldId)> for AttributeContainer { } } +impl From for AttributeContainer { + fn from(v: super::TypeAliasId) -> Self { + Self::TypeAlias(v) + } +} + /// An attribute (@ or @@) node in the AST. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct AttributeId(AttributeContainer, u32); @@ -102,6 +109,7 @@ impl Index for super::SchemaAst { AttributeContainer::EnumValue(enum_id, value_idx) => { &self[enum_id][value_idx].attributes } + AttributeContainer::TypeAlias(alias_id) => &self[alias_id].value.attributes(), } } } diff --git a/engine/baml-lib/schema-ast/src/ast/top.rs b/engine/baml-lib/schema-ast/src/ast/top.rs index 595e27e54..dfc67285b 100644 --- a/engine/baml-lib/schema-ast/src/ast/top.rs +++ b/engine/baml-lib/schema-ast/src/ast/top.rs @@ -1,24 +1,26 @@ use super::{ - traits::WithSpan, Identifier, Span, TemplateString, TypeExpressionBlock, ValueExprBlock, - WithIdentifier, + assignment::Assignment, traits::WithSpan, Identifier, Span, TemplateString, + TypeExpressionBlock, ValueExprBlock, WithIdentifier, }; /// Enum for distinguishing between top-level entries #[derive(Debug, Clone)] pub enum Top { - /// An enum declaration + /// An enum declaration. Enum(TypeExpressionBlock), - // A class declaration + /// A class declaration. Class(TypeExpressionBlock), - // A function declaration + /// A function declaration. Function(ValueExprBlock), + /// Type alias expression. + TypeAlias(Assignment), - // Clients to run + /// Clients to run. Client(ValueExprBlock), TemplateString(TemplateString), - // Generator + /// Generator. Generator(ValueExprBlock), TestCase(ValueExprBlock), @@ -34,6 +36,7 @@ impl Top { Top::Enum(_) => "enum", Top::Class(_) => "class", Top::Function(_) => "function", + Top::TypeAlias(_) => "type_alias", Top::Client(_) => "client", Top::TemplateString(_) => "template_string", Top::Generator(_) => "generator", @@ -62,6 +65,13 @@ impl Top { } } + pub fn as_type_alias_assignment(&self) -> Option<&Assignment> { + match self { + Top::TypeAlias(assignment) => Some(assignment), + _ => None, + } + } + pub fn as_template_string(&self) -> Option<&TemplateString> { match self { Top::TemplateString(t) => Some(t), @@ -78,6 +88,7 @@ impl WithIdentifier for Top { Top::Enum(x) => x.identifier(), Top::Class(x) => x.identifier(), Top::Function(x) => x.identifier(), + Top::TypeAlias(x) => x.identifier(), Top::Client(x) => x.identifier(), Top::TemplateString(x) => x.identifier(), Top::Generator(x) => x.identifier(), @@ -93,6 +104,7 @@ impl WithSpan for Top { Top::Enum(en) => en.span(), Top::Class(class) => class.span(), Top::Function(func) => func.span(), + Top::TypeAlias(alias) => alias.span(), Top::TemplateString(template) => template.span(), Top::Client(client) => client.span(), Top::Generator(gen) => gen.span(), diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index 3cc615534..7fe6ec60e 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -78,7 +78,7 @@ single_word = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" | "-")* } // ###################################### // Type Alias // ###################################### -type_alias = { TYPE_KEYWORD ~ identifier ~ base_type ~ (NEWLINE? ~ field_attribute)* } +type_alias = { identifier ~ identifier ~ assignment ~ field_type_with_attr } // ###################################### // Arguments @@ -177,7 +177,6 @@ BLOCK_CLOSE = { "}" } BLOCK_LEVEL_CATCH_ALL = { !BLOCK_CLOSE ~ CATCH_ALL } CATCH_ALL = { (!NEWLINE ~ ANY)+ ~ NEWLINE? } -TYPE_KEYWORD = { "type" } FUNCTION_KEYWORD = { "function" } TEMPLATE_KEYWORD = { "template_string" | "string_template" } TEST_KEYWORD = { "test" } diff --git a/engine/baml-lib/schema-ast/src/parser/mod.rs b/engine/baml-lib/schema-ast/src/parser/mod.rs index 82ec4c6a7..5f9bf74de 100644 --- a/engine/baml-lib/schema-ast/src/parser/mod.rs +++ b/engine/baml-lib/schema-ast/src/parser/mod.rs @@ -1,5 +1,6 @@ mod helpers; mod parse_arguments; +mod parse_assignment; mod parse_attribute; mod parse_comments; mod parse_expression; diff --git a/engine/baml-lib/schema-ast/src/parser/parse_assignment.rs b/engine/baml-lib/schema-ast/src/parser/parse_assignment.rs new file mode 100644 index 000000000..a0b52efc6 --- /dev/null +++ b/engine/baml-lib/schema-ast/src/parser/parse_assignment.rs @@ -0,0 +1,152 @@ +use super::{ + helpers::{parsing_catch_all, Pair}, + parse_identifier::parse_identifier, + parse_named_args_list::parse_named_argument_list, + Rule, +}; + +use crate::{ + assert_correct_parser, + ast::*, + parser::{parse_field::parse_field_type_with_attr, parse_types::parse_field_type}, +}; + +use internal_baml_diagnostics::{DatamodelError, Diagnostics}; + +/// Parses an assignment in the form of `keyword identifier = FieldType`. +/// +/// It only works with type aliases for now, it's not generic over all +/// expressions. +pub(crate) fn parse_assignment(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> Assignment { + assert_correct_parser!(pair, Rule::type_alias); + + let span = pair.as_span(); + + let mut consumed_definition_keyword = false; + + let mut identifier: Option = None; + let mut field_type: Option = None; + + for current in pair.into_inner() { + match current.as_rule() { + Rule::identifier => { + if !consumed_definition_keyword { + consumed_definition_keyword = true; + match current.as_str() { + "type" => {} // Ok, type alias. + + other => diagnostics.push_error(DatamodelError::new_validation_error( + &format!("Unexpected keyword used in assignment: {other}"), + diagnostics.span(current.as_span()), + )), + } + } else { + // There are two identifiers, the second one is the name of + // the type alias. + identifier = Some(parse_identifier(current, diagnostics)); + } + } + + Rule::assignment => {} // Ok, equal sign. + + // TODO: We probably only need field_type_with_attr since that's how + // the PEST syntax is defined. + Rule::field_type => field_type = parse_field_type(current, diagnostics), + + Rule::field_type_with_attr => { + field_type = parse_field_type_with_attr(current, false, diagnostics) + } + + _ => parsing_catch_all(current, "type_alias"), + } + } + + match (identifier, field_type) { + (Some(identifier), Some(field_type)) => Assignment { + identifier, + value: field_type, + span: diagnostics.span(span), + }, + + _ => panic!("Encountered impossible type_alias declaration during parsing"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::{BAMLParser, Rule}; + use baml_types::TypeValue; + use internal_baml_diagnostics::{Diagnostics, SourceFile}; + use pest::{consumes_to, fails_with, parses_to, Parser}; + + fn parse_type_alias(input: &'static str) -> Assignment { + let path = "test.baml"; + let source = SourceFile::new_static(path.into(), input); + + let mut diagnostics = Diagnostics::new(path.into()); + diagnostics.set_source(&source); + + let pairs = BAMLParser::parse(Rule::type_alias, input) + .unwrap() + .next() + .unwrap(); + + let assignment = super::parse_assignment(pairs, &mut diagnostics); + + // (assignment, diagnostics) + assignment + } + + #[test] + fn parse_type_alias_assignment_tokens() { + parses_to! { + parser: BAMLParser, + input: "type Test = int", + rule: Rule::type_alias, + tokens: [ + type_alias(0, 15, [ + identifier(0, 4, [single_word(0, 4)]), + identifier(5, 9, [single_word(5, 9)]), + assignment(10, 11), + field_type_with_attr(12, 15, [ + field_type(12, 15, [ + non_union(12, 15, [ + identifier(12, 15, [single_word(12, 15)]) + ]), + ]), + ]), + ]), + ] + } + + // This is parsed as identifier ~ identifier because of how Pest handles + // whitespaces. + // https://github.com/pest-parser/pest/discussions/967 + fails_with! { + parser: BAMLParser, + input: "typeTest = int", + rule: Rule::type_alias, + positives: [Rule::identifier], + negatives: [], + pos: 9 + } + } + + #[test] + fn parse_union_type_alias() { + let assignment = parse_type_alias("type Test = int | string"); + + assert_eq!(assignment.identifier.to_string(), "Test"); + + let FieldType::Union(_, elements, _, _) = assignment.value else { + panic!("Expected union type, got: {:?}", assignment.value); + }; + + let [FieldType::Primitive(_, TypeValue::Int, _, _), FieldType::Primitive(_, TypeValue::String, _, _)] = + elements.as_slice() + else { + panic!("Expected int | string union, got: {:?}", elements); + }; + } +} diff --git a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs index d6a0ac850..b63fab373 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs @@ -1,7 +1,7 @@ use std::path::{Path, PathBuf}; use super::{ - parse_template_string::parse_template_string, + parse_assignment::parse_assignment, parse_template_string::parse_template_string, parse_type_expression_block::parse_type_expression_block, parse_value_expression_block::parse_value_expression_block, BAMLParser, Rule, }; @@ -77,20 +77,20 @@ pub fn parse_schema( &mut diagnostics, ); match val_expr { - Ok(val) => { - if let Some(top) = match val.block_type { - ValueExprBlockType::Function => Some(Top::Function(val)), - ValueExprBlockType::Test => Some(Top::TestCase(val)), - ValueExprBlockType::Client => Some(Top::Client(val)), - ValueExprBlockType::RetryPolicy => Some(Top::RetryPolicy(val)), - ValueExprBlockType::Generator => Some(Top::Generator(val)), - } { - top_level_definitions.push(top); - } - } + Ok(val) => top_level_definitions.push(match val.block_type { + ValueExprBlockType::Function => Top::Function(val), + ValueExprBlockType::Test => Top::TestCase(val), + ValueExprBlockType::Client => Top::Client(val), + ValueExprBlockType::RetryPolicy => Top::RetryPolicy(val), + ValueExprBlockType::Generator => Top::Generator(val), + }), Err(e) => diagnostics.push_error(e), } } + Rule::type_alias => { + let assignment = parse_assignment(current, &mut diagnostics); + top_level_definitions.push(Top::TypeAlias(assignment)); + } Rule::template_declaration => { match parse_template_string( @@ -184,7 +184,9 @@ mod tests { use std::path::Path; use super::parse_schema; - use crate::ast::*; // Add this line to import the ast module + use crate::ast::*; + use baml_types::TypeValue; + // Add this line to import the ast module use internal_baml_diagnostics::SourceFile; #[test] @@ -356,4 +358,40 @@ mod tests { } } } + + #[test] + fn test_push_type_aliases() { + let input = "type One = int\ntype Two = string | One"; + + let path = "example_file.baml"; + let source = SourceFile::new_static(path.into(), input); + + let (ast, _) = parse_schema(&Path::new(path), &source).unwrap(); + + let [Top::TypeAlias(one), Top::TypeAlias(two)] = ast.tops.as_slice() else { + panic!( + "Expected two type aliases (type One, type Two), got: {:?}", + ast.tops + ); + }; + + assert_eq!(one.identifier.to_string(), "One"); + assert!(matches!( + one.value, + FieldType::Primitive(_, TypeValue::Int, _, _) + )); + + assert_eq!(two.identifier.to_string(), "Two"); + let FieldType::Union(_, elements, _, _) = &two.value else { + panic!("Expected union type (string | One), got: {:?}", two.value); + }; + + let [FieldType::Primitive(_, TypeValue::String, _, _), FieldType::Symbol(_, alias, _)] = + elements.as_slice() + else { + panic!("Expected union type (string | One), got: {:?}", two.value); + }; + + assert_eq!(alias.to_string(), "One"); + } } diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 892e13a06..069ced74e 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use anyhow::Result; use baml_types::BamlValue; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use internal_baml_core::ir::{ repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, }; @@ -18,12 +18,14 @@ pub fn render_output_format( ctx: &RuntimeContext, output: &FieldType, ) -> Result { - let (enums, classes, recursive_classes) = relevant_data_models(ir, output, ctx)?; + let (enums, classes, recursive_classes, structural_recursive_aliases) = + relevant_data_models(ir, output, ctx)?; Ok(OutputFormatContent::target(output.clone()) .enums(enums) .classes(classes) .recursive_classes(recursive_classes) + .structural_recursive_aliases(structural_recursive_aliases) .build()) } @@ -210,11 +212,17 @@ fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, ctx: &RuntimeContext, -) -> Result<(Vec, Vec, IndexSet)> { +) -> Result<( + Vec, + Vec, + IndexSet, + IndexMap, +)> { let mut checked_types = HashSet::new(); let mut enums = Vec::new(); let mut classes = Vec::new(); let mut recursive_classes = IndexSet::new(); + let mut structural_recursive_aliases = IndexMap::new(); let mut start: Vec = vec![output.clone()]; let eval_ctx = ctx.eval_ctx(false); @@ -365,9 +373,27 @@ fn relevant_data_models<'a>( constraints, }); } else { + // TODO: @antonio This one was nasty! If aliases are not + // resolved in the `ir.finite_recursive_cycles()` function + // then an alias that points to a recursive class will get + // resolved below and then this code will run, introducing + // a recursive class in the relevant data models that does + // not exist in the IR although it should!. Now it's been + // fixed so this should be safe to remove, it wasn't even + // a bug it was "why is this working when IT SHOULD NOT". recursive_classes.insert(cls.to_owned()); } } + (FieldType::RecursiveTypeAlias(name), _) => { + // TODO: Same O(n) problem as above. + for cycle in ir.structural_recursive_alias_cycles() { + if cycle.contains_key(name) { + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } + } + } + } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _) => {} (FieldType::Constrained { .. }, _) => { @@ -376,7 +402,12 @@ fn relevant_data_models<'a>( } } - Ok((enums, classes, recursive_classes)) + Ok(( + enums, + classes, + recursive_classes, + structural_recursive_aliases, + )) } #[cfg(test)] diff --git a/engine/baml-runtime/tests/test_runtime.rs b/engine/baml-runtime/tests/test_runtime.rs index e10a957a4..97e9d07ac 100644 --- a/engine/baml-runtime/tests/test_runtime.rs +++ b/engine/baml-runtime/tests/test_runtime.rs @@ -498,4 +498,106 @@ test TestTree { Ok(()) } + + #[test] + fn test_constrained_type_alias() -> anyhow::Result<()> { + let runtime = make_test_runtime( + r##" +class Foo2 { + bar int + baz string + sub Subthing @assert( {{ this.bar == 10}} ) | null +} + +class Foo3 { + bar int + baz string + sub Foo3 | null +} + +type Subthing = Foo2 @assert( {{ this.bar == 10 }}) + +function RunFoo2(input: Foo3) -> Foo2 { + client "openai/gpt-4o" + prompt #"Generate a Foo2 wrapping 30. Use {{ input }}. + {{ ctx.output_format }} + "# +} + +test RunFoo2Test { + functions [RunFoo2] + args { + input { + bar 30 + baz "hello" + sub null + } + } +} + "##, + )?; + + let ctx = runtime + .create_ctx_manager(BamlValue::String("test".to_string()), None) + .create_ctx_with_default(); + + let function_name = "RunFoo2"; + let test_name = "RunFoo2Test"; + let params = runtime.get_test_params(function_name, test_name, &ctx, true)?; + let render_prompt_future = + runtime + .internal() + .render_prompt(function_name, &ctx, ¶ms, None); + let (prompt, scope, _) = runtime.async_runtime.block_on(render_prompt_future)?; + + Ok(()) + } + + #[test] + fn test_recursive_alias_cycle() -> anyhow::Result<()> { + let runtime = make_test_runtime( + r##" +type RecAliasOne = RecAliasTwo +type RecAliasTwo = RecAliasThree +type RecAliasThree = RecAliasOne[] + +function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne { + client "openai/gpt-4o" + prompt r#" + Return the given value: + + {{ input }} + + {{ ctx.output_format }} + "# +} + +test RecursiveAliasCycle { + functions [RecursiveAliasCycle] + args { + input [ + [] + [] + [[], []] + ] + } +} + "##, + )?; + + let ctx = runtime + .create_ctx_manager(BamlValue::String("test".to_string()), None) + .create_ctx_with_default(); + + let function_name = "RecursiveAliasCycle"; + let test_name = "RecursiveAliasCycle"; + let params = runtime.get_test_params(function_name, test_name, &ctx, true)?; + let render_prompt_future = + runtime + .internal() + .render_prompt(function_name, &ctx, ¶ms, None); + let (prompt, scope, _) = runtime.async_runtime.block_on(render_prompt_future)?; + + Ok(()) + } } diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index edf84ffbc..614b19fbc 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -876,6 +876,7 @@ fn get_dummy_value( baml_runtime::FieldType::Literal(_) => None, baml_runtime::FieldType::Enum(_) => None, baml_runtime::FieldType::Class(_) => None, + baml_runtime::FieldType::RecursiveTypeAlias(_) => None, baml_runtime::FieldType::List(item) => { let dummy = get_dummy_value(indent + 1, allow_multiline, item); // Repeat it 2 times diff --git a/engine/baml-schema-wasm/tests/test_file_manager.rs b/engine/baml-schema-wasm/tests/test_file_manager.rs index 8de0e0743..6069e6091 100644 --- a/engine/baml-schema-wasm/tests/test_file_manager.rs +++ b/engine/baml-schema-wasm/tests/test_file_manager.rs @@ -229,4 +229,112 @@ test Two { assert!(diagnostics.errors().is_empty()); } + + #[wasm_bindgen_test] + fn test_type_alias_pointing_to_unknown_identifier() { + wasm_logger::init(wasm_logger::Config::new(log::Level::Info)); + let sample_baml_content = r##" + type Foo = i + "##; + let mut files = HashMap::new(); + files.insert("error.baml".to_string(), sample_baml_content.to_string()); + let files_js = to_value(&files).unwrap(); + let project = WasmProject::new("baml_src", files_js) + .map_err(JsValue::from) + .unwrap(); + + let env_vars = [("OPENAI_API_KEY", "12345")] + .iter() + .cloned() + .collect::>(); + let env_vars_js = to_value(&env_vars).unwrap(); + + let Err(js_error) = project.runtime(env_vars_js) else { + panic!("Expected error, got Ok"); + }; + + assert!(js_error.is_object()); + + // TODO: Don't know how to build Object + // assert_eq!( + // js_error, + // serde_wasm_bindgen::to_value::>>(&HashMap::from_iter([( + // "all_files".to_string(), + // vec!["error.baml".to_string()] + // )])) + // .unwrap() + // ); + } + + #[wasm_bindgen_test] + fn test_type_alias_pointing_to_union_with_unknown_identifier() { + wasm_logger::init(wasm_logger::Config::new(log::Level::Info)); + let sample_baml_content = r##" + type Foo = int | i + "##; + let mut files = HashMap::new(); + files.insert("error.baml".to_string(), sample_baml_content.to_string()); + let files_js = to_value(&files).unwrap(); + let project = WasmProject::new("baml_src", files_js) + .map_err(JsValue::from) + .unwrap(); + + let env_vars = [("OPENAI_API_KEY", "12345")] + .iter() + .cloned() + .collect::>(); + let env_vars_js = to_value(&env_vars).unwrap(); + + let Err(js_error) = project.runtime(env_vars_js) else { + panic!("Expected error, got Ok"); + }; + + assert!(js_error.is_object()); + + // TODO: Don't know how to build Object + // assert_eq!( + // js_error, + // serde_wasm_bindgen::to_value::>>(&HashMap::from_iter([( + // "all_files".to_string(), + // vec!["error.baml".to_string()] + // )])) + // .unwrap() + // ); + } + + #[wasm_bindgen_test] + fn test_type_alias_pointing_to_union_with_unknown_identifier_in_union() { + wasm_logger::init(wasm_logger::Config::new(log::Level::Info)); + let sample_baml_content = r##" + type Four = int | string | b + "##; + let mut files = HashMap::new(); + files.insert("error.baml".to_string(), sample_baml_content.to_string()); + let files_js = to_value(&files).unwrap(); + let project = WasmProject::new("baml_src", files_js) + .map_err(JsValue::from) + .unwrap(); + + let env_vars = [("OPENAI_API_KEY", "12345")] + .iter() + .cloned() + .collect::>(); + let env_vars_js = to_value(&env_vars).unwrap(); + + let Err(js_error) = project.runtime(env_vars_js) else { + panic!("Expected error, got Ok"); + }; + + assert!(js_error.is_object()); + + // TODO: Don't know how to build Object + // assert_eq!( + // js_error, + // serde_wasm_bindgen::to_value::>>(&HashMap::from_iter([( + // "all_files".to_string(), + // vec!["error.baml".to_string()] + // )])) + // .unwrap() + // ); + } } diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index c6f5e3ec6..9fb2cca91 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -539,6 +539,15 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { r#ref: format!("#/components/schemas/{}", name), }, }, + FieldType::RecursiveTypeAlias(_) => TypeSpecWithMeta { + meta: TypeMetadata { + title: None, + r#enum: None, + r#const: None, + nullable: false, + }, + type_spec: TypeSpec::AnyValue { any_of: vec![] }, + }, FieldType::Literal(v) => TypeSpecWithMeta { meta: TypeMetadata { title: None, @@ -564,8 +573,14 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { }), }, FieldType::Map(key, value) => { - if !matches!(**key, FieldType::Primitive(TypeValue::String)) { - anyhow::bail!("BAML<->OpenAPI only supports string keys in maps") + if !matches!( + **key, + FieldType::Primitive(TypeValue::String) + | FieldType::Enum(_) + | FieldType::Literal(LiteralValue::String(_)) + | FieldType::Union(_) + ) { + anyhow::bail!(format!("BAML<->OpenAPI only supports strings, enums and literal strings as map keys but got {key}")) } TypeSpecWithMeta { meta: TypeMetadata { @@ -704,6 +719,10 @@ enum TypeSpec { #[serde(rename = "oneOf", alias = "oneOf")] one_of: Vec, }, + AnyValue { + #[serde(rename = "anyOf", alias = "anyOf")] + any_of: Vec, + }, } #[derive(Clone, Debug, Serialize)] diff --git a/engine/language_client_codegen/src/python/generate_types.rs b/engine/language_client_codegen/src/python/generate_types.rs index 40cfb9c33..428b25d70 100644 --- a/engine/language_client_codegen/src/python/generate_types.rs +++ b/engine/language_client_codegen/src/python/generate_types.rs @@ -7,7 +7,7 @@ use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes}; use super::python_language_features::ToPython; use internal_baml_core::ir::{ - repr::{Docstring, IntermediateRepr}, + repr::{Docstring, IntermediateRepr, Walker}, ClassWalker, EnumWalker, FieldType, IRHelper, }; @@ -16,6 +16,7 @@ use internal_baml_core::ir::{ pub(crate) struct PythonTypes<'ir> { enums: Vec>, classes: Vec>, + structural_recursive_alias_cycles: Vec>, } #[derive(askama::Template)] @@ -41,6 +42,11 @@ struct PythonClass<'ir> { dynamic: bool, } +struct PythonTypeAlias<'ir> { + name: Cow<'ir, str>, + target: String, +} + #[derive(askama::Template)] #[template(path = "partial_types.py.j2", escape = "none")] pub(crate) struct PythonStreamTypes<'ir> { @@ -66,6 +72,10 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonT Ok(PythonTypes { enums: ir.walk_enums().map(PythonEnum::from).collect::>(), classes: ir.walk_classes().map(PythonClass::from).collect::>(), + structural_recursive_alias_cycles: ir + .walk_alias_cycles() + .map(PythonTypeAlias::from) + .collect::>(), }) } } @@ -126,6 +136,21 @@ impl<'ir> From> for PythonClass<'ir> { } } +// TODO: Define AliasWalker to simplify type. +impl<'ir> From> for PythonTypeAlias<'ir> { + fn from( + Walker { + db, + item: (name, target), + }: Walker<(&'ir String, &'ir FieldType)>, + ) -> Self { + PythonTypeAlias { + name: Cow::Borrowed(name), + target: target.to_type_ref(db), + } + } +} + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonStreamTypes<'ir> { type Error = anyhow::Error; @@ -219,6 +244,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { format!("\"{name}\"") } } + FieldType::RecursiveTypeAlias(name) => format!("\"{name}\""), FieldType::Literal(value) => to_python_literal(value), FieldType::Class(name) => format!("\"{name}\""), FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)), @@ -274,6 +300,13 @@ impl ToTypeReferenceInTypeDefinition for FieldType { format!("Optional[types.{name}]") } } + FieldType::RecursiveTypeAlias(name) => { + if wrapped { + format!("\"{name}\"") + } else { + format!("Optional[\"{name}\"]") + } + } FieldType::Literal(value) => to_python_literal(value), FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true)), FieldType::Map(key, value) => { diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index 2eda9b71f..06fbbe7d8 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -201,6 +201,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } FieldType::Literal(value) => to_python_literal(value), + FieldType::RecursiveTypeAlias(name) => format!("types.{name}"), FieldType::Class(name) => format!("types.{name}"), FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir, _with_checked)), FieldType::Map(key, value) => { @@ -255,6 +256,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } FieldType::Class(name) => format!("partial_types.{name}"), + FieldType::RecursiveTypeAlias(name) => format!("types.{name}"), FieldType::Literal(value) => to_python_literal(value), FieldType::List(inner) => { format!("List[{}]", inner.to_partial_type_ref(ir, with_checked)) diff --git a/engine/language_client_codegen/src/python/templates/types.py.j2 b/engine/language_client_codegen/src/python/templates/types.py.j2 index 86b776db3..bbc7a7e9a 100644 --- a/engine/language_client_codegen/src/python/templates/types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/types.py.j2 @@ -2,7 +2,7 @@ import baml_py from enum import Enum from pydantic import BaseModel, ConfigDict -from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union +from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union, TypeAlias T = TypeVar('T') @@ -59,3 +59,8 @@ class {{cls.name}}(BaseModel): {%- endif %} {%- endfor %} {% endfor %} + +{#- Type Aliases -#} +{% for alias in structural_recursive_alias_cycles %} +{{alias.name}}: TypeAlias = {{alias.target}} +{% endfor %} diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index cb2115910..c6cdba590 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -9,6 +9,9 @@ impl ToRuby for FieldType { match self { FieldType::Class(name) => format!("Baml::Types::{}", name.clone()), FieldType::Enum(name) => format!("T.any(Baml::Types::{}, String)", name.clone()), + // Sorbet does not support recursive type aliases. + // https://sorbet.org/docs/type-aliases + FieldType::RecursiveTypeAlias(_name) => "T.anything".to_string(), // TODO: Temporary solution until we figure out Ruby literals. FieldType::Literal(value) => value.literal_base_type().to_ruby(), // https://sorbet.org/docs/stdlib-generics diff --git a/engine/language_client_codegen/src/ruby/generate_types.rs b/engine/language_client_codegen/src/ruby/generate_types.rs index 55419ba0f..495b7e924 100644 --- a/engine/language_client_codegen/src/ruby/generate_types.rs +++ b/engine/language_client_codegen/src/ruby/generate_types.rs @@ -168,6 +168,8 @@ impl ToTypeReferenceInTypeDefinition for FieldType { match self { FieldType::Class(name) => format!("Baml::PartialTypes::{}", name.clone()), FieldType::Enum(name) => format!("T.nilable(Baml::Types::{})", name.clone()), + // TODO: Can we define recursive aliases in Ruby with Sorbet? + FieldType::RecursiveTypeAlias(_name) => "T.anything".to_string(), // TODO: Temporary solution until we figure out Ruby literals. FieldType::Literal(value) => value.literal_base_type().to_partial_type_ref(), // https://sorbet.org/docs/stdlib-generics diff --git a/engine/language_client_codegen/src/typescript/generate_types.rs b/engine/language_client_codegen/src/typescript/generate_types.rs index 0022f110a..195764bc5 100644 --- a/engine/language_client_codegen/src/typescript/generate_types.rs +++ b/engine/language_client_codegen/src/typescript/generate_types.rs @@ -4,8 +4,8 @@ use anyhow::Result; use itertools::Itertools; use internal_baml_core::ir::{ - repr::{Docstring, IntermediateRepr}, - ClassWalker, EnumWalker, + repr::{Docstring, IntermediateRepr, Walker}, + ClassWalker, EnumWalker, FieldType, }; use crate::{type_check_attributes, GeneratorArgs, TypeCheckAttributes}; @@ -24,6 +24,7 @@ pub(crate) struct TypeBuilder<'ir> { pub(crate) struct TypescriptTypes<'ir> { enums: Vec>, classes: Vec>, + structural_recursive_alias_cycles: Vec>, } struct TypescriptEnum<'ir> { @@ -40,6 +41,11 @@ pub struct TypescriptClass<'ir> { pub docstring: Option, } +struct TypescriptTypeAlias<'ir> { + name: Cow<'ir, str>, + target: String, +} + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTypes<'ir> { type Error = anyhow::Error; @@ -55,6 +61,10 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTyp .walk_classes() .map(|e| Into::::into(&e)) .collect::>(), + structural_recursive_alias_cycles: ir + .walk_alias_cycles() + .map(TypescriptTypeAlias::from) + .collect::>(), }) } } @@ -132,6 +142,21 @@ impl<'ir> From<&ClassWalker<'ir>> for TypescriptClass<'ir> { } } +// TODO: Define AliasWalker to simplify type. +impl<'ir> From> for TypescriptTypeAlias<'ir> { + fn from( + Walker { + db, + item: (name, target), + }: Walker<(&'ir String, &'ir FieldType)>, + ) -> Self { + Self { + name: Cow::Borrowed(name), + target: target.to_type_ref(db), + } + } +} + pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { checks .0 diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index d4ffcb93e..e3178534f 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -267,6 +267,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } FieldType::Class(name) => name.to_string(), + FieldType::RecursiveTypeAlias(name) => name.to_owned(), FieldType::List(inner) => match inner.as_ref() { FieldType::Union(_) | FieldType::Optional(_) => { format!("({})[]", inner.to_type_ref(ir)) diff --git a/engine/language_client_codegen/src/typescript/templates/types.ts.j2 b/engine/language_client_codegen/src/typescript/templates/types.ts.j2 index 91ab34165..308967604 100644 --- a/engine/language_client_codegen/src/typescript/templates/types.ts.j2 +++ b/engine/language_client_codegen/src/typescript/templates/types.ts.j2 @@ -52,3 +52,8 @@ export interface {{cls.name}} { {%- endif %} } {% endfor %} + +{#- Type Aliases -#} +{% for alias in structural_recursive_alias_cycles %} +type {{alias.name}} = {{alias.target}} +{% endfor %} diff --git a/fern/03-reference/baml/types.mdx b/fern/03-reference/baml/types.mdx index 6d97f1893..e4b4cf624 100644 --- a/fern/03-reference/baml/types.mdx +++ b/fern/03-reference/baml/types.mdx @@ -356,6 +356,22 @@ map<"A" | "B" | "C", string> - Not yet supported. Use a `class` instead. +## Type Aliases + +A _type alias_ is an alternative name for an existing type. It can be used to +simplify complex types or to give a more descriptive name to a type. Type +aliases are defined using the `type` keyword: + +```baml +type Graph = map +``` + +Type aliases can point to other aliases: + +```baml +type DataStructure = string[] | Graph +``` + ## Examples and Equivalents Here are some examples and what their equivalents are in different languages. diff --git a/integ-tests/baml_src/test-files/functions/output/literal-unions.baml b/integ-tests/baml_src/test-files/functions/output/literal-unions.baml index 0712b6f83..a3dbd0d0c 100644 --- a/integ-tests/baml_src/test-files/functions/output/literal-unions.baml +++ b/integ-tests/baml_src/test-files/functions/output/literal-unions.baml @@ -1,7 +1,7 @@ function LiteralUnionsTest(input: string) -> 1 | true | "string output" { client GPT35 prompt #" - Return one of these values: + Return one of these values without any additional context: {{ctx.output_format}} "# } diff --git a/integ-tests/baml_src/test-files/functions/output/recursive-type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/recursive-type-aliases.baml new file mode 100644 index 000000000..d308bd449 --- /dev/null +++ b/integ-tests/baml_src/test-files/functions/output/recursive-type-aliases.baml @@ -0,0 +1,54 @@ +class LinkedListAliasNode { + value int + next LinkedListAliasNode? +} + +// Simple alias that points to recursive type. +type LinkedListAlias = LinkedListAliasNode + +function AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias { + client "openai/gpt-4o" + prompt r#" + Return the given linked list back: + + {{ list }} + + {{ ctx.output_format }} + "# +} + +// Class that points to an alias that points to a recursive type. +class ClassToRecAlias { + list LinkedListAlias +} + +function ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias { + client "openai/gpt-4o" + prompt r#" + Return the given object back: + + {{ cls }} + + {{ ctx.output_format }} + "# +} + +// This is tricky cause this class should be hoisted, but classes and aliases +// are two different types in the AST. This test will make sure they can interop. +class NodeWithAliasIndirection { + value int + next NodeIndirection? +} + +type NodeIndirection = NodeWithAliasIndirection + +function RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection { + client "openai/gpt-4o" + prompt r#" + Return the given object back: + + {{ cls }} + + {{ ctx.output_format }} + "# +} diff --git a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml new file mode 100644 index 000000000..2e407303d --- /dev/null +++ b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml @@ -0,0 +1,136 @@ +type Primitive = int | string | bool | float + +type List = string[] + +type Graph = map + +type Combination = Primitive | List | Graph + +function PrimitiveAlias(p: Primitive) -> Primitive { + client "openai/gpt-4o" + prompt r#" + Return the given value back: {{ p }} + "# +} + +function MapAlias(m: Graph) -> Graph { + client "openai/gpt-4o" + prompt r#" + Return the given Graph back: + + {{ m }} + + {{ ctx.output_format }} + "# +} + +function NestedAlias(c: Combination) -> Combination { + client "openai/gpt-4o" + prompt r#" + Return the given value back: + + {{ c }} + + {{ ctx.output_format }} + "# +} + +// Test attribute merging. +type Currency = int @check(gt_ten, {{ this > 10 }}) +type Amount = Currency @assert({{ this > 0 }}) + +class MergeAttrs { + amount Amount @description("In USD") +} + +// This should be allowed. +type MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }}) + +function MergeAliasAttributes(money: int) -> MergeAttrs { + client "openai/gpt-4o" + prompt r#" + Return the given integer in the specified format: + + {{ money }} + + {{ ctx.output_format }} + "# +} + +function ReturnAliasWithMergedAttributes(money: Amount) -> Amount { + client "openai/gpt-4o" + prompt r#" + Return the given integer without additional context: + + {{ money }} + + {{ ctx.output_format }} + "# +} + +function AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs { + client "openai/gpt-4o" + prompt r#" + Return the given integer without additional context: + + {{ money }} + + {{ ctx.output_format }} + "# +} + +type RecursiveMapAlias = map + +function SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias { + client "openai/gpt-4o" + prompt r#" + Return the given value: + + {{ input }} + + {{ ctx.output_format }} + "# +} + +type RecursiveListAlias = RecursiveListAlias[] + +function SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias { + client "openai/gpt-4o" + prompt r#" + Return the given JSON array: + + {{ input }} + + {{ ctx.output_format }} + "# +} + +type RecAliasOne = RecAliasTwo +type RecAliasTwo = RecAliasThree +type RecAliasThree = RecAliasOne[] + +function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne { + client "openai/gpt-4o" + prompt r#" + Return the given JSON array: + + {{ input }} + + {{ ctx.output_format }} + "# +} + +type JsonValue = int | string | bool | float | JsonObject | JsonArray +type JsonObject = map +type JsonArray = JsonValue[] + +function JsonTypeAliasCycle(input: JsonValue) -> JsonValue { + client "openai/gpt-4o" + prompt r#" + Return the given input back: + + {{ input }} + + {{ ctx.output_format }} + "# +} diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index 86c8518d2..4228a3f81 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -73,6 +73,52 @@ async def AaaSamOutputFormat( ) return cast(types.Recipe, raw.cast_to(types, types)) + async def AliasThatPointsToRecursiveType( + self, + list: types.LinkedListAliasNode, + baml_options: BamlCallOptions = {}, + ) -> types.LinkedListAliasNode: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "AliasThatPointsToRecursiveType", + { + "list": list, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.LinkedListAliasNode, raw.cast_to(types, types)) + + async def AliasWithMultipleAttrs( + self, + money: Checked[int,types.Literal["gt_ten"]], + baml_options: BamlCallOptions = {}, + ) -> Checked[int,types.Literal["gt_ten"]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "AliasWithMultipleAttrs", + { + "money": money, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + async def AliasedInputClass( self, input: types.InputClass, @@ -280,6 +326,29 @@ async def BuildTree( ) return cast(types.Tree, raw.cast_to(types, types)) + async def ClassThatPointsToRecursiveClassThroughAlias( + self, + cls: types.ClassToRecAlias, + baml_options: BamlCallOptions = {}, + ) -> types.ClassToRecAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.ClassToRecAlias, raw.cast_to(types, types)) + async def ClassifyDynEnumTwo( self, input: str, @@ -1407,6 +1476,29 @@ async def InOutSingleLiteralStringMapKey( ) return cast(Dict[Literal["key"], str], raw.cast_to(types, types)) + async def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> types.JsonValue: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "JsonTypeAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.JsonValue, raw.cast_to(types, types)) + async def LiteralUnionsTest( self, input: str, @@ -1476,6 +1568,52 @@ async def MakeNestedBlockConstraint( ) return cast(types.NestedBlockConstraint, raw.cast_to(types, types)) + async def MapAlias( + self, + m: Dict[str, List[str]], + baml_options: BamlCallOptions = {}, + ) -> Dict[str, List[str]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "MapAlias", + { + "m": m, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Dict[str, List[str]], raw.cast_to(types, types)) + + async def MergeAliasAttributes( + self, + money: int, + baml_options: BamlCallOptions = {}, + ) -> types.MergeAttrs: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "MergeAliasAttributes", + { + "money": money, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.MergeAttrs, raw.cast_to(types, types)) + async def MyFunc( self, input: str, @@ -1499,6 +1637,29 @@ async def MyFunc( ) return cast(types.DynamicOutput, raw.cast_to(types, types)) + async def NestedAlias( + self, + c: Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], + baml_options: BamlCallOptions = {}, + ) -> Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "NestedAlias", + { + "c": c, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types)) + async def OptionalTest_Function( self, input: str, @@ -1568,6 +1729,29 @@ async def PredictAgeBare( ) return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types)) + async def PrimitiveAlias( + self, + p: Union[int, str, bool, float], + baml_options: BamlCallOptions = {}, + ) -> Union[int, str, bool, float]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "PrimitiveAlias", + { + "p": p, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Union[int, str, bool, float], raw.cast_to(types, types)) + async def PromptTestClaude( self, input: str, @@ -1729,6 +1913,75 @@ async def PromptTestStreaming( ) return cast(str, raw.cast_to(types, types)) + async def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> types.RecAliasOne: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "RecursiveAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecAliasOne, raw.cast_to(types, types)) + + async def RecursiveClassWithAliasIndirection( + self, + cls: types.NodeWithAliasIndirection, + baml_options: BamlCallOptions = {}, + ) -> types.NodeWithAliasIndirection: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "RecursiveClassWithAliasIndirection", + { + "cls": cls, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.NodeWithAliasIndirection, raw.cast_to(types, types)) + + async def ReturnAliasWithMergedAttributes( + self, + money: Checked[int,types.Literal["gt_ten"]], + baml_options: BamlCallOptions = {}, + ) -> Checked[int,types.Literal["gt_ten"]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "ReturnAliasWithMergedAttributes", + { + "money": money, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + async def ReturnFailingAssert( self, inp: int, @@ -1798,6 +2051,52 @@ async def SchemaDescriptions( ) return cast(types.Schema, raw.cast_to(types, types)) + async def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveListAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "SimpleRecursiveListAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveListAlias, raw.cast_to(types, types)) + + async def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveMapAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveMapAlias, raw.cast_to(types, types)) + async def StreamBigNumbers( self, digits: int, @@ -2851,11 +3150,11 @@ def AaaSamOutputFormat( self.__ctx_manager.get(), ) - def AliasedInputClass( + def AliasThatPointsToRecursiveType( self, - input: types.InputClass, + list: types.LinkedListAliasNode, baml_options: BamlCallOptions = {}, - ) -> baml_py.BamlStream[Optional[str], str]: + ) -> baml_py.BamlStream[partial_types.LinkedListAliasNode, types.LinkedListAliasNode]: __tb__ = baml_options.get("tb", None) if __tb__ is not None: tb = __tb__._tb # type: ignore (we know how to use this private attribute) @@ -2864,9 +3163,9 @@ def AliasedInputClass( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function( - "AliasedInputClass", + "AliasThatPointsToRecursiveType", { - "input": input, + "list": list, }, None, self.__ctx_manager.get(), @@ -2874,18 +3173,18 @@ def AliasedInputClass( __cr__, ) - return baml_py.BamlStream[Optional[str], str]( + return baml_py.BamlStream[partial_types.LinkedListAliasNode, types.LinkedListAliasNode]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(partial_types.LinkedListAliasNode, x.cast_to(types, partial_types)), + lambda x: cast(types.LinkedListAliasNode, x.cast_to(types, types)), self.__ctx_manager.get(), ) - def AliasedInputClass2( + def AliasWithMultipleAttrs( self, - input: types.InputClass, + money: Checked[int,types.Literal["gt_ten"]], baml_options: BamlCallOptions = {}, - ) -> baml_py.BamlStream[Optional[str], str]: + ) -> baml_py.BamlStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]: __tb__ = baml_options.get("tb", None) if __tb__ is not None: tb = __tb__._tb # type: ignore (we know how to use this private attribute) @@ -2894,9 +3193,9 @@ def AliasedInputClass2( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function( - "AliasedInputClass2", + "AliasWithMultipleAttrs", { - "input": input, + "money": money, }, None, self.__ctx_manager.get(), @@ -2904,16 +3203,16 @@ def AliasedInputClass2( __cr__, ) - return baml_py.BamlStream[Optional[str], str]( + return baml_py.BamlStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), self.__ctx_manager.get(), ) - def AliasedInputClassNested( + def AliasedInputClass( self, - input: types.InputClassNested, + input: types.InputClass, baml_options: BamlCallOptions = {}, ) -> baml_py.BamlStream[Optional[str], str]: __tb__ = baml_options.get("tb", None) @@ -2924,7 +3223,7 @@ def AliasedInputClassNested( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function( - "AliasedInputClassNested", + "AliasedInputClass", { "input": input, }, @@ -2941,9 +3240,9 @@ def AliasedInputClassNested( self.__ctx_manager.get(), ) - def AliasedInputEnum( + def AliasedInputClass2( self, - input: types.AliasedEnum, + input: types.InputClass, baml_options: BamlCallOptions = {}, ) -> baml_py.BamlStream[Optional[str], str]: __tb__ = baml_options.get("tb", None) @@ -2954,7 +3253,7 @@ def AliasedInputEnum( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function( - "AliasedInputEnum", + "AliasedInputClass2", { "input": input, }, @@ -2971,7 +3270,67 @@ def AliasedInputEnum( self.__ctx_manager.get(), ) - def AliasedInputList( + def AliasedInputClassNested( + self, + input: types.InputClassNested, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Optional[str], str]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "AliasedInputClassNested", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[Optional[str], str]( + raw, + lambda x: cast(Optional[str], x.cast_to(types, partial_types)), + lambda x: cast(str, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def AliasedInputEnum( + self, + input: types.AliasedEnum, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Optional[str], str]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "AliasedInputEnum", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[Optional[str], str]( + raw, + lambda x: cast(Optional[str], x.cast_to(types, partial_types)), + lambda x: cast(str, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def AliasedInputList( self, input: List[types.AliasedEnum], baml_options: BamlCallOptions = {}, @@ -3121,6 +3480,36 @@ def BuildTree( self.__ctx_manager.get(), ) + def ClassThatPointsToRecursiveClassThroughAlias( + self, + cls: types.ClassToRecAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.ClassToRecAlias, types.ClassToRecAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[partial_types.ClassToRecAlias, types.ClassToRecAlias]( + raw, + lambda x: cast(partial_types.ClassToRecAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.ClassToRecAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def ClassifyDynEnumTwo( self, input: str, @@ -4598,6 +4987,36 @@ def InOutSingleLiteralStringMapKey( self.__ctx_manager.get(), ) + def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.JsonValue, types.JsonValue]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "JsonTypeAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.JsonValue, types.JsonValue]( + raw, + lambda x: cast(types.JsonValue, x.cast_to(types, partial_types)), + lambda x: cast(types.JsonValue, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def LiteralUnionsTest( self, input: str, @@ -4686,6 +5105,66 @@ def MakeNestedBlockConstraint( self.__ctx_manager.get(), ) + def MapAlias( + self, + m: Dict[str, List[str]], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "MapAlias", + { + "m": m, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]( + raw, + lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, partial_types)), + lambda x: cast(Dict[str, List[str]], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def MergeAliasAttributes( + self, + money: int, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.MergeAttrs, types.MergeAttrs]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "MergeAliasAttributes", + { + "money": money, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[partial_types.MergeAttrs, types.MergeAttrs]( + raw, + lambda x: cast(partial_types.MergeAttrs, x.cast_to(types, partial_types)), + lambda x: cast(types.MergeAttrs, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def MyFunc( self, input: str, @@ -4716,6 +5195,36 @@ def MyFunc( self.__ctx_manager.get(), ) + def NestedAlias( + self, + c: Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "NestedAlias", + { + "c": c, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]( + raw, + lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, partial_types)), + lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def OptionalTest_Function( self, input: str, @@ -4806,6 +5315,36 @@ def PredictAgeBare( self.__ctx_manager.get(), ) + def PrimitiveAlias( + self, + p: Union[int, str, bool, float], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "PrimitiveAlias", + { + "p": p, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]( + raw, + lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, partial_types)), + lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def PromptTestClaude( self, input: str, @@ -5016,6 +5555,96 @@ def PromptTestStreaming( self.__ctx_manager.get(), ) + def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "RecursiveAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne]( + raw, + lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def RecursiveClassWithAliasIndirection( + self, + cls: types.NodeWithAliasIndirection, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.NodeWithAliasIndirection, types.NodeWithAliasIndirection]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "RecursiveClassWithAliasIndirection", + { + "cls": cls, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[partial_types.NodeWithAliasIndirection, types.NodeWithAliasIndirection]( + raw, + lambda x: cast(partial_types.NodeWithAliasIndirection, x.cast_to(types, partial_types)), + lambda x: cast(types.NodeWithAliasIndirection, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def ReturnAliasWithMergedAttributes( + self, + money: Checked[int,types.Literal["gt_ten"]], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "ReturnAliasWithMergedAttributes", + { + "money": money, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( + raw, + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def ReturnFailingAssert( self, inp: int, @@ -5106,6 +5735,66 @@ def SchemaDescriptions( self.__ctx_manager.get(), ) + def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "SimpleRecursiveListAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias]( + raw, + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.RecursiveMapAlias, types.RecursiveMapAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.RecursiveMapAlias, types.RecursiveMapAlias]( + raw, + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def StreamBigNumbers( self, digits: int, diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index c12ec8ca1..f306ae9a9 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -74,15 +74,17 @@ "test-files/functions/output/literal-boolean.baml": "function FnOutputLiteralBool(input: string) -> false {\n client GPT35\n prompt #\"\n Return a false: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralBool {\n functions [FnOutputLiteralBool]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/literal-int.baml": "function FnOutputLiteralInt(input: string) -> 5 {\n client GPT35\n prompt #\"\n Return an integer: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralInt {\n functions [FnOutputLiteralInt]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/literal-string.baml": "function FnOutputLiteralString(input: string) -> \"example output\" {\n client GPT35\n prompt #\"\n Return a string: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralString {\n functions [FnOutputLiteralString]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/literal-unions.baml": "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/literal-unions.baml": "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values without any additional context: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/map-enum-key.baml": "enum MapKey {\n A\n B\n C\n}\n\nfunction InOutEnumMapKey(i1: map, i2: map) -> map {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these: {{i1}} {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/map-literal-union-key.baml": "function InOutLiteralStringUnionMapKey(\n i1: map<\"one\" | \"two\" | (\"three\" | \"four\"), string>, \n i2: map<\"one\" | \"two\" | (\"three\" | \"four\"), string>\n) -> map<\"one\" | \"two\" | (\"three\" | \"four\"), string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these:\n \n {{i1}}\n \n {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction InOutSingleLiteralStringMapKey(m: map<\"key\", string>) -> map<\"key\", string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Return the same map you were given:\n \n {{m}}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/mutually-recursive-classes.baml": "class Tree {\n data int\n children Forest\n}\n\nclass Forest {\n trees Tree[]\n}\n\nclass BinaryNode {\n data int\n left BinaryNode?\n right BinaryNode?\n}\n\nfunction BuildTree(input: BinaryNode) -> Tree {\n client GPT35\n prompt #\"\n Given the input binary tree, transform it into a generic tree using the given schema.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestTree {\n functions [BuildTree]\n args {\n input {\n data 2\n left {\n data 1\n left null\n right null\n }\n right {\n data 3\n left null\n right null\n }\n }\n }\n}", "test-files/functions/output/optional-class.baml": "class ClassOptionalOutput {\n prop1 string\n prop2 string\n}\n\nfunction FnClassOptionalOutput(input: string) -> ClassOptionalOutput? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\n\nclass Blah {\n prop4 string?\n}\n\nclass ClassOptionalOutput2 {\n prop1 string?\n prop2 string?\n prop3 Blah?\n}\n\nfunction FnClassOptionalOutput2(input: string) -> ClassOptionalOutput2? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnClassOptionalOutput2 {\n functions [FnClassOptionalOutput2, FnClassOptionalOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/optional.baml": "class OptionalTest_Prop1 {\n omega_a string\n omega_b int\n}\n\nenum OptionalTest_CategoryType {\n Aleph\n Beta\n Gamma\n}\n \nclass OptionalTest_ReturnType {\n omega_1 OptionalTest_Prop1?\n omega_2 string?\n omega_3 (OptionalTest_CategoryType?)[]\n} \n \nfunction OptionalTest_Function(input: string) -> (OptionalTest_ReturnType?)[]\n{ \n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest OptionalTest_Function {\n functions [OptionalTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/recursive-class.baml": "class Node {\n data int\n next Node?\n}\n\nclass LinkedList {\n head Node?\n len int\n}\n\nclient O1 {\n provider \"openai\"\n options {\n model \"o1-mini\"\n default_role \"user\"\n }\n}\n\nfunction BuildLinkedList(input: int[]) -> LinkedList {\n client O1\n prompt #\"\n Build a linked list from the input array of integers.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestLinkedList {\n functions [BuildLinkedList]\n args {\n input [1, 2, 3, 4, 5]\n }\n}\n", + "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/python/baml_client/partial_types.py b/integ-tests/python/baml_client/partial_types.py index 6c5c20e54..9ce8a02c0 100644 --- a/integ-tests/python/baml_client/partial_types.py +++ b/integ-tests/python/baml_client/partial_types.py @@ -64,6 +64,9 @@ class ClassOptionalOutput2(BaseModel): prop2: Optional[str] = None prop3: Optional["Blah"] = None +class ClassToRecAlias(BaseModel): + list: Optional["LinkedListAliasNode"] = None + class ClassWithImage(BaseModel): myImage: Optional[baml_py.Image] = None param2: Optional[str] = None @@ -173,6 +176,10 @@ class LinkedList(BaseModel): head: Optional["Node"] = None len: Optional[int] = None +class LinkedListAliasNode(BaseModel): + value: Optional[int] = None + next: Optional["LinkedListAliasNode"] = None + class LiteralClassHello(BaseModel): prop: Literal["hello"] @@ -195,6 +202,9 @@ class Martian(BaseModel): """The age of the Martian in Mars years. So many Mars years.""" +class MergeAttrs(BaseModel): + amount: Checked[Optional[int],Literal["gt_ten"]] + class NamedArgsSingleClass(BaseModel): key: Optional[str] = None key_two: Optional[bool] = None @@ -219,6 +229,10 @@ class Node(BaseModel): data: Optional[int] = None next: Optional["Node"] = None +class NodeWithAliasIndirection(BaseModel): + value: Optional[int] = None + next: Optional["NodeWithAliasIndirection"] = None + class OptionalListAndMap(BaseModel): p: List[Optional[str]] q: Dict[str, Optional[str]] diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index c758d3414..41137df90 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -70,6 +70,52 @@ def AaaSamOutputFormat( ) return cast(types.Recipe, raw.cast_to(types, types)) + def AliasThatPointsToRecursiveType( + self, + list: types.LinkedListAliasNode, + baml_options: BamlCallOptions = {}, + ) -> types.LinkedListAliasNode: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "AliasThatPointsToRecursiveType", + { + "list": list, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.LinkedListAliasNode, raw.cast_to(types, types)) + + def AliasWithMultipleAttrs( + self, + money: Checked[int,types.Literal["gt_ten"]], + baml_options: BamlCallOptions = {}, + ) -> Checked[int,types.Literal["gt_ten"]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "AliasWithMultipleAttrs", + { + "money": money, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + def AliasedInputClass( self, input: types.InputClass, @@ -277,6 +323,29 @@ def BuildTree( ) return cast(types.Tree, raw.cast_to(types, types)) + def ClassThatPointsToRecursiveClassThroughAlias( + self, + cls: types.ClassToRecAlias, + baml_options: BamlCallOptions = {}, + ) -> types.ClassToRecAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.ClassToRecAlias, raw.cast_to(types, types)) + def ClassifyDynEnumTwo( self, input: str, @@ -1404,6 +1473,29 @@ def InOutSingleLiteralStringMapKey( ) return cast(Dict[Literal["key"], str], raw.cast_to(types, types)) + def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> types.JsonValue: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "JsonTypeAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.JsonValue, raw.cast_to(types, types)) + def LiteralUnionsTest( self, input: str, @@ -1473,6 +1565,52 @@ def MakeNestedBlockConstraint( ) return cast(types.NestedBlockConstraint, raw.cast_to(types, types)) + def MapAlias( + self, + m: Dict[str, List[str]], + baml_options: BamlCallOptions = {}, + ) -> Dict[str, List[str]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "MapAlias", + { + "m": m, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Dict[str, List[str]], raw.cast_to(types, types)) + + def MergeAliasAttributes( + self, + money: int, + baml_options: BamlCallOptions = {}, + ) -> types.MergeAttrs: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "MergeAliasAttributes", + { + "money": money, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.MergeAttrs, raw.cast_to(types, types)) + def MyFunc( self, input: str, @@ -1496,6 +1634,29 @@ def MyFunc( ) return cast(types.DynamicOutput, raw.cast_to(types, types)) + def NestedAlias( + self, + c: Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], + baml_options: BamlCallOptions = {}, + ) -> Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "NestedAlias", + { + "c": c, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types)) + def OptionalTest_Function( self, input: str, @@ -1565,6 +1726,29 @@ def PredictAgeBare( ) return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types)) + def PrimitiveAlias( + self, + p: Union[int, str, bool, float], + baml_options: BamlCallOptions = {}, + ) -> Union[int, str, bool, float]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "PrimitiveAlias", + { + "p": p, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Union[int, str, bool, float], raw.cast_to(types, types)) + def PromptTestClaude( self, input: str, @@ -1726,6 +1910,75 @@ def PromptTestStreaming( ) return cast(str, raw.cast_to(types, types)) + def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> types.RecAliasOne: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "RecursiveAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecAliasOne, raw.cast_to(types, types)) + + def RecursiveClassWithAliasIndirection( + self, + cls: types.NodeWithAliasIndirection, + baml_options: BamlCallOptions = {}, + ) -> types.NodeWithAliasIndirection: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "RecursiveClassWithAliasIndirection", + { + "cls": cls, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.NodeWithAliasIndirection, raw.cast_to(types, types)) + + def ReturnAliasWithMergedAttributes( + self, + money: Checked[int,types.Literal["gt_ten"]], + baml_options: BamlCallOptions = {}, + ) -> Checked[int,types.Literal["gt_ten"]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "ReturnAliasWithMergedAttributes", + { + "money": money, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + def ReturnFailingAssert( self, inp: int, @@ -1795,6 +2048,52 @@ def SchemaDescriptions( ) return cast(types.Schema, raw.cast_to(types, types)) + def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveListAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "SimpleRecursiveListAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveListAlias, raw.cast_to(types, types)) + + def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveMapAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveMapAlias, raw.cast_to(types, types)) + def StreamBigNumbers( self, digits: int, @@ -2849,11 +3148,11 @@ def AaaSamOutputFormat( self.__ctx_manager.get(), ) - def AliasedInputClass( + def AliasThatPointsToRecursiveType( self, - input: types.InputClass, + list: types.LinkedListAliasNode, baml_options: BamlCallOptions = {}, - ) -> baml_py.BamlSyncStream[Optional[str], str]: + ) -> baml_py.BamlSyncStream[partial_types.LinkedListAliasNode, types.LinkedListAliasNode]: __tb__ = baml_options.get("tb", None) if __tb__ is not None: tb = __tb__._tb # type: ignore (we know how to use this private attribute) @@ -2862,9 +3161,9 @@ def AliasedInputClass( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function_sync( - "AliasedInputClass", + "AliasThatPointsToRecursiveType", { - "input": input, + "list": list, }, None, self.__ctx_manager.get(), @@ -2872,18 +3171,18 @@ def AliasedInputClass( __cr__, ) - return baml_py.BamlSyncStream[Optional[str], str]( + return baml_py.BamlSyncStream[partial_types.LinkedListAliasNode, types.LinkedListAliasNode]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(partial_types.LinkedListAliasNode, x.cast_to(types, partial_types)), + lambda x: cast(types.LinkedListAliasNode, x.cast_to(types, types)), self.__ctx_manager.get(), ) - def AliasedInputClass2( + def AliasWithMultipleAttrs( self, - input: types.InputClass, + money: Checked[int,types.Literal["gt_ten"]], baml_options: BamlCallOptions = {}, - ) -> baml_py.BamlSyncStream[Optional[str], str]: + ) -> baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]: __tb__ = baml_options.get("tb", None) if __tb__ is not None: tb = __tb__._tb # type: ignore (we know how to use this private attribute) @@ -2892,9 +3191,9 @@ def AliasedInputClass2( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function_sync( - "AliasedInputClass2", + "AliasWithMultipleAttrs", { - "input": input, + "money": money, }, None, self.__ctx_manager.get(), @@ -2902,16 +3201,16 @@ def AliasedInputClass2( __cr__, ) - return baml_py.BamlSyncStream[Optional[str], str]( + return baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), self.__ctx_manager.get(), ) - def AliasedInputClassNested( + def AliasedInputClass( self, - input: types.InputClassNested, + input: types.InputClass, baml_options: BamlCallOptions = {}, ) -> baml_py.BamlSyncStream[Optional[str], str]: __tb__ = baml_options.get("tb", None) @@ -2922,7 +3221,7 @@ def AliasedInputClassNested( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function_sync( - "AliasedInputClassNested", + "AliasedInputClass", { "input": input, }, @@ -2939,9 +3238,9 @@ def AliasedInputClassNested( self.__ctx_manager.get(), ) - def AliasedInputEnum( + def AliasedInputClass2( self, - input: types.AliasedEnum, + input: types.InputClass, baml_options: BamlCallOptions = {}, ) -> baml_py.BamlSyncStream[Optional[str], str]: __tb__ = baml_options.get("tb", None) @@ -2952,7 +3251,7 @@ def AliasedInputEnum( __cr__ = baml_options.get("client_registry", None) raw = self.__runtime.stream_function_sync( - "AliasedInputEnum", + "AliasedInputClass2", { "input": input, }, @@ -2969,7 +3268,67 @@ def AliasedInputEnum( self.__ctx_manager.get(), ) - def AliasedInputList( + def AliasedInputClassNested( + self, + input: types.InputClassNested, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Optional[str], str]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "AliasedInputClassNested", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[Optional[str], str]( + raw, + lambda x: cast(Optional[str], x.cast_to(types, partial_types)), + lambda x: cast(str, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def AliasedInputEnum( + self, + input: types.AliasedEnum, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Optional[str], str]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "AliasedInputEnum", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[Optional[str], str]( + raw, + lambda x: cast(Optional[str], x.cast_to(types, partial_types)), + lambda x: cast(str, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def AliasedInputList( self, input: List[types.AliasedEnum], baml_options: BamlCallOptions = {}, @@ -3119,6 +3478,36 @@ def BuildTree( self.__ctx_manager.get(), ) + def ClassThatPointsToRecursiveClassThroughAlias( + self, + cls: types.ClassToRecAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.ClassToRecAlias, types.ClassToRecAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[partial_types.ClassToRecAlias, types.ClassToRecAlias]( + raw, + lambda x: cast(partial_types.ClassToRecAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.ClassToRecAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def ClassifyDynEnumTwo( self, input: str, @@ -4596,6 +4985,36 @@ def InOutSingleLiteralStringMapKey( self.__ctx_manager.get(), ) + def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.JsonValue, types.JsonValue]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "JsonTypeAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.JsonValue, types.JsonValue]( + raw, + lambda x: cast(types.JsonValue, x.cast_to(types, partial_types)), + lambda x: cast(types.JsonValue, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def LiteralUnionsTest( self, input: str, @@ -4684,6 +5103,66 @@ def MakeNestedBlockConstraint( self.__ctx_manager.get(), ) + def MapAlias( + self, + m: Dict[str, List[str]], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "MapAlias", + { + "m": m, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]( + raw, + lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, partial_types)), + lambda x: cast(Dict[str, List[str]], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def MergeAliasAttributes( + self, + money: int, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.MergeAttrs, types.MergeAttrs]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "MergeAliasAttributes", + { + "money": money, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[partial_types.MergeAttrs, types.MergeAttrs]( + raw, + lambda x: cast(partial_types.MergeAttrs, x.cast_to(types, partial_types)), + lambda x: cast(types.MergeAttrs, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def MyFunc( self, input: str, @@ -4714,6 +5193,36 @@ def MyFunc( self.__ctx_manager.get(), ) + def NestedAlias( + self, + c: Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "NestedAlias", + { + "c": c, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]( + raw, + lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, partial_types)), + lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def OptionalTest_Function( self, input: str, @@ -4804,6 +5313,36 @@ def PredictAgeBare( self.__ctx_manager.get(), ) + def PrimitiveAlias( + self, + p: Union[int, str, bool, float], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "PrimitiveAlias", + { + "p": p, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]( + raw, + lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, partial_types)), + lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def PromptTestClaude( self, input: str, @@ -5014,6 +5553,96 @@ def PromptTestStreaming( self.__ctx_manager.get(), ) + def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.RecAliasOne, types.RecAliasOne]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "RecursiveAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.RecAliasOne, types.RecAliasOne]( + raw, + lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def RecursiveClassWithAliasIndirection( + self, + cls: types.NodeWithAliasIndirection, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.NodeWithAliasIndirection, types.NodeWithAliasIndirection]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "RecursiveClassWithAliasIndirection", + { + "cls": cls, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[partial_types.NodeWithAliasIndirection, types.NodeWithAliasIndirection]( + raw, + lambda x: cast(partial_types.NodeWithAliasIndirection, x.cast_to(types, partial_types)), + lambda x: cast(types.NodeWithAliasIndirection, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def ReturnAliasWithMergedAttributes( + self, + money: Checked[int,types.Literal["gt_ten"]], + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "ReturnAliasWithMergedAttributes", + { + "money": money, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( + raw, + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def ReturnFailingAssert( self, inp: int, @@ -5104,6 +5733,66 @@ def SchemaDescriptions( self.__ctx_manager.get(), ) + def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.RecursiveListAlias, types.RecursiveListAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "SimpleRecursiveListAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.RecursiveListAlias, types.RecursiveListAlias]( + raw, + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + + def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.RecursiveMapAlias, types.RecursiveMapAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.RecursiveMapAlias, types.RecursiveMapAlias]( + raw, + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def StreamBigNumbers( self, digits: int, diff --git a/integ-tests/python/baml_client/type_builder.py b/integ-tests/python/baml_client/type_builder.py index 16b69fe31..741999408 100644 --- a/integ-tests/python/baml_client/type_builder.py +++ b/integ-tests/python/baml_client/type_builder.py @@ -20,7 +20,7 @@ class TypeBuilder(_TypeBuilder): def __init__(self): super().__init__(classes=set( - ["BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",] + ["BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",] ), enums=set( ["AliasedEnum","Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","MapKey","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum",] )) diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index d25a9cae5..6a68f8efd 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -16,7 +16,7 @@ import baml_py from enum import Enum from pydantic import BaseModel, ConfigDict -from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union +from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union, TypeAlias T = TypeVar('T') @@ -189,6 +189,9 @@ class ClassOptionalOutput2(BaseModel): prop2: Optional[str] = None prop3: Optional["Blah"] = None +class ClassToRecAlias(BaseModel): + list: "LinkedListAliasNode" + class ClassWithImage(BaseModel): myImage: baml_py.Image param2: str @@ -298,6 +301,10 @@ class LinkedList(BaseModel): head: Optional["Node"] = None len: int +class LinkedListAliasNode(BaseModel): + value: int + next: Optional["LinkedListAliasNode"] = None + class LiteralClassHello(BaseModel): prop: Literal["hello"] @@ -320,6 +327,9 @@ class Martian(BaseModel): """The age of the Martian in Mars years. So many Mars years.""" +class MergeAttrs(BaseModel): + amount: Checked[int,Literal["gt_ten"]] + class NamedArgsSingleClass(BaseModel): key: str key_two: bool @@ -344,6 +354,10 @@ class Node(BaseModel): data: int next: Optional["Node"] = None +class NodeWithAliasIndirection(BaseModel): + value: int + next: Optional["NodeWithAliasIndirection"] = None + class OptionalListAndMap(BaseModel): p: Optional[List[str]] = None q: Optional[Dict[str, str]] = None @@ -468,3 +482,19 @@ class UnionTest_ReturnType(BaseModel): class WithReasoning(BaseModel): value: str reasoning: str + +RecursiveMapAlias: TypeAlias = Dict[str, "RecursiveMapAlias"] + +RecursiveListAlias: TypeAlias = List["RecursiveListAlias"] + +RecAliasOne: TypeAlias = "RecAliasTwo" + +RecAliasTwo: TypeAlias = "RecAliasThree" + +RecAliasThree: TypeAlias = List["RecAliasOne"] + +JsonValue: TypeAlias = Union[int, str, bool, float, "JsonObject", "JsonArray"] + +JsonObject: TypeAlias = Dict[str, "JsonValue"] + +JsonArray: TypeAlias = List["JsonValue"] diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index b54076d9d..bf12b8674 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -40,7 +40,14 @@ BlockConstraintForParam, NestedBlockConstraintForParam, MapKey, +<<<<<<< HEAD + LinkedListAliasNode, + ClassToRecAlias, + NodeWithAliasIndirection, + MergeAttrs, +======= OptionalListAndMap, +>>>>>>> canary ) import baml_client.types as types from ..baml_client.tracing import trace, set_tags, flush, on_log_event @@ -258,6 +265,108 @@ async def test_single_literal_string_key_in_map(self): res = await b.InOutSingleLiteralStringMapKey({"key": "1"}) assert res["key"] == "1" + @pytest.mark.asyncio + async def test_primitive_union_alias(self): + res = await b.PrimitiveAlias("test") + assert res == "test" + + @pytest.mark.asyncio + async def test_map_alias(self): + res = await b.MapAlias({"A": ["B", "C"], "B": [], "C": []}) + assert res == {"A": ["B", "C"], "B": [], "C": []} + + @pytest.mark.asyncio + async def test_alias_union(self): + res = await b.NestedAlias("test") + assert res == "test" + + res = await b.NestedAlias({"A": ["B", "C"], "B": [], "C": []}) + assert res == {"A": ["B", "C"], "B": [], "C": []} + + @pytest.mark.asyncio + async def test_alias_pointing_to_recursive_class(self): + res = await b.AliasThatPointsToRecursiveType( + LinkedListAliasNode(value=1, next=None) + ) + assert res == LinkedListAliasNode(value=1, next=None) + + @pytest.mark.asyncio + async def test_class_pointing_to_alias_that_points_to_recursive_class(self): + res = await b.ClassThatPointsToRecursiveClassThroughAlias( + ClassToRecAlias(list=LinkedListAliasNode(value=1, next=None)) + ) + assert res == ClassToRecAlias(list=LinkedListAliasNode(value=1, next=None)) + + @pytest.mark.asyncio + async def test_recursive_class_with_alias_indirection(self): + res = await b.RecursiveClassWithAliasIndirection( + NodeWithAliasIndirection( + value=1, next=NodeWithAliasIndirection(value=2, next=None) + ) + ) + assert res == NodeWithAliasIndirection( + value=1, next=NodeWithAliasIndirection(value=2, next=None) + ) + + @pytest.mark.asyncio + async def test_merge_alias_attributes(self): + res = await b.MergeAliasAttributes(123) + assert res.amount.value == 123 + assert res.amount.checks["gt_ten"].status == "succeeded" + + @pytest.mark.asyncio + async def test_return_alias_with_merged_attrs(self): + res = await b.ReturnAliasWithMergedAttributes(123) + assert res.value == 123 + assert res.checks["gt_ten"].status == "succeeded" + + @pytest.mark.asyncio + async def test_alias_with_multiple_attrs(self): + res = await b.AliasWithMultipleAttrs(123) + assert res.value == 123 + assert res.checks["gt_ten"].status == "succeeded" + + @pytest.mark.asyncio + async def test_simple_recursive_map_alias(self): + res = await b.SimpleRecursiveMapAlias({"one": {"two": {"three": {}}}}) + assert res == {"one": {"two": {"three": {}}}} + + @pytest.mark.asyncio + async def test_simple_recursive_list_alias(self): + res = await b.SimpleRecursiveListAlias([[], [], [[]]]) + assert res == [[], [], [[]]] + + @pytest.mark.asyncio + async def test_recursive_alias_cycles(self): + res = await b.RecursiveAliasCycle([[], [], [[]]]) + assert res == [[], [], [[]]] + + @pytest.mark.asyncio + async def test_json_type_alias_cycle(self): + data = { + "number": 1, + "string": "test", + "bool": True, + "list": [1, 2, 3], + "object": {"number": 1, "string": "test", "bool": True, "list": [1, 2, 3]}, + "json": { + "number": 1, + "string": "test", + "bool": True, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": True, + "list": [1, 2, 3], + }, + }, + } + + res = await b.JsonTypeAliasCycle(data) + assert res == data + assert res["json"]["object"]["list"] == [1, 2, 3] + class MyCustomClass(NamedArgsSingleClass): date: datetime.datetime diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index 837931984..b59ac567e 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -82,6 +82,70 @@ def AaaSamOutputFormat( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + list: Baml::Types::LinkedListAliasNode, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::LinkedListAliasNode) + } + def AliasThatPointsToRecursiveType( + *varargs, + list:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("AliasThatPointsToRecursiveType may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "AliasThatPointsToRecursiveType", + { + list: list, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + money: Baml::Checked[Integer], + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Checked[Integer]) + } + def AliasWithMultipleAttrs( + *varargs, + money:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("AliasWithMultipleAttrs may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "AliasWithMultipleAttrs", + { + money: money, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -370,6 +434,38 @@ def BuildTree( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + cls: Baml::Types::ClassToRecAlias, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::ClassToRecAlias) + } + def ClassThatPointsToRecursiveClassThroughAlias( + *varargs, + cls:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ClassThatPointsToRecursiveClassThroughAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "ClassThatPointsToRecursiveClassThroughAlias", + { + cls: cls, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -1938,6 +2034,38 @@ def InOutSingleLiteralStringMapKey( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def JsonTypeAliasCycle( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("JsonTypeAliasCycle may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "JsonTypeAliasCycle", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -2034,6 +2162,70 @@ def MakeNestedBlockConstraint( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + m: T::Hash[String, T::Array[String]], + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T::Hash[String, T::Array[String]]) + } + def MapAlias( + *varargs, + m:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MapAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "MapAlias", + { + m: m, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + money: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::MergeAttrs) + } + def MergeAliasAttributes( + *varargs, + money:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MergeAliasAttributes may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "MergeAliasAttributes", + { + money: money, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -2066,6 +2258,38 @@ def MyFunc( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + c: T.any(T.any(Integer, String, T::Boolean, Float), T::Array[String], T::Hash[String, T::Array[String]]), + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.any(T.any(Integer, String, T::Boolean, Float), T::Array[String], T::Hash[String, T::Array[String]])) + } + def NestedAlias( + *varargs, + c:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("NestedAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "NestedAlias", + { + c: c, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -2162,6 +2386,38 @@ def PredictAgeBare( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + p: T.any(Integer, String, T::Boolean, Float), + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.any(Integer, String, T::Boolean, Float)) + } + def PrimitiveAlias( + *varargs, + p:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PrimitiveAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "PrimitiveAlias", + { + p: p, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -2389,27 +2645,27 @@ def PromptTestStreaming( sig { params( varargs: T.untyped, - inp: Integer, + input: T.anything, baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Integer) + ).returns(T.anything) } - def ReturnFailingAssert( + def RecursiveAliasCycle( *varargs, - inp:, + input:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("ReturnFailingAssert may only be called with keyword arguments") + raise ArgumentError.new("RecursiveAliasCycle may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.call_function( - "ReturnFailingAssert", + "RecursiveAliasCycle", { - inp: inp, + input: input, }, @ctx_manager, baml_options[:tb]&.instance_variable_get(:@registry), @@ -2421,27 +2677,27 @@ def ReturnFailingAssert( sig { params( varargs: T.untyped, - a: Integer, + cls: Baml::Types::NodeWithAliasIndirection, baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Baml::Types::MalformedConstraints) + ).returns(Baml::Types::NodeWithAliasIndirection) } - def ReturnMalformedConstraints( + def RecursiveClassWithAliasIndirection( *varargs, - a:, + cls:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("ReturnMalformedConstraints may only be called with keyword arguments") + raise ArgumentError.new("RecursiveClassWithAliasIndirection may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.call_function( - "ReturnMalformedConstraints", + "RecursiveClassWithAliasIndirection", { - a: a, + cls: cls, }, @ctx_manager, baml_options[:tb]&.instance_variable_get(:@registry), @@ -2453,27 +2709,27 @@ def ReturnMalformedConstraints( sig { params( varargs: T.untyped, - input: String, + money: Baml::Checked[Integer], baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Baml::Types::Schema) + ).returns(Baml::Checked[Integer]) } - def SchemaDescriptions( + def ReturnAliasWithMergedAttributes( *varargs, - input:, + money:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("SchemaDescriptions may only be called with keyword arguments") + raise ArgumentError.new("ReturnAliasWithMergedAttributes may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.call_function( - "SchemaDescriptions", + "ReturnAliasWithMergedAttributes", { - input: input, + money: money, }, @ctx_manager, baml_options[:tb]&.instance_variable_get(:@registry), @@ -2485,27 +2741,27 @@ def SchemaDescriptions( sig { params( varargs: T.untyped, - digits: Integer, + inp: Integer, baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Baml::Types::BigNumbers) + ).returns(Integer) } - def StreamBigNumbers( + def ReturnFailingAssert( *varargs, - digits:, + inp:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("StreamBigNumbers may only be called with keyword arguments") + raise ArgumentError.new("ReturnFailingAssert may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.call_function( - "StreamBigNumbers", + "ReturnFailingAssert", { - digits: digits, + inp: inp, }, @ctx_manager, baml_options[:tb]&.instance_variable_get(:@registry), @@ -2517,27 +2773,187 @@ def StreamBigNumbers( sig { params( varargs: T.untyped, - theme: String,length: Integer, + a: Integer, baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Baml::Types::TwoStoriesOneTitle) + ).returns(Baml::Types::MalformedConstraints) } - def StreamFailingAssertion( + def ReturnMalformedConstraints( *varargs, - theme:,length:, + a:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("StreamFailingAssertion may only be called with keyword arguments") + raise ArgumentError.new("ReturnMalformedConstraints may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.call_function( - "StreamFailingAssertion", + "ReturnMalformedConstraints", { - theme: theme,length: length, + a: a, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + input: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::Schema) + } + def SchemaDescriptions( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SchemaDescriptions may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "SchemaDescriptions", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def SimpleRecursiveListAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveListAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "SimpleRecursiveListAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def SimpleRecursiveMapAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveMapAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "SimpleRecursiveMapAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + digits: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::BigNumbers) + } + def StreamBigNumbers( + *varargs, + digits:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("StreamBigNumbers may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "StreamBigNumbers", + { + digits: digits, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + theme: String,length: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::TwoStoriesOneTitle) + } + def StreamFailingAssertion( + *varargs, + theme:,length:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("StreamFailingAssertion may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "StreamFailingAssertion", + { + theme: theme,length: length, }, @ctx_manager, baml_options[:tb]&.instance_variable_get(:@registry), @@ -3938,6 +4354,76 @@ def AaaSamOutputFormat( ) end + sig { + params( + varargs: T.untyped, + list: Baml::Types::LinkedListAliasNode, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::LinkedListAliasNode]) + } + def AliasThatPointsToRecursiveType( + *varargs, + list:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("AliasThatPointsToRecursiveType may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "AliasThatPointsToRecursiveType", + { + list: list, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::LinkedListAliasNode, Baml::Types::LinkedListAliasNode].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + money: Baml::Checked[Integer], + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Checked[Integer]]) + } + def AliasWithMultipleAttrs( + *varargs, + money:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("AliasWithMultipleAttrs may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "AliasWithMultipleAttrs", + { + money: money, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::Checked[T.nilable(Integer)], Baml::Checked[Integer]].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -4253,6 +4739,41 @@ def BuildTree( ) end + sig { + params( + varargs: T.untyped, + cls: Baml::Types::ClassToRecAlias, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::ClassToRecAlias]) + } + def ClassThatPointsToRecursiveClassThroughAlias( + *varargs, + cls:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ClassThatPointsToRecursiveClassThroughAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "ClassThatPointsToRecursiveClassThroughAlias", + { + cls: cls, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::ClassToRecAlias, Baml::Types::ClassToRecAlias].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -5971,25 +6492,25 @@ def InOutSingleLiteralStringMapKey( sig { params( varargs: T.untyped, - input: String, + input: T.anything, baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Baml::BamlStream[T.any(Integer, T::Boolean, String)]) + ).returns(Baml::BamlStream[T.anything]) } - def LiteralUnionsTest( + def JsonTypeAliasCycle( *varargs, input:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("LiteralUnionsTest may only be called with keyword arguments") + raise ArgumentError.new("JsonTypeAliasCycle may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.stream_function( - "LiteralUnionsTest", + "JsonTypeAliasCycle", { input: input, }, @@ -5997,7 +6518,7 @@ def LiteralUnionsTest( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T.nilable(T.any(T.nilable(Integer), T.nilable(T::Boolean), T.nilable(String))), T.any(Integer, T::Boolean, String)].new( + Baml::BamlStream[T.anything, T.anything].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6006,33 +6527,68 @@ def LiteralUnionsTest( sig { params( varargs: T.untyped, - + input: String, baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] - ).returns(Baml::BamlStream[Baml::Checked[Baml::Types::BlockConstraint]]) + ).returns(Baml::BamlStream[T.any(Integer, T::Boolean, String)]) } - def MakeBlockConstraint( + def LiteralUnionsTest( *varargs, - + input:, baml_options: {} ) if varargs.any? - raise ArgumentError.new("MakeBlockConstraint may only be called with keyword arguments") + raise ArgumentError.new("LiteralUnionsTest may only be called with keyword arguments") end if (baml_options.keys - [:client_registry, :tb]).any? raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") end raw = @runtime.stream_function( - "MakeBlockConstraint", + "LiteralUnionsTest", { - + input: input, }, @ctx_manager, baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::Checked[Baml::PartialTypes::BlockConstraint], Baml::Checked[Baml::Types::BlockConstraint]].new( + Baml::BamlStream[T.nilable(T.any(T.nilable(Integer), T.nilable(T::Boolean), T.nilable(String))), T.any(Integer, T::Boolean, String)].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Checked[Baml::Types::BlockConstraint]]) + } + def MakeBlockConstraint( + *varargs, + + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MakeBlockConstraint may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "MakeBlockConstraint", + { + + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::Checked[Baml::PartialTypes::BlockConstraint], Baml::Checked[Baml::Types::BlockConstraint]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6073,6 +6629,76 @@ def MakeNestedBlockConstraint( ) end + sig { + params( + varargs: T.untyped, + m: T::Hash[String, T::Array[String]], + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T::Hash[String, T::Array[String]]]) + } + def MapAlias( + *varargs, + m:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MapAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "MapAlias", + { + m: m, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T::Hash[String, T::Array[T.nilable(String)]], T::Hash[String, T::Array[String]]].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + money: Integer, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::MergeAttrs]) + } + def MergeAliasAttributes( + *varargs, + money:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MergeAliasAttributes may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "MergeAliasAttributes", + { + money: money, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::MergeAttrs, Baml::Types::MergeAttrs].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -6108,6 +6734,41 @@ def MyFunc( ) end + sig { + params( + varargs: T.untyped, + c: T.any(T.any(Integer, String, T::Boolean, Float), T::Array[String], T::Hash[String, T::Array[String]]), + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.any(T.any(Integer, String, T::Boolean, Float), T::Array[String], T::Hash[String, T::Array[String]])]) + } + def NestedAlias( + *varargs, + c:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("NestedAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "NestedAlias", + { + c: c, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.nilable(T.any(T.nilable(T.any(T.nilable(Integer), T.nilable(String), T.nilable(T::Boolean), T.nilable(Float))), T::Array[T.nilable(String)], T::Hash[String, T::Array[T.nilable(String)]])), T.any(T.any(Integer, String, T::Boolean, Float), T::Array[String], T::Hash[String, T::Array[String]])].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -6213,6 +6874,41 @@ def PredictAgeBare( ) end + sig { + params( + varargs: T.untyped, + p: T.any(Integer, String, T::Boolean, Float), + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.any(Integer, String, T::Boolean, Float)]) + } + def PrimitiveAlias( + *varargs, + p:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PrimitiveAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "PrimitiveAlias", + { + p: p, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.nilable(T.any(T.nilable(Integer), T.nilable(String), T.nilable(T::Boolean), T.nilable(Float))), T.any(Integer, String, T::Boolean, Float)].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -6458,6 +7154,111 @@ def PromptTestStreaming( ) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def RecursiveAliasCycle( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("RecursiveAliasCycle may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "RecursiveAliasCycle", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + cls: Baml::Types::NodeWithAliasIndirection, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::NodeWithAliasIndirection]) + } + def RecursiveClassWithAliasIndirection( + *varargs, + cls:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("RecursiveClassWithAliasIndirection may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "RecursiveClassWithAliasIndirection", + { + cls: cls, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::NodeWithAliasIndirection, Baml::Types::NodeWithAliasIndirection].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + money: Baml::Checked[Integer], + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Checked[Integer]]) + } + def ReturnAliasWithMergedAttributes( + *varargs, + money:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ReturnAliasWithMergedAttributes may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "ReturnAliasWithMergedAttributes", + { + money: money, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::Checked[T.nilable(Integer)], Baml::Checked[Integer]].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -6563,6 +7364,76 @@ def SchemaDescriptions( ) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def SimpleRecursiveListAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveListAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "SimpleRecursiveListAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def SimpleRecursiveMapAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveMapAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "SimpleRecursiveMapAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index 9ce81fe73..bbcec4073 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -74,15 +74,17 @@ module Inlined "test-files/functions/output/literal-boolean.baml" => "function FnOutputLiteralBool(input: string) -> false {\n client GPT35\n prompt #\"\n Return a false: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralBool {\n functions [FnOutputLiteralBool]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/literal-int.baml" => "function FnOutputLiteralInt(input: string) -> 5 {\n client GPT35\n prompt #\"\n Return an integer: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralInt {\n functions [FnOutputLiteralInt]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/literal-string.baml" => "function FnOutputLiteralString(input: string) -> \"example output\" {\n client GPT35\n prompt #\"\n Return a string: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralString {\n functions [FnOutputLiteralString]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/literal-unions.baml" => "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/literal-unions.baml" => "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values without any additional context: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/map-enum-key.baml" => "enum MapKey {\n A\n B\n C\n}\n\nfunction InOutEnumMapKey(i1: map, i2: map) -> map {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these: {{i1}} {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/map-literal-union-key.baml" => "function InOutLiteralStringUnionMapKey(\n i1: map<\"one\" | \"two\" | (\"three\" | \"four\"), string>, \n i2: map<\"one\" | \"two\" | (\"three\" | \"four\"), string>\n) -> map<\"one\" | \"two\" | (\"three\" | \"four\"), string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these:\n \n {{i1}}\n \n {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction InOutSingleLiteralStringMapKey(m: map<\"key\", string>) -> map<\"key\", string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Return the same map you were given:\n \n {{m}}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/mutually-recursive-classes.baml" => "class Tree {\n data int\n children Forest\n}\n\nclass Forest {\n trees Tree[]\n}\n\nclass BinaryNode {\n data int\n left BinaryNode?\n right BinaryNode?\n}\n\nfunction BuildTree(input: BinaryNode) -> Tree {\n client GPT35\n prompt #\"\n Given the input binary tree, transform it into a generic tree using the given schema.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestTree {\n functions [BuildTree]\n args {\n input {\n data 2\n left {\n data 1\n left null\n right null\n }\n right {\n data 3\n left null\n right null\n }\n }\n }\n}", "test-files/functions/output/optional-class.baml" => "class ClassOptionalOutput {\n prop1 string\n prop2 string\n}\n\nfunction FnClassOptionalOutput(input: string) -> ClassOptionalOutput? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\n\nclass Blah {\n prop4 string?\n}\n\nclass ClassOptionalOutput2 {\n prop1 string?\n prop2 string?\n prop3 Blah?\n}\n\nfunction FnClassOptionalOutput2(input: string) -> ClassOptionalOutput2? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnClassOptionalOutput2 {\n functions [FnClassOptionalOutput2, FnClassOptionalOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/optional.baml" => "class OptionalTest_Prop1 {\n omega_a string\n omega_b int\n}\n\nenum OptionalTest_CategoryType {\n Aleph\n Beta\n Gamma\n}\n \nclass OptionalTest_ReturnType {\n omega_1 OptionalTest_Prop1?\n omega_2 string?\n omega_3 (OptionalTest_CategoryType?)[]\n} \n \nfunction OptionalTest_Function(input: string) -> (OptionalTest_ReturnType?)[]\n{ \n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest OptionalTest_Function {\n functions [OptionalTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/recursive-class.baml" => "class Node {\n data int\n next Node?\n}\n\nclass LinkedList {\n head Node?\n len int\n}\n\nclient O1 {\n provider \"openai\"\n options {\n model \"o1-mini\"\n default_role \"user\"\n }\n}\n\nfunction BuildLinkedList(input: int[]) -> LinkedList {\n client O1\n prompt #\"\n Build a linked list from the input array of integers.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestLinkedList {\n functions [BuildLinkedList]\n args {\n input [1, 2, 3, 4, 5]\n }\n}\n", + "test-files/functions/output/recursive-type-aliases.baml" => "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml" => "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml" => "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/unions.baml" => "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml" => "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml" => "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/ruby/baml_client/partial-types.rb b/integ-tests/ruby/baml_client/partial-types.rb index 478205c1e..438593d7f 100644 --- a/integ-tests/ruby/baml_client/partial-types.rb +++ b/integ-tests/ruby/baml_client/partial-types.rb @@ -28,6 +28,7 @@ class BlockConstraintForParam < T::Struct; end class BookOrder < T::Struct; end class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end + class ClassToRecAlias < T::Struct; end class ClassWithImage < T::Struct; end class CompoundBigNumbers < T::Struct; end class ContactInfo < T::Struct; end @@ -52,18 +53,21 @@ class InnerClass2 < T::Struct; end class InputClass < T::Struct; end class InputClassNested < T::Struct; end class LinkedList < T::Struct; end + class LinkedListAliasNode < T::Struct; end class LiteralClassHello < T::Struct; end class LiteralClassOne < T::Struct; end class LiteralClassTwo < T::Struct; end class MalformedConstraints < T::Struct; end class MalformedConstraints2 < T::Struct; end class Martian < T::Struct; end + class MergeAttrs < T::Struct; end class NamedArgsSingleClass < T::Struct; end class Nested < T::Struct; end class Nested2 < T::Struct; end class NestedBlockConstraint < T::Struct; end class NestedBlockConstraintForParam < T::Struct; end class Node < T::Struct; end + class NodeWithAliasIndirection < T::Struct; end class OptionalListAndMap < T::Struct; end class OptionalTest_Prop1 < T::Struct; end class OptionalTest_ReturnType < T::Struct; end @@ -208,6 +212,18 @@ def initialize(props) @props = props end end + class ClassToRecAlias < T::Struct + include Baml::Sorbet::Struct + const :list, Baml::PartialTypes::LinkedListAliasNode + + def initialize(props) + super( + list: props[:list], + ) + + @props = props + end + end class ClassWithImage < T::Struct include Baml::Sorbet::Struct const :myImage, T.nilable(Baml::Image) @@ -560,6 +576,20 @@ def initialize(props) @props = props end end + class LinkedListAliasNode < T::Struct + include Baml::Sorbet::Struct + const :value, T.nilable(Integer) + const :next, Baml::PartialTypes::LinkedListAliasNode + + def initialize(props) + super( + value: props[:value], + next: props[:next], + ) + + @props = props + end + end class LiteralClassHello < T::Struct include Baml::Sorbet::Struct const :prop, T.nilable(String) @@ -636,6 +666,18 @@ def initialize(props) @props = props end end + class MergeAttrs < T::Struct + include Baml::Sorbet::Struct + const :amount, Baml::Checked[T.nilable(Integer)] + + def initialize(props) + super( + amount: props[:amount], + ) + + @props = props + end + end class NamedArgsSingleClass < T::Struct include Baml::Sorbet::Struct const :key, T.nilable(String) @@ -720,6 +762,20 @@ def initialize(props) @props = props end end + class NodeWithAliasIndirection < T::Struct + include Baml::Sorbet::Struct + const :value, T.nilable(Integer) + const :next, Baml::PartialTypes::NodeWithAliasIndirection + + def initialize(props) + super( + value: props[:value], + next: props[:next], + ) + + @props = props + end + end class OptionalListAndMap < T::Struct include Baml::Sorbet::Struct const :p, T::Array[T.nilable(String)] diff --git a/integ-tests/ruby/baml_client/type-registry.rb b/integ-tests/ruby/baml_client/type-registry.rb index 037563564..3095f25e9 100644 --- a/integ-tests/ruby/baml_client/type-registry.rb +++ b/integ-tests/ruby/baml_client/type-registry.rb @@ -18,7 +18,7 @@ module Baml class TypeBuilder def initialize @registry = Baml::Ffi::TypeBuilder.new - @classes = Set[ "BigNumbers", "BinaryNode", "Blah", "BlockConstraint", "BlockConstraintForParam", "BookOrder", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassWithImage", "CompoundBigNumbers", "ContactInfo", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Earthling", "Education", "Email", "EmailAddress", "Event", "FakeImage", "FlightConfirmation", "FooAny", "Forest", "GroceryReceipt", "InnerClass", "InnerClass2", "InputClass", "InputClassNested", "LinkedList", "LiteralClassHello", "LiteralClassOne", "LiteralClassTwo", "MalformedConstraints", "MalformedConstraints2", "Martian", "NamedArgsSingleClass", "Nested", "Nested2", "NestedBlockConstraint", "NestedBlockConstraintForParam", "Node", "OptionalListAndMap", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "OriginalA", "OriginalB", "Person", "PhoneNumber", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "Tree", "TwoStoriesOneTitle", "UnionTest_ReturnType", "WithReasoning", ] + @classes = Set[ "BigNumbers", "BinaryNode", "Blah", "BlockConstraint", "BlockConstraintForParam", "BookOrder", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassToRecAlias", "ClassWithImage", "CompoundBigNumbers", "ContactInfo", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Earthling", "Education", "Email", "EmailAddress", "Event", "FakeImage", "FlightConfirmation", "FooAny", "Forest", "GroceryReceipt", "InnerClass", "InnerClass2", "InputClass", "InputClassNested", "LinkedList", "LinkedListAliasNode", "LiteralClassHello", "LiteralClassOne", "LiteralClassTwo", "MalformedConstraints", "MalformedConstraints2", "Martian", "MergeAttrs", "NamedArgsSingleClass", "Nested", "Nested2", "NestedBlockConstraint", "NestedBlockConstraintForParam", "Node", "NodeWithAliasIndirection", "OptionalListAndMap", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "OriginalA", "OriginalB", "Person", "PhoneNumber", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "Tree", "TwoStoriesOneTitle", "UnionTest_ReturnType", "WithReasoning", ] @enums = Set[ "AliasedEnum", "Category", "Category2", "Category3", "Color", "DataType", "DynEnumOne", "DynEnumTwo", "EnumInClass", "EnumOutput", "Hobby", "MapKey", "NamedArgsSingleEnum", "NamedArgsSingleEnumList", "OptionalTest_CategoryType", "OrderStatus", "Tag", "TestEnum", ] end diff --git a/integ-tests/ruby/baml_client/types.rb b/integ-tests/ruby/baml_client/types.rb index 34f80c77e..320f35062 100644 --- a/integ-tests/ruby/baml_client/types.rb +++ b/integ-tests/ruby/baml_client/types.rb @@ -153,6 +153,7 @@ class BlockConstraintForParam < T::Struct; end class BookOrder < T::Struct; end class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end + class ClassToRecAlias < T::Struct; end class ClassWithImage < T::Struct; end class CompoundBigNumbers < T::Struct; end class ContactInfo < T::Struct; end @@ -177,18 +178,21 @@ class InnerClass2 < T::Struct; end class InputClass < T::Struct; end class InputClassNested < T::Struct; end class LinkedList < T::Struct; end + class LinkedListAliasNode < T::Struct; end class LiteralClassHello < T::Struct; end class LiteralClassOne < T::Struct; end class LiteralClassTwo < T::Struct; end class MalformedConstraints < T::Struct; end class MalformedConstraints2 < T::Struct; end class Martian < T::Struct; end + class MergeAttrs < T::Struct; end class NamedArgsSingleClass < T::Struct; end class Nested < T::Struct; end class Nested2 < T::Struct; end class NestedBlockConstraint < T::Struct; end class NestedBlockConstraintForParam < T::Struct; end class Node < T::Struct; end + class NodeWithAliasIndirection < T::Struct; end class OptionalListAndMap < T::Struct; end class OptionalTest_Prop1 < T::Struct; end class OptionalTest_ReturnType < T::Struct; end @@ -333,6 +337,18 @@ def initialize(props) @props = props end end + class ClassToRecAlias < T::Struct + include Baml::Sorbet::Struct + const :list, Baml::Types::LinkedListAliasNode + + def initialize(props) + super( + list: props[:list], + ) + + @props = props + end + end class ClassWithImage < T::Struct include Baml::Sorbet::Struct const :myImage, Baml::Image @@ -685,6 +701,20 @@ def initialize(props) @props = props end end + class LinkedListAliasNode < T::Struct + include Baml::Sorbet::Struct + const :value, Integer + const :next, T.nilable(Baml::Types::LinkedListAliasNode) + + def initialize(props) + super( + value: props[:value], + next: props[:next], + ) + + @props = props + end + end class LiteralClassHello < T::Struct include Baml::Sorbet::Struct const :prop, String @@ -761,6 +791,18 @@ def initialize(props) @props = props end end + class MergeAttrs < T::Struct + include Baml::Sorbet::Struct + const :amount, Baml::Checked[Integer] + + def initialize(props) + super( + amount: props[:amount], + ) + + @props = props + end + end class NamedArgsSingleClass < T::Struct include Baml::Sorbet::Struct const :key, String @@ -845,6 +887,20 @@ def initialize(props) @props = props end end + class NodeWithAliasIndirection < T::Struct + include Baml::Sorbet::Struct + const :value, Integer + const :next, T.nilable(Baml::Types::NodeWithAliasIndirection) + + def initialize(props) + super( + value: props[:value], + next: props[:next], + ) + + @props = props + end + end class OptionalListAndMap < T::Struct include Baml::Sorbet::Struct const :p, T.nilable(T::Array[String]) diff --git a/integ-tests/ruby/test_functions.rb b/integ-tests/ruby/test_functions.rb index d12af5413..988ae9865 100644 --- a/integ-tests/ruby/test_functions.rb +++ b/integ-tests/ruby/test_functions.rb @@ -78,6 +78,47 @@ res = b.InOutSingleLiteralStringMapKey(m: {"key" => "1"}) assert_equal res['key'], "1" + + res = b.PrimitiveAlias(p: "test") + assert_equal res, "test" + + res = b.MapAlias(m: {"A" => ["B", "C"], "B" => [], "C" => []}) + assert_equal res, {"A" => ["B", "C"], "B" => [], "C" => []} + + res = b.NestedAlias(c: "test") + assert_equal res, "test" + + res = b.NestedAlias(c: {"A" => ["B", "C"], "B" => [], "C" => []}) + assert_equal res, {"A" => ["B", "C"], "B" => [], "C" => []} + + res = b.AliasThatPointsToRecursiveType(list: Baml::Types::LinkedListAliasNode.new( + value: 1, + next: nil, + )) + # TODO: Doesn't implement equality + # assert_equal res, Baml::Types::LinkedListAliasNode.new( + # value: 1, + # next: nil, + # ) + + res = b.ClassThatPointsToRecursiveClassThroughAlias( + cls: Baml::Types::ClassToRecAlias.new( + list: Baml::Types::LinkedListAliasNode.new( + value: 1, + next: nil, + ) + ) + ) + + res = b.RecursiveClassWithAliasIndirection.new( + cls: Baml::Types::NodeWithAliasIndirection.new( + value: 1, + next: Baml::Types::NodeWithAliasIndirection.new( + value: 2, + next: nil, + ) + ) + ) end it "optional map and list" do diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index b739410de..4cc0f0e7c 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -17,7 +17,7 @@ $ pnpm add @boundaryml/baml // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlStream, Image, ClientRegistry, BamlValidationError, createBamlValidationError } from "@boundaryml/baml" import { Checked, Check } from "./types" -import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassToRecAlias, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LinkedListAliasNode, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, MergeAttrs, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, NodeWithAliasIndirection, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" @@ -68,6 +68,56 @@ export class BamlAsyncClient { } } + async AliasThatPointsToRecursiveType( + list: LinkedListAliasNode, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "AliasThatPointsToRecursiveType", + { + "list": list + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as LinkedListAliasNode + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async AliasWithMultipleAttrs( + money: Checked, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise> { + try { + const raw = await this.runtime.callFunction( + "AliasWithMultipleAttrs", + { + "money": money + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Checked + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async AliasedInputClass( input: InputClass, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -293,6 +343,31 @@ export class BamlAsyncClient { } } + async ClassThatPointsToRecursiveClassThroughAlias( + cls: ClassToRecAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as ClassToRecAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async ClassifyDynEnumTwo( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1518,6 +1593,31 @@ export class BamlAsyncClient { } } + async JsonTypeAliasCycle( + input: JsonValue, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "JsonTypeAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as JsonValue + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1593,6 +1693,56 @@ export class BamlAsyncClient { } } + async MapAlias( + m: Record, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise> { + try { + const raw = await this.runtime.callFunction( + "MapAlias", + { + "m": m + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Record + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async MergeAliasAttributes( + money: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "MergeAliasAttributes", + { + "money": money + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as MergeAttrs + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async MyFunc( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1618,6 +1768,31 @@ export class BamlAsyncClient { } } + async NestedAlias( + c: number | string | boolean | number | string[] | Record, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise> { + try { + const raw = await this.runtime.callFunction( + "NestedAlias", + { + "c": c + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number | string | boolean | number | string[] | Record + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async OptionalTest_Function( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1693,6 +1868,31 @@ export class BamlAsyncClient { } } + async PrimitiveAlias( + p: number | string | boolean | number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "PrimitiveAlias", + { + "p": p + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number | string | boolean | number + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1868,6 +2068,81 @@ export class BamlAsyncClient { } } + async RecursiveAliasCycle( + input: RecAliasOne, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "RecursiveAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecAliasOne + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async RecursiveClassWithAliasIndirection( + cls: NodeWithAliasIndirection, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "RecursiveClassWithAliasIndirection", + { + "cls": cls + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as NodeWithAliasIndirection + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async ReturnAliasWithMergedAttributes( + money: Checked, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise> { + try { + const raw = await this.runtime.callFunction( + "ReturnAliasWithMergedAttributes", + { + "money": money + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Checked + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async ReturnFailingAssert( inp: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1943,6 +2218,56 @@ export class BamlAsyncClient { } } + async SimpleRecursiveListAlias( + input: RecursiveListAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "SimpleRecursiveListAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveListAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async SimpleRecursiveMapAlias( + input: RecursiveMapAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "SimpleRecursiveMapAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveMapAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -3082,13 +3407,79 @@ class BamlStreamClient { } } - AliasedInputClass( - input: InputClass, + AliasThatPointsToRecursiveType( + list: LinkedListAliasNode, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream, LinkedListAliasNode> { try { const raw = this.runtime.streamFunction( - "AliasedInputClass", + "AliasThatPointsToRecursiveType", + { + "list": list + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, LinkedListAliasNode>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is LinkedListAliasNode => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + AliasWithMultipleAttrs( + money: Checked, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream>, Checked> { + try { + const raw = this.runtime.streamFunction( + "AliasWithMultipleAttrs", + { + "money": money + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream>, Checked>( + raw, + (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + AliasedInputClass( + input: InputClass, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, string> { + try { + const raw = this.runtime.streamFunction( + "AliasedInputClass", { "input": input }, @@ -3379,6 +3770,39 @@ class BamlStreamClient { } } + ClassThatPointsToRecursiveClassThroughAlias( + cls: ClassToRecAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, ClassToRecAlias> { + try { + const raw = this.runtime.streamFunction( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, ClassToRecAlias>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is ClassToRecAlias => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + ClassifyDynEnumTwo( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -4996,6 +5420,39 @@ class BamlStreamClient { } } + JsonTypeAliasCycle( + input: JsonValue, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, JsonValue> { + try { + const raw = this.runtime.streamFunction( + "JsonTypeAliasCycle", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, JsonValue>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is JsonValue => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5095,6 +5552,72 @@ class BamlStreamClient { } } + MapAlias( + m: Record, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream>, Record> { + try { + const raw = this.runtime.streamFunction( + "MapAlias", + { + "m": m + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream>, Record>( + raw, + (a): a is RecursivePartialNull> => a, + (a): a is Record => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + MergeAliasAttributes( + money: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, MergeAttrs> { + try { + const raw = this.runtime.streamFunction( + "MergeAliasAttributes", + { + "money": money + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, MergeAttrs>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is MergeAttrs => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + MyFunc( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5128,6 +5651,39 @@ class BamlStreamClient { } } + NestedAlias( + c: number | string | boolean | number | string[] | Record, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream>, number | string | boolean | number | string[] | Record> { + try { + const raw = this.runtime.streamFunction( + "NestedAlias", + { + "c": c + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream>, number | string | boolean | number | string[] | Record>( + raw, + (a): a is RecursivePartialNull> => a, + (a): a is number | string | boolean | number | string[] | Record => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + OptionalTest_Function( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5227,6 +5783,39 @@ class BamlStreamClient { } } + PrimitiveAlias( + p: number | string | boolean | number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, number | string | boolean | number> { + try { + const raw = this.runtime.streamFunction( + "PrimitiveAlias", + { + "p": p + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, number | string | boolean | number>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is number | string | boolean | number => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5458,6 +6047,105 @@ class BamlStreamClient { } } + RecursiveAliasCycle( + input: RecAliasOne, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, RecAliasOne> { + try { + const raw = this.runtime.streamFunction( + "RecursiveAliasCycle", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, RecAliasOne>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is RecAliasOne => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + RecursiveClassWithAliasIndirection( + cls: NodeWithAliasIndirection, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, NodeWithAliasIndirection> { + try { + const raw = this.runtime.streamFunction( + "RecursiveClassWithAliasIndirection", + { + "cls": cls + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, NodeWithAliasIndirection>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is NodeWithAliasIndirection => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + ReturnAliasWithMergedAttributes( + money: Checked, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream>, Checked> { + try { + const raw = this.runtime.streamFunction( + "ReturnAliasWithMergedAttributes", + { + "money": money + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream>, Checked>( + raw, + (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + ReturnFailingAssert( inp: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5557,6 +6245,72 @@ class BamlStreamClient { } } + SimpleRecursiveListAlias( + input: RecursiveListAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, RecursiveListAlias> { + try { + const raw = this.runtime.streamFunction( + "SimpleRecursiveListAlias", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, RecursiveListAlias>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is RecursiveListAlias => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + SimpleRecursiveMapAlias( + input: RecursiveMapAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, RecursiveMapAlias> { + try { + const raw = this.runtime.streamFunction( + "SimpleRecursiveMapAlias", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, RecursiveMapAlias>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is RecursiveMapAlias => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index a517f92b8..c55e9b5e8 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -75,15 +75,17 @@ const fileMap = { "test-files/functions/output/literal-boolean.baml": "function FnOutputLiteralBool(input: string) -> false {\n client GPT35\n prompt #\"\n Return a false: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralBool {\n functions [FnOutputLiteralBool]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/literal-int.baml": "function FnOutputLiteralInt(input: string) -> 5 {\n client GPT35\n prompt #\"\n Return an integer: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralInt {\n functions [FnOutputLiteralInt]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/literal-string.baml": "function FnOutputLiteralString(input: string) -> \"example output\" {\n client GPT35\n prompt #\"\n Return a string: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralString {\n functions [FnOutputLiteralString]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/literal-unions.baml": "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/literal-unions.baml": "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values without any additional context: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/map-enum-key.baml": "enum MapKey {\n A\n B\n C\n}\n\nfunction InOutEnumMapKey(i1: map, i2: map) -> map {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these: {{i1}} {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/map-literal-union-key.baml": "function InOutLiteralStringUnionMapKey(\n i1: map<\"one\" | \"two\" | (\"three\" | \"four\"), string>, \n i2: map<\"one\" | \"two\" | (\"three\" | \"four\"), string>\n) -> map<\"one\" | \"two\" | (\"three\" | \"four\"), string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these:\n \n {{i1}}\n \n {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction InOutSingleLiteralStringMapKey(m: map<\"key\", string>) -> map<\"key\", string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Return the same map you were given:\n \n {{m}}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/mutually-recursive-classes.baml": "class Tree {\n data int\n children Forest\n}\n\nclass Forest {\n trees Tree[]\n}\n\nclass BinaryNode {\n data int\n left BinaryNode?\n right BinaryNode?\n}\n\nfunction BuildTree(input: BinaryNode) -> Tree {\n client GPT35\n prompt #\"\n Given the input binary tree, transform it into a generic tree using the given schema.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestTree {\n functions [BuildTree]\n args {\n input {\n data 2\n left {\n data 1\n left null\n right null\n }\n right {\n data 3\n left null\n right null\n }\n }\n }\n}", "test-files/functions/output/optional-class.baml": "class ClassOptionalOutput {\n prop1 string\n prop2 string\n}\n\nfunction FnClassOptionalOutput(input: string) -> ClassOptionalOutput? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\n\nclass Blah {\n prop4 string?\n}\n\nclass ClassOptionalOutput2 {\n prop1 string?\n prop2 string?\n prop3 Blah?\n}\n\nfunction FnClassOptionalOutput2(input: string) -> ClassOptionalOutput2? {\n client GPT35\n prompt #\"\n Return a json blob for the following input:\n {{input}}\n\n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnClassOptionalOutput2 {\n functions [FnClassOptionalOutput2, FnClassOptionalOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/optional.baml": "class OptionalTest_Prop1 {\n omega_a string\n omega_b int\n}\n\nenum OptionalTest_CategoryType {\n Aleph\n Beta\n Gamma\n}\n \nclass OptionalTest_ReturnType {\n omega_1 OptionalTest_Prop1?\n omega_2 string?\n omega_3 (OptionalTest_CategoryType?)[]\n} \n \nfunction OptionalTest_Function(input: string) -> (OptionalTest_ReturnType?)[]\n{ \n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest OptionalTest_Function {\n functions [OptionalTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/recursive-class.baml": "class Node {\n data int\n next Node?\n}\n\nclass LinkedList {\n head Node?\n len int\n}\n\nclient O1 {\n provider \"openai\"\n options {\n model \"o1-mini\"\n default_role \"user\"\n }\n}\n\nfunction BuildLinkedList(input: int[]) -> LinkedList {\n client O1\n prompt #\"\n Build a linked list from the input array of integers.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestLinkedList {\n functions [BuildLinkedList]\n args {\n input [1, 2, 3, 4, 5]\n }\n}\n", + "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index de791b8f1..62da62795 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -17,7 +17,7 @@ $ pnpm add @boundaryml/baml // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlSyncStream, Image, ClientRegistry, createBamlValidationError, BamlValidationError } from "@boundaryml/baml" import { Checked, Check } from "./types" -import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassToRecAlias, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LinkedListAliasNode, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, MergeAttrs, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, NodeWithAliasIndirection, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" @@ -68,6 +68,56 @@ export class BamlSyncClient { } } + AliasThatPointsToRecursiveType( + list: LinkedListAliasNode, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): LinkedListAliasNode { + try { + const raw = this.runtime.callFunctionSync( + "AliasThatPointsToRecursiveType", + { + "list": list + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as LinkedListAliasNode + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + AliasWithMultipleAttrs( + money: Checked, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Checked { + try { + const raw = this.runtime.callFunctionSync( + "AliasWithMultipleAttrs", + { + "money": money + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Checked + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + AliasedInputClass( input: InputClass, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -293,6 +343,31 @@ export class BamlSyncClient { } } + ClassThatPointsToRecursiveClassThroughAlias( + cls: ClassToRecAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): ClassToRecAlias { + try { + const raw = this.runtime.callFunctionSync( + "ClassThatPointsToRecursiveClassThroughAlias", + { + "cls": cls + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as ClassToRecAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + ClassifyDynEnumTwo( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1518,6 +1593,31 @@ export class BamlSyncClient { } } + JsonTypeAliasCycle( + input: JsonValue, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): JsonValue { + try { + const raw = this.runtime.callFunctionSync( + "JsonTypeAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as JsonValue + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1593,6 +1693,56 @@ export class BamlSyncClient { } } + MapAlias( + m: Record, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Record { + try { + const raw = this.runtime.callFunctionSync( + "MapAlias", + { + "m": m + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Record + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + MergeAliasAttributes( + money: number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): MergeAttrs { + try { + const raw = this.runtime.callFunctionSync( + "MergeAliasAttributes", + { + "money": money + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as MergeAttrs + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + MyFunc( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1618,6 +1768,31 @@ export class BamlSyncClient { } } + NestedAlias( + c: number | string | boolean | number | string[] | Record, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): number | string | boolean | number | string[] | Record { + try { + const raw = this.runtime.callFunctionSync( + "NestedAlias", + { + "c": c + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number | string | boolean | number | string[] | Record + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + OptionalTest_Function( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1693,6 +1868,31 @@ export class BamlSyncClient { } } + PrimitiveAlias( + p: number | string | boolean | number, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): number | string | boolean | number { + try { + const raw = this.runtime.callFunctionSync( + "PrimitiveAlias", + { + "p": p + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number | string | boolean | number + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1868,6 +2068,81 @@ export class BamlSyncClient { } } + RecursiveAliasCycle( + input: RecAliasOne, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): RecAliasOne { + try { + const raw = this.runtime.callFunctionSync( + "RecursiveAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecAliasOne + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + RecursiveClassWithAliasIndirection( + cls: NodeWithAliasIndirection, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): NodeWithAliasIndirection { + try { + const raw = this.runtime.callFunctionSync( + "RecursiveClassWithAliasIndirection", + { + "cls": cls + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as NodeWithAliasIndirection + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + ReturnAliasWithMergedAttributes( + money: Checked, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Checked { + try { + const raw = this.runtime.callFunctionSync( + "ReturnAliasWithMergedAttributes", + { + "money": money + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as Checked + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + ReturnFailingAssert( inp: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1943,6 +2218,56 @@ export class BamlSyncClient { } } + SimpleRecursiveListAlias( + input: RecursiveListAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): RecursiveListAlias { + try { + const raw = this.runtime.callFunctionSync( + "SimpleRecursiveListAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveListAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + SimpleRecursiveMapAlias( + input: RecursiveMapAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): RecursiveMapAlias { + try { + const raw = this.runtime.callFunctionSync( + "SimpleRecursiveMapAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveMapAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/type_builder.ts b/integ-tests/typescript/baml_client/type_builder.ts index 22e5b622b..2d40ca5d8 100644 --- a/integ-tests/typescript/baml_client/type_builder.ts +++ b/integ-tests/typescript/baml_client/type_builder.ts @@ -50,7 +50,7 @@ export default class TypeBuilder { constructor() { this.tb = new _TypeBuilder({ classes: new Set([ - "BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning", + "BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning", ]), enums: new Set([ "AliasedEnum","Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","MapKey","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum", diff --git a/integ-tests/typescript/baml_client/types.ts b/integ-tests/typescript/baml_client/types.ts index c802f70e5..fae4be644 100644 --- a/integ-tests/typescript/baml_client/types.ts +++ b/integ-tests/typescript/baml_client/types.ts @@ -207,6 +207,11 @@ export interface ClassOptionalOutput2 { } +export interface ClassToRecAlias { + list: LinkedListAliasNode + +} + export interface ClassWithImage { myImage: Image param2: string @@ -364,6 +369,12 @@ export interface LinkedList { } +export interface LinkedListAliasNode { + value: number + next?: LinkedListAliasNode | null + +} + export interface LiteralClassHello { prop: "hello" @@ -402,6 +413,11 @@ export interface Martian { } +export interface MergeAttrs { + amount: Checked + +} + export interface NamedArgsSingleClass { key: string key_two: boolean @@ -438,6 +454,12 @@ export interface Node { } +export interface NodeWithAliasIndirection { + value: number + next?: NodeWithAliasIndirection | null + +} + export interface OptionalListAndMap { p?: string[] | null q?: Record | null @@ -614,3 +636,19 @@ export interface WithReasoning { reasoning: string } + +type RecursiveMapAlias = Record + +type RecursiveListAlias = RecursiveListAlias[] + +type RecAliasOne = RecAliasTwo + +type RecAliasTwo = RecAliasThree + +type RecAliasThree = RecAliasOne[] + +type JsonValue = number | string | boolean | number | JsonObject | JsonArray + +type JsonObject = Record + +type JsonArray = JsonValue[] diff --git a/integ-tests/typescript/test-report.html b/integ-tests/typescript/test-report.html index e6e835369..19102aad1 100644 --- a/integ-tests/typescript/test-report.html +++ b/integ-tests/typescript/test-report.html @@ -257,9 +257,53 @@ font-size: 1rem; padding: 0 0.5rem; } +<<<<<<< HEAD +

Test Report

Started: 2024-12-18 19:35:20
Suites (1)
0 passed
1 failed
0 pending
Tests (83)
79 passed
4 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
0.361s
Integ tests > should work for all inputs
single string list
passed
0.444s
Integ tests > should work for all inputs
return literal union
passed
0.537s
Integ tests > should work for all inputs
single class
passed
0.442s
Integ tests > should work for all inputs
multiple classes
passed
0.685s
Integ tests > should work for all inputs
single enum list
passed
0.35s
Integ tests > should work for all inputs
single float
passed
0.339s
Integ tests > should work for all inputs
single int
passed
0.436s
Integ tests > should work for all inputs
single literal int
passed
0.511s
Integ tests > should work for all inputs
single literal bool
passed
0.387s
Integ tests > should work for all inputs
single literal string
passed
0.384s
Integ tests > should work for all inputs
single class with literal prop
passed
0.559s
Integ tests > should work for all inputs
single class with literal union prop
passed
0.78s
Integ tests > should work for all inputs
single optional string
passed
0.652s
Integ tests > should work for all inputs
single map string to string
passed
0.763s
Integ tests > should work for all inputs
single map string to class
passed
0.635s
Integ tests > should work for all inputs
single map string to map
passed
0.549s
Integ tests > should work for all inputs
enum key in map
passed
0.615s
Integ tests > should work for all inputs
literal string union key in map
passed
0.821s
Integ tests > should work for all inputs
single literal string key in map
passed
0.471s
Integ tests > should work for all inputs
primitive union alias
passed
0.553s
Integ tests > should work for all inputs
map alias
passed
1s
Integ tests > should work for all inputs
alias union
passed
1.357s
Integ tests > should work for all inputs
alias pointing to recursive class
passed
1.143s
Integ tests > should work for all inputs
class pointing to alias that points to recursive class
passed
1.105s
Integ tests > should work for all inputs
recursive class with alias indirection
passed
1.023s
Integ tests > should work for all inputs
merge alias attributes
passed
0.605s
Integ tests > should work for all inputs
return alias with merged attrs
passed
0.521s
Integ tests > should work for all inputs
alias with multiple attrs
passed
0.514s
Integ tests > should work for all inputs
simple recursive map alias
passed
1.231s
Integ tests > should work for all inputs
simple recursive map alias
passed
0.608s
Integ tests > should work for all inputs
recursive alias cycles
passed
0.821s
Integ tests > should work for all inputs
json type alias cycle
passed
5.814s
Integ tests
should work for all outputs
passed
5.424s
Integ tests
works with retries1
passed
1.155s
Integ tests
works with retries2
passed
2.451s
Integ tests
works with fallbacks
passed
1.945s
Integ tests
should work with image from url
passed
1.536s
Integ tests
should work with image from base 64
passed
1.436s
Integ tests
should work with audio base 64
passed
1.739s
Integ tests
should work with audio from url
passed
1.739s
Integ tests
should support streaming in OpenAI
passed
2.323s
Integ tests
should support streaming in Gemini
passed
7.924s
Integ tests
should support AWS
passed
1.583s
Integ tests
should support streaming in AWS
passed
1.614s
Integ tests
should allow overriding the region
passed
0.045s
Integ tests
should support OpenAI shorthand
passed
13.65s
Integ tests
should support OpenAI shorthand streaming
passed
9.065s
Integ tests
should support anthropic shorthand
passed
3.223s
Integ tests
should support anthropic shorthand streaming
passed
3.834s
Integ tests
should support streaming without iterating
passed
5.387s
Integ tests
should support streaming in Claude
passed
0.927s
Integ tests
should support azure
failed
1.087s
Error: expect(received).toContain(expected) // indexOf
+
+Expected substring: "donkey"
+Received string:    "barrel-tossing ape
+king of jungle, climbs with ease
+rescues damsel fair"
+    at Object.toContain (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:440:31)
Integ tests
should support azure streaming
passed
1.005s
Integ tests
should fail if azure is not configured
failed
0.051s
Error: expect(received).rejects.toThrow(expected)
+
+Expected substring: "BamlClientError"
+Received message:   "BamlError: BamlInvalidArgumentError: BAML function TestAzureFailure does not exist in baml_src/ (did you typo it?): function `TestAzureFailure` not found. Did you mean one of: TestAzure, ExpectFailure, ExtractResume, NestedAlias, TestAnthropic?"
+
+      2508 |         __baml_options__?.clientRegistry,
+      2509 |       )
+    > 2510 |       return raw.parsed() as string
+           |                  ^
+      2511 |     } catch (error: any) {
+      2512 |       const bamlError = createBamlValidationError(error);
+      2513 |       if (bamlError instanceof BamlValidationError) {
+
+      at BamlAsyncClient.parsed [as TestAzureFailure] (baml_client/async_client.ts:2510:18)
+      at tests/integ-tests.test.ts:455:7
+      at Object.<anonymous> (tests/integ-tests.test.ts:454:5)
+    at Object.toThrow (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/expect@29.7.0/node_modules/expect/build/index.js:218:22)
+    at Object.toThrow (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:456:16)
+    at Promise.then.completed (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/utils.js:298:28)
+    at new Promise (<anonymous>)
+    at callAsyncCircusFn (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/utils.js:231:10)
+    at _callCircusTest (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/run.js:316:40)
+    at _runTest (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/run.js:252:3)
+    at _runTestsForDescribeBlock (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/run.js:126:9)
+    at _runTestsForDescribeBlock (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/run.js:121:9)
+    at run (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/run.js:71:3)
+    at runAndTransformResultsToJestFormat (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/legacy-code-todo-rewrite/jestAdapterInit.js:122:21)
+    at jestAdapter (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-circus@29.7.0/node_modules/jest-circus/build/legacy-code-todo-rewrite/jestAdapter.js:79:19)
+    at runTestInternal (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-runner@29.7.0/node_modules/jest-runner/build/runTest.js:367:16)
+    at runTest (/workspaces/baml/integ-tests/typescript/node_modules/.pnpm/jest-runner@29.7.0/node_modules/jest-runner/build/runTest.js:444:34)
Integ tests
should support vertex
failed
0.001s
Error: BamlError: Failed to read service account file: 
+
+Caused by:
+    No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.014s
Integ tests
supports tracing async
passed
6.818s
Integ tests
should work with dynamic types single
passed
1.551s
Integ tests
should work with dynamic types enum
passed
1.222s
Integ tests
should work with dynamic literals
passed
1.523s
Integ tests
should work with dynamic types class
passed
1.831s
Integ tests
should work with dynamic inputs class
passed
0.532s
Integ tests
should work with dynamic inputs list
passed
0.492s
Integ tests
should work with dynamic output map
passed
0.841s
Integ tests
should work with dynamic output union
passed
1.729s
Integ tests
should work with nested classes
failed
0.11s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
+    at BamlStream.parsed [as getFinalResponse] (/workspaces/baml/engine/language_client_typescript/stream.js:58:39)
+    at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:723:19)
Integ tests
should work with dynamic client
passed
0.38s
Integ tests
should work with 'onLogEvent'
passed
1.649s
Integ tests
should work with a sync client
passed
0.417s
Integ tests
should raise an error when appropriate
passed
1.066s
Integ tests
should raise a BAMLValidationError
passed
0.576s
Integ tests
should reset environment variables correctly
passed
1.738s
Integ tests
should use aliases when serializing input objects - classes
passed
0.819s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
1.127s
Integ tests
should use aliases when serializing input objects - enums
passed
0.508s
Integ tests
should use aliases when serializing input objects - lists
passed
0.385s
Integ tests
constraints: should handle checks in return types
passed
0.712s
Integ tests
constraints: should handle checks in returned unions
passed
0.771s
Integ tests
constraints: should handle block-level checks
passed
0.591s
Integ tests
constraints: should handle nested-block-level checks
passed
0.615s
Integ tests
simple recursive type
passed
2.564s
Integ tests
mutually recursive type
passed
2.154s
+=======

Test Report

Started: 2024-12-19 17:53:14
Suites (1)
0 passed
1 failed
0 pending
Tests (71)
69 passed
2 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
0.4s
Integ tests > should work for all inputs
single string list
passed
0.513s
Integ tests > should work for all inputs
return literal union
passed
0.41s
Integ tests > should work for all inputs
optional list and map
passed
2.148s
Integ tests > should work for all inputs
single class
passed
0.614s
Integ tests > should work for all inputs
multiple classes
passed
0.513s
Integ tests > should work for all inputs
single enum list
passed
0.92s
Integ tests > should work for all inputs
single float
passed
0.403s
Integ tests > should work for all inputs
single int
passed
0.516s
Integ tests > should work for all inputs
single literal int
passed
0.41s
Integ tests > should work for all inputs
single literal bool
passed
0.369s
Integ tests > should work for all inputs
single literal string
passed
0.433s
Integ tests > should work for all inputs
single class with literal prop
passed
0.616s
Integ tests > should work for all inputs
single class with literal union prop
passed
0.877s
Integ tests > should work for all inputs
single optional string
passed
0.365s
Integ tests > should work for all inputs
single map string to string
passed
0.594s
Integ tests > should work for all inputs
single map string to class
passed
0.635s
Integ tests > should work for all inputs
single map string to map
passed
0.545s
Integ tests > should work for all inputs
enum key in map
passed
1.022s
Integ tests > should work for all inputs
literal string union key in map
passed
0.683s
Integ tests > should work for all inputs
single literal string key in map
passed
0.921s
Integ tests
should work for all outputs
passed
5.118s
Integ tests
works with retries1
passed
1.151s
Integ tests
works with retries2
passed
2.431s
Integ tests
works with fallbacks
passed
1.946s
Integ tests
should work with image from url
passed
4.301s
Integ tests
should work with image from base 64
passed
1.843s
Integ tests
should work with audio base 64
passed
2.046s
Integ tests
should work with audio from url
passed
1.845s
Integ tests
should support streaming in OpenAI
passed
2.228s
Integ tests
should support streaming in Gemini
passed
10.231s
Integ tests
should support AWS
passed
3.004s
Integ tests
should support streaming in AWS
passed
1.517s
Integ tests
should allow overriding the region
passed
0.077s
Integ tests
should support OpenAI shorthand
passed
16.429s
Integ tests
should support OpenAI shorthand streaming
passed
12.635s
Integ tests
should support anthropic shorthand
passed
6.513s
Integ tests
should support anthropic shorthand streaming
passed
2.57s
Integ tests
should support streaming without iterating
passed
5.048s
Integ tests
should support streaming in Claude
passed
1.126s
Integ tests
should support azure
passed
0.979s
Integ tests
should support azure streaming
passed
0.905s
Integ tests
should fail if azure is not configured
passed
0.047s
Integ tests
should support vertex
failed
0.001s
Error: BamlError: Failed to read service account file: 
 
 Caused by:
     No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.013s
Integ tests
supports tracing async
passed
3.442s
Integ tests
should work with dynamic types single
passed
1.321s
Integ tests
should work with dynamic types enum
passed
1.027s
Integ tests
should work with dynamic literals
passed
1.079s
Integ tests
should work with dynamic types class
passed
1.461s
Integ tests
should work with dynamic inputs class
passed
0.532s
Integ tests
should work with dynamic inputs list
passed
0.605s
Integ tests
should work with dynamic output map
passed
0.849s
Integ tests
should work with dynamic output union
passed
1.919s
Integ tests
should work with nested classes
failed
0.107s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
     at BamlStream.parsed [as getFinalResponse] (/workspaces/baml/engine/language_client_typescript/stream.js:58:39)
-    at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:645:19)
Integ tests
should work with dynamic client
passed
0.502s
Integ tests
should work with 'onLogEvent'
passed
1.996s
Integ tests
should work with a sync client
passed
0.666s
Integ tests
should raise an error when appropriate
passed
1.072s
Integ tests
should raise a BAMLValidationError
passed
0.462s
Integ tests
should reset environment variables correctly
passed
1.536s
Integ tests
should use aliases when serializing input objects - classes
passed
0.82s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
0.781s
Integ tests
should use aliases when serializing input objects - enums
passed
0.649s
Integ tests
should use aliases when serializing input objects - lists
passed
0.411s
Integ tests
constraints: should handle checks in return types
passed
0.717s
Integ tests
constraints: should handle checks in returned unions
passed
0.824s
Integ tests
constraints: should handle block-level checks
passed
0.71s
Integ tests
constraints: should handle nested-block-level checks
passed
0.707s
Integ tests
simple recursive type
passed
2.164s
Integ tests
mutually recursive type
passed
2.253s
\ No newline at end of file + at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:645:19)
Integ tests
should work with dynamic client
passed
0.502s
Integ tests
should work with 'onLogEvent'
passed
1.996s
Integ tests
should work with a sync client
passed
0.666s
Integ tests
should raise an error when appropriate
passed
1.072s
Integ tests
should raise a BAMLValidationError
passed
0.462s
Integ tests
should reset environment variables correctly
passed
1.536s
Integ tests
should use aliases when serializing input objects - classes
passed
0.82s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
0.781s
Integ tests
should use aliases when serializing input objects - enums
passed
0.649s
Integ tests
should use aliases when serializing input objects - lists
passed
0.411s
Integ tests
constraints: should handle checks in return types
passed
0.717s
Integ tests
constraints: should handle checks in returned unions
passed
0.824s
Integ tests
constraints: should handle block-level checks
passed
0.71s
Integ tests
constraints: should handle nested-block-level checks
passed
0.707s
Integ tests
simple recursive type
passed
2.164s
Integ tests
mutually recursive type
passed
2.253s
+>>>>>>> canary diff --git a/integ-tests/typescript/tests/integ-tests.test.ts b/integ-tests/typescript/tests/integ-tests.test.ts index 003ab7683..4f5783322 100644 --- a/integ-tests/typescript/tests/integ-tests.test.ts +++ b/integ-tests/typescript/tests/integ-tests.test.ts @@ -159,6 +159,92 @@ describe('Integ tests', () => { const res = await b.InOutSingleLiteralStringMapKey({ key: '1' }) expect(res).toHaveProperty('key', '1') }) + + it('primitive union alias', async () => { + const res = await b.PrimitiveAlias('test') + expect(res).toEqual('test') + }) + + it('map alias', async () => { + const res = await b.MapAlias({ A: ['B', 'C'], B: [], C: [] }) + expect(res).toEqual({ A: ['B', 'C'], B: [], C: [] }) + }) + + it('alias union', async () => { + let res = await b.NestedAlias('test') + expect(res).toEqual('test') + + res = await b.NestedAlias({ A: ['B', 'C'], B: [], C: [] }) + expect(res).toEqual({ A: ['B', 'C'], B: [], C: [] }) + }) + + it('alias pointing to recursive class', async () => { + const res = await b.AliasThatPointsToRecursiveType({ value: 1, next: null }) + expect(res).toEqual({ value: 1, next: null }) + }) + + it('class pointing to alias that points to recursive class', async () => { + const res = await b.ClassThatPointsToRecursiveClassThroughAlias({ list: { value: 1, next: null } }) + expect(res).toEqual({ list: { value: 1, next: null } }) + }) + + it('recursive class with alias indirection', async () => { + const res = await b.RecursiveClassWithAliasIndirection({ value: 1, next: { value: 2, next: null } }) + expect(res).toEqual({ value: 1, next: { value: 2, next: null } }) + }) + + it('merge alias attributes', async () => { + const res = await b.MergeAliasAttributes(123) + expect(res.amount.value).toEqual(123) + expect(res.amount.checks['gt_ten'].status).toEqual('succeeded') + }) + + it('return alias with merged attrs', async () => { + const res = await b.ReturnAliasWithMergedAttributes(123) + expect(res.value).toEqual(123) + expect(res.checks['gt_ten'].status).toEqual('succeeded') + }) + + it('alias with multiple attrs', async () => { + const res = await b.AliasWithMultipleAttrs(123) + expect(res.value).toEqual(123) + expect(res.checks['gt_ten'].status).toEqual('succeeded') + }) + + it('simple recursive map alias', async () => { + const res = await b.SimpleRecursiveMapAlias({ one: { two: { three: {} } } }) + expect(res).toEqual({ one: { two: { three: {} } } }) + }) + + it('simple recursive map alias', async () => { + const res = await b.SimpleRecursiveListAlias([[], [], [[]]]) + expect(res).toEqual([[], [], [[]]]) + }) + + it('recursive alias cycles', async () => { + const res = await b.RecursiveAliasCycle([[], [], [[]]]) + expect(res).toEqual([[], [], [[]]]) + }) + + it('json type alias cycle', async () => { + const data = { + number: 1, + string: 'test', + bool: true, + list: [1, 2, 3], + object: { number: 1, string: 'test', bool: true, list: [1, 2, 3] }, + json: { + number: 1, + string: 'test', + bool: true, + list: [1, 2, 3], + object: { number: 1, string: 'test', bool: true, list: [1, 2, 3] }, + }, + } + const res = await b.JsonTypeAliasCycle(data) + expect(res).toEqual(data) + expect(res.json.object.list).toEqual([1, 2, 3]) + }) }) it('should work for all outputs', async () => { diff --git a/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json b/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json index 6bcb9f63a..5034ded68 100644 --- a/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json +++ b/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json @@ -720,12 +720,24 @@ "name": "constant.numeric" }, "type_alias": { - "begin": "(type)\\s+(\\w+)", + "begin": "(type)\\s+(\\w+)\\s*(=)", "beginCaptures": { "1": { "name": "storage.type.declaration" }, - "2": { "name": "entity.name.type" } + "2": { "name": "entity.name.type" }, + "3": { "name": "keyword.operator.assignment" } }, - "patterns": [{ "include": "#comment" }] + "end": "(?=$|\\n)", + "patterns": [ + { "include": "#comment" }, + { + "begin": "(?<=\\=)\\s*", + "end": "(?=//|$|\\n)", + "patterns": [ + { "include": "#type_definition" }, + { "include": "#block_attribute" } + ] + } + ] }, "invalid_assignment": { "name": "invalid.illegal",