From e38e5e7f6c51323e89166ce0f78959a77e2ad064 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Mon, 2 Dec 2024 23:44:42 +0100 Subject: [PATCH 01/19] Allow structural recursion --- .../validation_pipeline/validations/cycle.rs | 66 +++++++++++++++---- .../class/recursive_type_aliases.baml | 6 -- engine/baml-lib/parser-database/src/lib.rs | 57 ++++++++++------ .../baml-lib/parser-database/src/types/mod.rs | 7 ++ .../parser-database/src/walkers/mod.rs | 11 ++++ engine/baml-lib/schema-ast/src/ast.rs | 8 +++ 6 files changed, 116 insertions(+), 39 deletions(-) diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index 89e0abb40..f7cf76ebb 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -14,21 +14,40 @@ use crate::validate::validation_pipeline::context::Context; /// Validates if the dependency graph contains one or more infinite cycles. pub(super) fn validate(ctx: &mut Context<'_>) { - // Solve cycles first. We need that information in case a class points to - // an unresolveble type alias. - let alias_cycles = report_infinite_cycles( - &ctx.db.type_alias_dependencies(), + // We'll check type alias cycles first. Just like Typescript, cycles are + // allowed only for maps and lists. We'll call such cycles "structural + // recursion". Anything else like nulls or unions won't terminate a cycle. + let structural_type_aliases = HashMap::from_iter(ctx.db.walk_type_aliases().map(|alias| { + let mut dependencies = HashSet::new(); + insert_required_alias_deps(alias.target(), ctx, &mut dependencies); + + (alias.id, dependencies) + })); + + // Based on the graph we've built with does not include the edges created + // by maps and lists, check the cycles and report them. + report_infinite_cycles( + &structural_type_aliases, ctx, "These aliases form a dependency cycle", ); - // First, build a graph of all the "required" dependencies represented as an + // In order to avoid infinite recursion when resolving types for class + // dependencies below, we'll compute the cycles of aliases including maps + // and lists so that the recursion can be stopped before entering a cycle. + let complete_alias_cycles = Tarjan::components(ctx.db.type_alias_dependencies()) + .iter() + .flatten() + .copied() + .collect(); + + // Now build a graph of all the "required" dependencies represented as an // adjacency list. We're only going to consider type dependencies that can // actually cause infinite recursion. Unions and optionals can stop the // recursion at any point, so they don't have to be part of the "dependency" // graph because technically an optional field doesn't "depend" on anything, // it can just be null. - let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| { + let class_dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| { let expr_block = &ctx.db.ast()[class.id]; // TODO: There's already a hash set that returns "dependencies" in @@ -44,12 +63,12 @@ pub(super) fn validate(ctx: &mut Context<'_>) { for field in &expr_block.fields { if let Some(field_type) = &field.expr { - insert_required_deps( + insert_required_class_deps( class.id, field_type, ctx, &mut dependencies, - &alias_cycles.iter().flatten().copied().collect(), + &complete_alias_cycles, ); } } @@ -58,7 +77,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { })); report_infinite_cycles( - &dependency_graph, + &class_dependency_graph, ctx, "These classes form a dependency cycle", ); @@ -106,7 +125,7 @@ where /// it reaches stack overflows with large inputs. /// /// TODO: Use a struct to keep all this state. Too many parameters already. -fn insert_required_deps( +fn insert_required_class_deps( id: TypeExpId, field: &FieldType, ctx: &Context<'_>, @@ -142,7 +161,7 @@ fn insert_required_deps( // We also have to stop recursion if we know the alias is // part of a cycle. if !alias_cycles.contains(&alias.id) { - insert_required_deps(id, alias.target(), ctx, deps, alias_cycles) + insert_required_class_deps(id, alias.target(), ctx, deps, alias_cycles) } } _ => {} @@ -159,7 +178,7 @@ fn insert_required_deps( let mut nested_deps = HashSet::new(); for f in field_types { - insert_required_deps(id, f, ctx, &mut nested_deps, alias_cycles); + insert_required_class_deps(id, f, ctx, &mut nested_deps, alias_cycles); // No nested deps found on this component, this makes the // union finite, so no need to go deeper. @@ -189,3 +208,26 @@ fn insert_required_deps( _ => {} } } + +/// Implemented a la TS, maps and lists are not included as edges. +fn insert_required_alias_deps( + field_type: &FieldType, + ctx: &Context<'_>, + required: &mut HashSet, +) { + match field_type { + FieldType::Symbol(_, ident, _) => { + if let Some(TypeWalker::TypeAlias(alias)) = ctx.db.find_type_by_str(ident.name()) { + required.insert(alias.id); + } + } + + FieldType::Union(_, field_types, ..) | FieldType::Tuple(_, field_types, ..) => { + for f in field_types { + insert_required_alias_deps(f, ctx, required); + } + } + + _ => {} + } +} diff --git a/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml b/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml index a200c3293..f7ff88ef0 100644 --- a/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml +++ b/engine/baml-lib/baml/tests/validation_files/class/recursive_type_aliases.baml @@ -73,12 +73,6 @@ type Map = map // 50 | // 51 | type EnterCycle = NoStop // | -// error: Error validating: These aliases form a dependency cycle: Map -// --> class/recursive_type_aliases.baml:56 -// | -// 55 | // RecursiveMap -// 56 | type Map = map -// | // error: Error validating: These classes form a dependency cycle: Recursive // --> class/recursive_type_aliases.baml:22 // | diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 4666df2ca..cb5ef44cb 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -159,27 +159,42 @@ impl ParserDatabase { // instead of strings (class names). That requires less conversions when // working with the graph. Once the work is done, IDs can be converted // to names where needed. - let finite_cycles = Tarjan::components(&HashMap::from_iter( - self.types.class_dependencies.iter().map(|(id, deps)| { - let deps = - HashSet::from_iter(deps.iter().filter_map( - |dep| match self.find_type_by_str(dep) { - Some(TypeWalker::Class(cls)) => Some(cls.id), - Some(TypeWalker::Enum(_)) => None, - Some(TypeWalker::TypeAlias(_)) => None, - None => panic!("Unknown class `{dep}`"), - }, - )); - (*id, deps) - }), - )); + let mut resolved_dependency_graph = HashMap::new(); + + for (id, deps) in self.types.class_dependencies.iter() { + let mut resolved_deps = HashSet::new(); + + for dep in deps { + match self.find_type_by_str(dep) { + Some(TypeWalker::Class(cls)) => { + resolved_deps.insert(cls.id); + } + Some(TypeWalker::Enum(_)) => {} + // Gotta resolve type aliases. + Some(TypeWalker::TypeAlias(alias)) => { + resolved_deps.extend(alias.resolved().flat_idns().iter().map(|ident| { + match self.find_type_by_str(ident.name()) { + Some(TypeWalker::Class(cls)) => cls.id, + Some(TypeWalker::Enum(_)) => { + panic!("Enums are not allowed in type aliases") + } + Some(TypeWalker::TypeAlias(alias)) => { + panic!("Alias should be resolved at this point") + } + None => panic!("Unknown class `{dep}`"), + } + })) + } + None => panic!("Unknown class `{dep}`"), + } + } + + resolved_dependency_graph.insert(*id, resolved_deps); + } - // Inject finite cycles into parser DB. This will then be passed into - // the IR and then into the Jinja output format. - self.types.finite_recursive_cycles = finite_cycles - .into_iter() - .map(|cycle| cycle.into_iter().collect()) - .collect(); + // Find the cycles and inject them into parser DB. This will then be + // passed into the IR and then into the Jinja output format. + self.types.finite_recursive_cycles = Tarjan::components(&resolved_dependency_graph); // Fully resolve function dependencies. let extends = self @@ -308,7 +323,7 @@ mod test { } fn assert_finite_cycles(baml: &'static str, expected: &[&[&str]]) -> Result<(), Diagnostics> { - let mut db = parse(baml)?; + let db = parse(baml)?; assert_eq!( db.finite_recursive_cycles() diff --git a/engine/baml-lib/parser-database/src/types/mod.rs b/engine/baml-lib/parser-database/src/types/mod.rs index 4b9dd23f2..abc89e588 100644 --- a/engine/baml-lib/parser-database/src/types/mod.rs +++ b/engine/baml-lib/parser-database/src/types/mod.rs @@ -276,6 +276,13 @@ pub(super) struct Types { /// Merge-Find Set or something like that. pub(super) finite_recursive_cycles: Vec>, + /// Contains recursive type aliases. + /// + /// Recursive type aliases are a little bit trickier than recursive classes + /// because the termination condition is tied to lists and maps only. Nulls + /// and unions won't allow type alias cycles to be resolved. + pub(super) structural_recursive_type_aliases: Vec>, + pub(super) function: HashMap, pub(super) client_properties: HashMap, diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index ae705b111..59a31f821 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -209,6 +209,17 @@ impl<'db> crate::ParserDatabase { }) } + /// Walk all the type aliases in the AST. + pub fn walk_type_aliases(&self) -> impl Iterator> { + self.ast() + .iter_tops() + .filter_map(|(top_id, _)| top_id.as_type_alias_id()) + .map(move |top_id| Walker { + db: self, + id: top_id, + }) + } + /// Walk all template strings in the schema. pub fn walk_templates(&self) -> impl Iterator> { self.ast() diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index 24b0b8a4f..d34b3ecad 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -191,6 +191,14 @@ impl TopId { } } + /// Try to interpret the top as a type alias. + pub fn as_type_alias_id(self) -> Option { + match self { + TopId::TypeAlias(id) => Some(id), + _ => None, + } + } + /// Try to interpret the top as a function. pub fn as_function_id(self) -> Option { match self { From 794b3f42f2a1473aaddf13ee9ad37cdd064637b4 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 3 Dec 2024 00:10:24 +0100 Subject: [PATCH 02/19] Pass structural cycles to IR --- engine/baml-lib/baml-core/src/ir/repr.rs | 13 +++++++++++++ engine/baml-lib/parser-database/src/lib.rs | 5 +++++ engine/baml-lib/parser-database/src/types/mod.rs | 2 +- engine/baml-lib/parser-database/src/walkers/mod.rs | 12 +++++++++++- 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index 64431059c..30cc728e4 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -35,6 +35,9 @@ pub struct IntermediateRepr { /// Strongly connected components of the dependency graph (finite cycles). finite_recursive_cycles: Vec>, + /// Type alias cycles introduced by lists and maps. + structural_recursive_alias_cycles: Vec>, + configuration: Configuration, } @@ -53,6 +56,7 @@ impl IntermediateRepr { enums: vec![], classes: vec![], finite_recursive_cycles: vec![], + structural_recursive_alias_cycles: vec![], functions: vec![], clients: vec![], retry_policies: vec![], @@ -174,6 +178,15 @@ impl IntermediateRepr { .collect() }) .collect(), + structural_recursive_alias_cycles: db + .structural_recursive_alias_cycles() + .iter() + .map(|ids| { + ids.iter() + .map(|id| db.ast()[*id].name().to_string()) + .collect() + }) + .collect(), functions: db .walk_functions() .map(|e| e.node(db)) diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index cb5ef44cb..570f99956 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -138,6 +138,11 @@ impl ParserDatabase { self.types.resolved_type_aliases.insert(*alias_id, resolved); } + // Cycles left here after cycle validation are allowed. Basically lists + // and maps can introduce cycles. + self.types.structural_recursive_alias_cycles = + Tarjan::components(&self.types.type_alias_dependencies); + // NOTE: Class dependency cycles are already checked at // baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs // diff --git a/engine/baml-lib/parser-database/src/types/mod.rs b/engine/baml-lib/parser-database/src/types/mod.rs index abc89e588..683082581 100644 --- a/engine/baml-lib/parser-database/src/types/mod.rs +++ b/engine/baml-lib/parser-database/src/types/mod.rs @@ -281,7 +281,7 @@ pub(super) struct Types { /// Recursive type aliases are a little bit trickier than recursive classes /// because the termination condition is tied to lists and maps only. Nulls /// and unions won't allow type alias cycles to be resolved. - pub(super) structural_recursive_type_aliases: Vec>, + pub(super) structural_recursive_alias_cycles: Vec>, pub(super) function: HashMap, diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index 59a31f821..a6154a1cf 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -22,7 +22,9 @@ pub use configuration::*; use either::Either; pub use field::*; pub use function::FunctionWalker; -use internal_baml_schema_ast::ast::{FieldType, Identifier, TopId, TypeExpId, WithName}; +use internal_baml_schema_ast::ast::{ + FieldType, Identifier, TopId, TypeAliasId, TypeExpId, WithName, +}; pub use r#class::*; pub use r#enum::*; pub use template_string::TemplateStringWalker; @@ -142,6 +144,14 @@ impl<'db> crate::ParserDatabase { &self.types.finite_recursive_cycles } + /// Set of all aliases that are part of a structural cycle. + /// + /// A structural cycle is created through a map or list, which introduce one + /// level of indirection. + pub fn structural_recursive_alias_cycles(&self) -> &[Vec] { + &self.types.structural_recursive_alias_cycles + } + /// Returns the resolved aliases map. pub fn resolved_type_alias_by_name(&self, alias: &str) -> Option<&FieldType> { match self.find_type_by_str(alias) { From 98697072c033bf48c8490fc1a10ade17039059e6 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 3 Dec 2024 00:50:14 +0100 Subject: [PATCH 03/19] Test structural recursion finder --- engine/baml-lib/parser-database/src/lib.rs | 57 +++++++++++++++++++ engine/baml-lib/parser-database/src/tarjan.rs | 9 ++- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 570f99956..f938237ac 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -344,6 +344,26 @@ mod test { Ok(()) } + fn assert_structural_alias_cycles( + baml: &'static str, + expected: &[&[&str]], + ) -> Result<(), Diagnostics> { + let db = parse(baml)?; + + assert_eq!( + db.structural_recursive_alias_cycles() + .iter() + .map(|ids| Vec::from_iter(ids.iter().map(|id| db.ast()[*id].name().to_string()))) + .collect::>(), + expected + .iter() + .map(|cycle| Vec::from_iter(cycle.iter().map(ToString::to_string))) + .collect::>() + ); + + Ok(()) + } + #[test] fn find_simple_recursive_class() -> Result<(), Diagnostics> { assert_finite_cycles( @@ -602,4 +622,41 @@ mod test { Ok(()) } + + #[test] + fn find_basic_map_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles( + "type RecursiveMap = map", + &[&["RecursiveMap"]], + ) + } + + #[test] + fn find_basic_list_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles("type A = A[]", &[&["A"]]) + } + + #[test] + fn find_long_list_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles( + r#" + type A = B + type B = C + type C = A[] + "#, + &[&["A", "B", "C"]], + ) + } + + #[test] + fn find_intricate_structural_cycle() -> Result<(), Diagnostics> { + assert_structural_alias_cycles( + r#" + type JsonValue = string | int | float | bool | null | JsonArray | JsonObject + type JsonArray = JsonValue[] + type JsonObject = map + "#, + &[&["JsonValue", "JsonArray", "JsonObject"]], + ) + } } diff --git a/engine/baml-lib/parser-database/src/tarjan.rs b/engine/baml-lib/parser-database/src/tarjan.rs index 55d5af18f..b7eccb39d 100644 --- a/engine/baml-lib/parser-database/src/tarjan.rs +++ b/engine/baml-lib/parser-database/src/tarjan.rs @@ -6,6 +6,7 @@ use std::{ cmp, collections::{HashMap, HashSet}, + fmt::Debug, hash::Hash, }; @@ -126,8 +127,14 @@ impl<'g, V: Eq + Ord + Hash + Copy> Tarjan<'g, V> { self.index += 1; self.stack.push(node_id); + // TODO: @antoniosarosi: HashSet is random, won't always iterate in the + // same order. Fix this with IndexSet or something, we really don't want + // to sort this every single time. + let mut successors = Vec::from_iter(&self.graph[&node_id]); + successors.sort(); + // Visit neighbors to find strongly connected components. - for successor_id in &self.graph[&node_id] { + for successor_id in successors { // Grab owned state to circumvent borrow checker. let mut successor = *&self.state[successor_id]; if successor.index == Self::UNVISITED { From 72ea8cbc5b8ea3f373a42b776eb977c350c2d916 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Mon, 16 Dec 2024 16:23:07 +0100 Subject: [PATCH 04/19] Implement codegen for Python type aliases --- .../src/ir/ir_helpers/to_baml_arg.rs | 20 +++++- .../baml-lib/baml-core/src/ir/json_schema.rs | 4 +- engine/baml-lib/baml-core/src/ir/repr.rs | 69 ++++++++++++------- .../baml-lib/baml-types/src/field_type/mod.rs | 22 +++--- .../jinja-runtime/src/output_format/types.rs | 60 +++++++++++++--- .../src/deserializer/coercer/field_type.rs | 5 +- engine/baml-lib/jsonish/src/tests/mod.rs | 3 +- engine/baml-lib/parser-database/src/lib.rs | 59 ++++++++++++---- engine/baml-lib/parser-database/src/tarjan.rs | 3 +- .../baml-lib/parser-database/src/types/mod.rs | 18 ++++- .../prompt_renderer/render_output_format.rs | 24 +++++-- engine/language_client_codegen/src/openapi.rs | 14 +++- .../src/python/generate_types.rs | 37 +++++++++- .../language_client_codegen/src/python/mod.rs | 4 +- .../src/python/templates/types.py.j2 | 7 +- .../src/ruby/field_type.rs | 3 +- .../src/ruby/generate_types.rs | 3 +- .../src/typescript/mod.rs | 2 +- .../functions/output/type-aliases.baml | 13 ++++ .../python/baml_client/async_client.py | 53 ++++++++++++++ integ-tests/python/baml_client/inlinedbaml.py | 2 +- integ-tests/python/baml_client/sync_client.py | 53 ++++++++++++++ integ-tests/python/baml_client/types.py | 4 +- integ-tests/python/tests/test_functions.py | 5 ++ integ-tests/ruby/baml_client/client.rb | 67 ++++++++++++++++++ integ-tests/ruby/baml_client/inlined.rb | 2 +- .../typescript/baml_client/async_client.ts | 58 ++++++++++++++++ .../typescript/baml_client/inlinedbaml.ts | 2 +- .../typescript/baml_client/sync_client.ts | 25 +++++++ 29 files changed, 554 insertions(+), 87 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index e2ae93083..4d8bb08d3 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -263,8 +263,24 @@ impl ArgCoercer { Err(()) } }, - (FieldType::Alias { resolution, .. }, _) => { - self.coerce_arg(ir, &resolution, value, scope) + // TODO: Is this even possible? + (FieldType::RecursiveTypeAlias(name), _) => { + let mut maybe_coerced = None; + // TODO: Fix this O(n) + for cycle in ir.structural_recursive_alias_cycles().iter() { + if let Some(target) = cycle.get(name) { + maybe_coerced = Some(self.coerce_arg(ir, target, value, scope)?); + break; + } + } + + match maybe_coerced { + Some(coerced) => Ok(coerced), + None => { + scope.push_error(format!("Recursive type alias {} not found", name)); + Err(()) + } + } } (FieldType::List(item), _) => match value { BamlValue::List(arr) => { diff --git a/engine/baml-lib/baml-core/src/ir/json_schema.rs b/engine/baml-lib/baml-core/src/ir/json_schema.rs index e0c7fa886..bccf278c6 100644 --- a/engine/baml-lib/baml-core/src/ir/json_schema.rs +++ b/engine/baml-lib/baml-core/src/ir/json_schema.rs @@ -156,10 +156,12 @@ impl WithJsonSchema for FieldType { FieldType::Class(name) | FieldType::Enum(name) => json!({ "$ref": format!("#/definitions/{}", name), }), - FieldType::Alias { resolution, .. } => resolution.json_schema(), FieldType::Literal(v) => json!({ "const": v.to_string(), }), + FieldType::RecursiveTypeAlias(_) => json!({ + "type": ["number", "string", "boolean", "object", "array", "null"] + }), FieldType::Primitive(t) => match t { TypeValue::String => json!({ "type": "string", diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index faeecf45f..292e30922 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -36,7 +36,7 @@ pub struct IntermediateRepr { finite_recursive_cycles: Vec>, /// Type alias cycles introduced by lists and maps. - structural_recursive_alias_cycles: Vec>, + structural_recursive_alias_cycles: Vec>, configuration: Configuration, } @@ -102,6 +102,10 @@ impl IntermediateRepr { &self.finite_recursive_cycles } + pub fn structural_recursive_alias_cycles(&self) -> &[IndexMap] { + &self.structural_recursive_alias_cycles + } + pub fn walk_enums(&self) -> impl ExactSizeIterator>> { self.enums.iter().map(|e| Walker { db: self, item: e }) } @@ -110,6 +114,14 @@ impl IntermediateRepr { self.classes.iter().map(|e| Walker { db: self, item: e }) } + // TODO: Exact size Iterator + Node<>? + pub fn walk_alias_cycles(&self) -> impl Iterator> { + self.structural_recursive_alias_cycles + .iter() + .flatten() + .map(|e| Walker { db: self, item: e }) + } + pub fn function_names(&self) -> impl ExactSizeIterator { self.functions.iter().map(|f| f.elem.name()) } @@ -172,15 +184,18 @@ impl IntermediateRepr { .collect() }) .collect(), - structural_recursive_alias_cycles: db - .structural_recursive_alias_cycles() - .iter() - .map(|ids| { - ids.iter() - .map(|id| db.ast()[*id].name().to_string()) - .collect() - }) - .collect(), + structural_recursive_alias_cycles: { + let mut recursive_aliases = vec![]; + for cycle in db.structural_recursive_alias_cycles() { + let mut component = IndexMap::new(); + for id in cycle { + let alias = &db.ast()[*id]; + component.insert(alias.name().to_string(), alias.value.repr(db)?); + } + recursive_aliases.push(component); + } + recursive_aliases + }, functions: db .walk_functions() .map(|e| e.node(db)) @@ -432,11 +447,18 @@ impl WithRepr for ast::FieldType { _ => base_type, } } - Some(TypeWalker::TypeAlias(alias_walker)) => FieldType::Alias { - name: alias_walker.name().to_owned(), - target: Box::new(alias_walker.target().repr(db)?), - resolution: Box::new(alias_walker.resolved().repr(db)?), - }, + Some(TypeWalker::TypeAlias(alias_walker)) => { + if db + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| cycle.contains(&alias_walker.id)) + { + FieldType::RecursiveTypeAlias(alias_walker.name().to_string()) + } else { + alias_walker.resolved().to_owned().repr(db)? + } + } + None => return Err(anyhow!("Field type uses unresolvable local identifier")), }, arity, @@ -1237,11 +1259,7 @@ mod tests { let class = ir.find_class("Test").unwrap(); let alias = class.find_field("field").unwrap(); - let FieldType::Alias { resolution, .. } = alias.r#type() else { - panic!("expected alias type, found {:?}", alias.r#type()); - }; - - assert_eq!(**resolution, FieldType::Primitive(TypeValue::Int)); + assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int)); } #[test] @@ -1262,12 +1280,11 @@ mod tests { let class = ir.find_class("Test").unwrap(); let alias = class.find_field("field").unwrap(); - let FieldType::Alias { resolution, .. } = alias.r#type() else { - panic!("expected alias type, found {:?}", alias.r#type()); - }; - - let FieldType::Constrained { base, constraints } = &**resolution else { - panic!("expected resolved constrained type, found {:?}", resolution); + let FieldType::Constrained { base, constraints } = alias.r#type() else { + panic!( + "expected resolved constrained type, found {:?}", + alias.r#type() + ); }; assert_eq!(constraints.len(), 3); diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index 909f20426..8952ba11f 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -85,14 +85,7 @@ pub enum FieldType { Union(Vec), Tuple(Vec), Optional(Box), - Alias { - /// Name of the alias. - name: String, - /// Type that the alias points to. - target: Box, - /// Final resolved type (an alias can point to other aliases). - resolution: Box, - }, + RecursiveTypeAlias(String), Constrained { base: Box, constraints: Vec, @@ -103,8 +96,9 @@ pub enum FieldType { impl std::fmt::Display for FieldType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - FieldType::Enum(name) | FieldType::Class(name) => write!(f, "{name}"), - FieldType::Alias { name, .. } => write!(f, "{name}"), + FieldType::Enum(name) + | FieldType::Class(name) + | FieldType::RecursiveTypeAlias(name) => write!(f, "{name}"), FieldType::Primitive(t) => write!(f, "{t}"), FieldType::Literal(v) => write!(f, "{v}"), FieldType::Union(choices) => { @@ -187,8 +181,6 @@ impl FieldType { } match (self, other) { - (FieldType::Alias { resolution, .. }, _) => resolution.is_subtype_of(other), - (_, FieldType::Alias { resolution, .. }) => self.is_subtype_of(resolution), (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, (FieldType::Optional(self_item), FieldType::Optional(other_item)) => { self_item.is_subtype_of(other_item) @@ -207,6 +199,12 @@ impl FieldType { } (FieldType::Map(_, _), _) => false, + // TODO: is it necessary to check if the alias is part of the same + // cycle? + (FieldType::RecursiveTypeAlias(_), _) | (_, FieldType::RecursiveTypeAlias(_)) => { + self == other + } + ( FieldType::Constrained { base: self_base, diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index f3e357ba2..ade317e53 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -58,6 +58,7 @@ pub struct OutputFormatContent { pub enums: Arc>, pub classes: Arc>, recursive_classes: Arc>, + structural_recursive_aliases: Arc>, pub target: FieldType, } @@ -67,6 +68,8 @@ pub struct Builder { classes: Vec, /// Order matters for this one. recursive_classes: IndexSet, + /// Recursive aliases introduced maps and lists. + structural_recursive_aliases: IndexSet, target: FieldType, } @@ -76,6 +79,7 @@ impl Builder { enums: vec![], classes: vec![], recursive_classes: IndexSet::new(), + structural_recursive_aliases: IndexSet::new(), target, } } @@ -95,6 +99,14 @@ impl Builder { self } + pub fn structural_recursive_aliases( + mut self, + structural_recursive_aliases: IndexSet, + ) -> Self { + self.structural_recursive_aliases = structural_recursive_aliases; + self + } + pub fn target(mut self, target: FieldType) -> Self { self.target = target; self @@ -115,6 +127,9 @@ impl Builder { .collect(), ), recursive_classes: Arc::new(self.recursive_classes.into_iter().collect()), + structural_recursive_aliases: Arc::new( + self.structural_recursive_aliases.into_iter().collect(), + ), target: self.target, } } @@ -297,6 +312,7 @@ fn indefinite_article_a_or_an(word: &str) -> &str { struct RenderState { hoisted_enums: IndexSet, + hoisted_aliases: IndexMap, } impl OutputFormatContent { @@ -333,8 +349,13 @@ impl OutputFormatContent { Some(format!("Answer in JSON using this {type_prefix}:{end}")) } - FieldType::Alias { resolution, .. } => { - auto_prefix(&resolution, options, output_format_content) + FieldType::RecursiveTypeAlias(_) => { + let type_prefix = match &options.hoisted_class_prefix { + RenderSetting::Always(prefix) if !prefix.is_empty() => prefix, + _ => RenderOptions::DEFAULT_TYPE_PREFIX_IN_RENDER_MESSAGE, + }; + + Some(format!("Answer in JSON using this {type_prefix}: ")) } FieldType::List(_) => Some(String::from( "Answer with a JSON Array using this schema:\n", @@ -487,12 +508,7 @@ impl OutputFormatContent { } .to_string() } - FieldType::Alias { resolution, .. } => self.render_possibly_recursive_type( - options, - &resolution, - render_state, - group_hoisted_literals, - )?, + FieldType::RecursiveTypeAlias(name) => name.to_owned(), FieldType::List(inner) => { let is_recursive = match inner.as_ref() { FieldType::Class(nested_class) => self.recursive_classes.contains(nested_class), @@ -566,6 +582,7 @@ impl OutputFormatContent { let mut render_state = RenderState { hoisted_enums: IndexSet::new(), + hoisted_aliases: IndexMap::new(), }; let mut message = match &self.target { @@ -597,6 +614,7 @@ impl OutputFormatContent { })); let mut class_definitions = Vec::new(); + let mut type_alias_definitions = Vec::new(); // Hoist recursive classes. The render_state struct doesn't need to // contain these classes because we already know that we're gonna hoist @@ -618,6 +636,27 @@ impl OutputFormatContent { }); } + // Yeah gotta love the borrow checker. // TODO: @antonio + let hoisted_aliases = + std::mem::replace(&mut render_state.hoisted_aliases, Default::default()); + + for (alias, target) in hoisted_aliases.iter() { + let recursive_pointer = self.inner_type_render( + &options, + // TODO: @antonio This code sucks beyond measure. Fix this. + &target, + &mut render_state, + false, + )?; + + type_alias_definitions.push(match &options.hoisted_class_prefix { + RenderSetting::Always(prefix) if !prefix.is_empty() => { + format!("{prefix} {alias} = {recursive_pointer}") + } + _ => format!("{alias} = {recursive_pointer}"), + }); + } + let mut output = String::new(); if !enum_definitions.is_empty() { @@ -630,6 +669,11 @@ impl OutputFormatContent { output.push_str("\n\n"); } + if !type_alias_definitions.is_empty() { + output.push_str(&type_alias_definitions.join("\n\n")); + output.push_str("\n\n"); + } + if let Some(p) = prefix { output.push_str(&p); } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index d984d00a2..d825dfd28 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -79,7 +79,8 @@ impl TypeCoercer for FieldType { FieldType::Enum(e) => IrRef::Enum(e).coerce(ctx, target, value), FieldType::Literal(l) => l.coerce(ctx, target, value), FieldType::Class(c) => IrRef::Class(c).coerce(ctx, target, value), - FieldType::Alias { resolution, .. } => resolution.coerce(ctx, target, value), + // TODO: How to handle this? + FieldType::RecursiveTypeAlias(_) => todo!("Alias with no resolution"), FieldType::List(_) => coerce_array(ctx, self, value), FieldType::Union(_) => coerce_union(ctx, self, value), FieldType::Optional(_) => coerce_optional(ctx, self, value), @@ -165,7 +166,7 @@ impl DefaultValue for FieldType { FieldType::Enum(e) => None, FieldType::Literal(_) => None, FieldType::Class(_) => None, - FieldType::Alias { resolution, .. } => resolution.default_value(error), + FieldType::RecursiveTypeAlias(_) => None, FieldType::List(_) => Some(BamlValueWithFlags::List(get_flags(), Vec::new())), FieldType::Union(items) => items.iter().find_map(|i| i.default_value(error)), FieldType::Primitive(TypeValue::Null) | FieldType::Optional(_) => { diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 1e36255b1..41a22ef4c 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -230,7 +230,8 @@ fn relevant_data_models<'a>( }); } } - (FieldType::Alias { resolution, .. }, _) => start.push(*resolution.to_owned()), + // TODO: Add structural aliases here. + (FieldType::RecursiveTypeAlias(_), _) => {} (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _constraints) => {} (FieldType::Constrained { .. }, _) => { diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 1d76d6504..c90c141ba 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -130,6 +130,11 @@ impl ParserDatabase { } fn finalize_dependencies(&mut self, diag: &mut Diagnostics) { + // Cycles left here after cycle validation are allowed. Basically lists + // and maps can introduce cycles. + self.types.structural_recursive_alias_cycles = + Tarjan::components(&self.types.type_alias_dependencies); + // Resolve type aliases. // Cycles are already validated so this should not stack overflow and // it should find the final type. @@ -138,11 +143,6 @@ impl ParserDatabase { self.types.resolved_type_aliases.insert(*alias_id, resolved); } - // Cycles left here after cycle validation are allowed. Basically lists - // and maps can introduce cycles. - self.types.structural_recursive_alias_cycles = - Tarjan::components(&self.types.type_alias_dependencies); - // NOTE: Class dependency cycles are already checked at // baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs // @@ -248,17 +248,21 @@ impl ParserDatabase { // For aliases just get the resolved identifiers and // push them into the stack. If we find resolved classes we'll - // add their dependencies as well. Note that this is not - // "recursive" per se because type aliases can never "resolve" - // to other type aliases. + // add their dependencies as well. Some(TypeWalker::TypeAlias(walker)) => { - stack.extend(walker.resolved().flat_idns().iter().map(|ident| { + stack.extend(walker.resolved().flat_idns().iter().filter_map(|ident| { // Add the resolved name itself to the deps. collected_deps.insert(ident.name().to_owned()); - // Push the resolved name into the stack in case - // it's a class, we'll have to add its deps as - // well. - ident.name() + // If the type is an alias then don't recurse. + if self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| cycle.contains(&walker.id)) + { + None + } else { + Some(ident.name()) + } })) } @@ -661,6 +665,9 @@ mod test { "#, &[&["JsonValue", "JsonArray", "JsonObject"]], ) + } + + #[test] fn merged_alias_attrs() -> Result<(), Diagnostics> { #[rustfmt::skip] let db = parse(r#" @@ -674,4 +681,30 @@ mod test { Ok(()) } + + // Resolution of aliases here at the parser database level doesn't matter + // as much because there's no notion of "classes" or "enums", it's just + // "symbols". But the resolve type function should not stack overflow + // anyway. + #[test] + fn resolve_simple_structural_recursive_alias() -> Result<(), Diagnostics> { + #[rustfmt::skip] + let db = parse(r#" + type A = A[] + "#)?; + + let resolved = db.resolved_type_alias_by_name("A").unwrap(); + + let FieldType::List(_, inner, ..) = resolved else { + panic!("expected a list type, got {resolved:?}"); + }; + + let FieldType::Symbol(_, ident, _) = &**inner else { + panic!("expected a symbol type, got {inner:?}"); + }; + + assert_eq!(ident.name(), "A"); + + Ok(()) + } } diff --git a/engine/baml-lib/parser-database/src/tarjan.rs b/engine/baml-lib/parser-database/src/tarjan.rs index 32a3bc4db..a80a6a52d 100644 --- a/engine/baml-lib/parser-database/src/tarjan.rs +++ b/engine/baml-lib/parser-database/src/tarjan.rs @@ -129,7 +129,8 @@ impl<'g, V: Eq + Ord + Hash + Copy> Tarjan<'g, V> { // TODO: @antoniosarosi: HashSet is random, won't always iterate in the // same order. Fix this with IndexSet or something, we really don't want - // to sort this every single time. + // to sort this every single time. Also order only matters for tests, we + // can do `if cfg!(test)` or something. let mut successors = Vec::from_iter(&self.graph[&node_id]); successors.sort(); diff --git a/engine/baml-lib/parser-database/src/types/mod.rs b/engine/baml-lib/parser-database/src/types/mod.rs index 58a6f1aa2..1978fd078 100644 --- a/engine/baml-lib/parser-database/src/types/mod.rs +++ b/engine/baml-lib/parser-database/src/types/mod.rs @@ -416,8 +416,22 @@ pub fn resolve_type_alias(field_type: &FieldType, db: &ParserDatabase) -> FieldT let mut resolved = match db.types.resolved_type_aliases.get(alias_id) { // Check if we can avoid deeper recursion. Some(already_resolved) => already_resolved.to_owned(), - // No luck, recurse. - None => resolve_type_alias(&db.ast[*alias_id].value, db), + + // No luck, check if the type is resolvable. + None => { + // TODO: O(n) + if db + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| cycle.contains(alias_id)) + { + // Not resolvable, part of a cycle. + field_type.to_owned() + } else { + // Maybe resolvable, recurse deeper. + resolve_type_alias(&db.ast[*alias_id].value, db) + } + } }; // Sync arity. Basically stuff like: diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 1e8e8c57f..05b7fbd17 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -18,12 +18,14 @@ pub fn render_output_format( ctx: &RuntimeContext, output: &FieldType, ) -> Result { - let (enums, classes, recursive_classes) = relevant_data_models(ir, output, ctx)?; + let (enums, classes, recursive_classes, structural_recursive_aliases) = + relevant_data_models(ir, output, ctx)?; Ok(OutputFormatContent::target(output.clone()) .enums(enums) .classes(classes) .recursive_classes(recursive_classes) + .structural_recursive_aliases(structural_recursive_aliases) .build()) } @@ -210,11 +212,12 @@ fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, ctx: &RuntimeContext, -) -> Result<(Vec, Vec, IndexSet)> { +) -> Result<(Vec, Vec, IndexSet, IndexSet)> { let mut checked_types = HashSet::new(); let mut enums = Vec::new(); let mut classes = Vec::new(); let mut recursive_classes = IndexSet::new(); + let mut structural_recursive_aliases = IndexSet::new(); let mut start: Vec = vec![output.clone()]; let eval_ctx = ctx.eval_ctx(false); @@ -376,8 +379,14 @@ fn relevant_data_models<'a>( recursive_classes.insert(cls.to_owned()); } } - (FieldType::Alias { resolution, .. }, _) => { - start.push(*resolution.clone()); + (FieldType::RecursiveTypeAlias(name), _) => { + // TODO: Same O(n) problem as above. + // TODO: Do we need the type information or just the name? + for cycle in ir.structural_recursive_alias_cycles() { + if cycle.contains_key(name) { + structural_recursive_aliases.extend(cycle.keys().map(ToOwned::to_owned)); + } + } } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _) => {} @@ -387,7 +396,12 @@ fn relevant_data_models<'a>( } } - Ok((enums, classes, recursive_classes)) + Ok(( + enums, + classes, + recursive_classes, + structural_recursive_aliases, + )) } #[cfg(test)] diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index 4dc6009a1..33020edcd 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -539,7 +539,15 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { r#ref: format!("#/components/schemas/{}", name), }, }, - FieldType::Alias { resolution, .. } => resolution.to_type_spec(_ir)?, + FieldType::RecursiveTypeAlias(_) => TypeSpecWithMeta { + meta: TypeMetadata { + title: None, + r#enum: None, + r#const: None, + nullable: false, + }, + type_spec: TypeSpec::AnyValue { any_of: vec![] }, + }, FieldType::Literal(v) => TypeSpecWithMeta { meta: TypeMetadata { title: None, @@ -705,6 +713,10 @@ enum TypeSpec { #[serde(rename = "oneOf", alias = "oneOf")] one_of: Vec, }, + AnyValue { + #[serde(rename = "anyOf", alias = "anyOf")] + any_of: Vec, + }, } #[derive(Clone, Debug, Serialize)] diff --git a/engine/language_client_codegen/src/python/generate_types.rs b/engine/language_client_codegen/src/python/generate_types.rs index 882110a51..428b25d70 100644 --- a/engine/language_client_codegen/src/python/generate_types.rs +++ b/engine/language_client_codegen/src/python/generate_types.rs @@ -7,7 +7,7 @@ use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes}; use super::python_language_features::ToPython; use internal_baml_core::ir::{ - repr::{Docstring, IntermediateRepr}, + repr::{Docstring, IntermediateRepr, Walker}, ClassWalker, EnumWalker, FieldType, IRHelper, }; @@ -16,6 +16,7 @@ use internal_baml_core::ir::{ pub(crate) struct PythonTypes<'ir> { enums: Vec>, classes: Vec>, + structural_recursive_alias_cycles: Vec>, } #[derive(askama::Template)] @@ -41,6 +42,11 @@ struct PythonClass<'ir> { dynamic: bool, } +struct PythonTypeAlias<'ir> { + name: Cow<'ir, str>, + target: String, +} + #[derive(askama::Template)] #[template(path = "partial_types.py.j2", escape = "none")] pub(crate) struct PythonStreamTypes<'ir> { @@ -66,6 +72,10 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonT Ok(PythonTypes { enums: ir.walk_enums().map(PythonEnum::from).collect::>(), classes: ir.walk_classes().map(PythonClass::from).collect::>(), + structural_recursive_alias_cycles: ir + .walk_alias_cycles() + .map(PythonTypeAlias::from) + .collect::>(), }) } } @@ -126,6 +136,21 @@ impl<'ir> From> for PythonClass<'ir> { } } +// TODO: Define AliasWalker to simplify type. +impl<'ir> From> for PythonTypeAlias<'ir> { + fn from( + Walker { + db, + item: (name, target), + }: Walker<(&'ir String, &'ir FieldType)>, + ) -> Self { + PythonTypeAlias { + name: Cow::Borrowed(name), + target: target.to_type_ref(db), + } + } +} + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonStreamTypes<'ir> { type Error = anyhow::Error; @@ -219,7 +244,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { format!("\"{name}\"") } } - FieldType::Alias { resolution, .. } => resolution.to_type_ref(ir), + FieldType::RecursiveTypeAlias(name) => format!("\"{name}\""), FieldType::Literal(value) => to_python_literal(value), FieldType::Class(name) => format!("\"{name}\""), FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)), @@ -275,7 +300,13 @@ impl ToTypeReferenceInTypeDefinition for FieldType { format!("Optional[types.{name}]") } } - FieldType::Alias { resolution, .. } => resolution.to_partial_type_ref(ir, wrapped), + FieldType::RecursiveTypeAlias(name) => { + if wrapped { + format!("\"{name}\"") + } else { + format!("Optional[\"{name}\"]") + } + } FieldType::Literal(value) => to_python_literal(value), FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true)), FieldType::Map(key, value) => { diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index 082fc5d3d..06fbbe7d8 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -201,7 +201,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } FieldType::Literal(value) => to_python_literal(value), - FieldType::Alias { resolution, .. } => resolution.to_type_ref(ir, _with_checked), + FieldType::RecursiveTypeAlias(name) => format!("types.{name}"), FieldType::Class(name) => format!("types.{name}"), FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir, _with_checked)), FieldType::Map(key, value) => { @@ -256,7 +256,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } FieldType::Class(name) => format!("partial_types.{name}"), - FieldType::Alias { resolution, .. } => resolution.to_partial_type_ref(ir, with_checked), + FieldType::RecursiveTypeAlias(name) => format!("types.{name}"), FieldType::Literal(value) => to_python_literal(value), FieldType::List(inner) => { format!("List[{}]", inner.to_partial_type_ref(ir, with_checked)) diff --git a/engine/language_client_codegen/src/python/templates/types.py.j2 b/engine/language_client_codegen/src/python/templates/types.py.j2 index 86b776db3..bbc7a7e9a 100644 --- a/engine/language_client_codegen/src/python/templates/types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/types.py.j2 @@ -2,7 +2,7 @@ import baml_py from enum import Enum from pydantic import BaseModel, ConfigDict -from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union +from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union, TypeAlias T = TypeVar('T') @@ -59,3 +59,8 @@ class {{cls.name}}(BaseModel): {%- endif %} {%- endfor %} {% endfor %} + +{#- Type Aliases -#} +{% for alias in structural_recursive_alias_cycles %} +{{alias.name}}: TypeAlias = {{alias.target}} +{% endfor %} diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index c17c7e538..91e37e33e 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -9,7 +9,8 @@ impl ToRuby for FieldType { match self { FieldType::Class(name) => format!("Baml::Types::{}", name.clone()), FieldType::Enum(name) => format!("T.any(Baml::Types::{}, String)", name.clone()), - FieldType::Alias { resolution, .. } => resolution.to_ruby(), + // TODO: Can we define recursive aliases in Ruby with Sorbet? + FieldType::RecursiveTypeAlias(_name) => "T.anything".to_string(), // TODO: Temporary solution until we figure out Ruby literals. FieldType::Literal(value) => value.literal_base_type().to_ruby(), // https://sorbet.org/docs/stdlib-generics diff --git a/engine/language_client_codegen/src/ruby/generate_types.rs b/engine/language_client_codegen/src/ruby/generate_types.rs index b69a44359..495b7e924 100644 --- a/engine/language_client_codegen/src/ruby/generate_types.rs +++ b/engine/language_client_codegen/src/ruby/generate_types.rs @@ -168,7 +168,8 @@ impl ToTypeReferenceInTypeDefinition for FieldType { match self { FieldType::Class(name) => format!("Baml::PartialTypes::{}", name.clone()), FieldType::Enum(name) => format!("T.nilable(Baml::Types::{})", name.clone()), - FieldType::Alias { resolution, .. } => resolution.to_partial_type_ref(), + // TODO: Can we define recursive aliases in Ruby with Sorbet? + FieldType::RecursiveTypeAlias(_name) => "T.anything".to_string(), // TODO: Temporary solution until we figure out Ruby literals. FieldType::Literal(value) => value.literal_base_type().to_partial_type_ref(), // https://sorbet.org/docs/stdlib-generics diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index 4f8d61a77..e3178534f 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -267,7 +267,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } FieldType::Class(name) => name.to_string(), - FieldType::Alias { resolution, .. } => resolution.to_type_ref(ir), + FieldType::RecursiveTypeAlias(name) => name.to_owned(), FieldType::List(inner) => match inner.as_ref() { FieldType::Union(_) | FieldType::Optional(_) => { format!("({})[]", inner.to_type_ref(ir)) diff --git a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml index 5da6d06ef..6d93fec8e 100644 --- a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml +++ b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml @@ -75,6 +75,19 @@ function AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs { {{ money }} + {{ ctx.output_format }} + "# +} + +type RecursiveMapAlias = map + +function SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias { + client "openai/gpt-4o" + prompt r#" + Return the given value: + + {{ input }} + {{ ctx.output_format }} "# } \ No newline at end of file diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index 1bb0c3836..52da2b5de 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -1982,6 +1982,29 @@ async def SchemaDescriptions( ) return cast(types.Schema, raw.cast_to(types, types)) + async def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveMapAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveMapAlias, raw.cast_to(types, types)) + async def StreamBigNumbers( self, digits: int, @@ -5507,6 +5530,36 @@ def SchemaDescriptions( self.__ctx_manager.get(), ) + def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.RecursiveMapAlias, types.RecursiveMapAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.RecursiveMapAlias, types.RecursiveMapAlias]( + raw, + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def StreamBigNumbers( self, digits: int, diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index 43fe97e37..dd4487bf2 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -83,7 +83,7 @@ "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index 1202e8360..560bb8279 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -1979,6 +1979,29 @@ def SchemaDescriptions( ) return cast(types.Schema, raw.cast_to(types, types)) + def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveMapAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveMapAlias, raw.cast_to(types, types)) + def StreamBigNumbers( self, digits: int, @@ -5505,6 +5528,36 @@ def SchemaDescriptions( self.__ctx_manager.get(), ) + def SimpleRecursiveMapAlias( + self, + input: types.RecursiveMapAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.RecursiveMapAlias, types.RecursiveMapAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "SimpleRecursiveMapAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.RecursiveMapAlias, types.RecursiveMapAlias]( + raw, + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def StreamBigNumbers( self, digits: int, diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index 1b64063b6..4bdc54668 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -16,7 +16,7 @@ import baml_py from enum import Enum from pydantic import BaseModel, ConfigDict -from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union +from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union, TypeAlias T = TypeVar('T') @@ -478,3 +478,5 @@ class UnionTest_ReturnType(BaseModel): class WithReasoning(BaseModel): value: str reasoning: str + +RecursiveMapAlias: TypeAlias = Dict[str, "RecursiveMapAlias"] diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index 11d92c800..416647db2 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -314,6 +314,11 @@ async def test_alias_with_multiple_attrs(self): assert res.value == 123 assert res.checks["gt_ten"].status == "succeeded" + @pytest.mark.asyncio + async def test_simple_recursive_map_alias(self): + res = await b.SimpleRecursiveMapAlias({"one": {"two": {"three": {}}}}) + assert res == {"one": {"two": {"three": {}}}} + class MyCustomClass(NamedArgsSingleClass): date: datetime.datetime diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index e8cc900c6..51f5d200d 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -2738,6 +2738,38 @@ def SchemaDescriptions( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def SimpleRecursiveMapAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveMapAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "SimpleRecursiveMapAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -7067,6 +7099,41 @@ def SchemaDescriptions( ) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def SimpleRecursiveMapAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveMapAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "SimpleRecursiveMapAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index 0751e9f54..f12cc922b 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -83,7 +83,7 @@ module Inlined "test-files/functions/output/recursive-type-aliases.baml" => "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml" => "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml" => "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml" => "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml" => "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml" => "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index bd88336c3..792b5c3dd 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -2143,6 +2143,31 @@ export class BamlAsyncClient { } } + async SimpleRecursiveMapAlias( + input: RecursiveMapAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "SimpleRecursiveMapAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveMapAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5996,6 +6021,39 @@ class BamlStreamClient { } } + SimpleRecursiveMapAlias( + input: RecursiveMapAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, RecursiveMapAlias> { + try { + const raw = this.runtime.streamFunction( + "SimpleRecursiveMapAlias", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, RecursiveMapAlias>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is RecursiveMapAlias => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index f03520004..5c490a61b 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -84,7 +84,7 @@ const fileMap = { "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index 35306dac1..f9af16e41 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -2143,6 +2143,31 @@ export class BamlSyncClient { } } + SimpleRecursiveMapAlias( + input: RecursiveMapAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): RecursiveMapAlias { + try { + const raw = this.runtime.callFunctionSync( + "SimpleRecursiveMapAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveMapAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } From 140b3dd3426ca102fd21102f18225c330bc74833 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 17 Dec 2024 01:12:00 +0100 Subject: [PATCH 05/19] Integ test works! Yeah --- .../jinja-runtime/src/output_format/types.rs | 18 ++++++++++++------ .../src/deserializer/coercer/field_type.rs | 15 ++++++++++++--- .../src/deserializer/coercer/ir_ref/mod.rs | 5 +++++ .../prompt_renderer/render_output_format.rs | 16 +++++++++++----- 4 files changed, 40 insertions(+), 14 deletions(-) diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index ade317e53..bf5bb6045 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -58,7 +58,7 @@ pub struct OutputFormatContent { pub enums: Arc>, pub classes: Arc>, recursive_classes: Arc>, - structural_recursive_aliases: Arc>, + structural_recursive_aliases: Arc>, pub target: FieldType, } @@ -69,7 +69,7 @@ pub struct Builder { /// Order matters for this one. recursive_classes: IndexSet, /// Recursive aliases introduced maps and lists. - structural_recursive_aliases: IndexSet, + structural_recursive_aliases: IndexMap, target: FieldType, } @@ -79,7 +79,7 @@ impl Builder { enums: vec![], classes: vec![], recursive_classes: IndexSet::new(), - structural_recursive_aliases: IndexSet::new(), + structural_recursive_aliases: IndexMap::new(), target, } } @@ -101,7 +101,7 @@ impl Builder { pub fn structural_recursive_aliases( mut self, - structural_recursive_aliases: IndexSet, + structural_recursive_aliases: IndexMap, ) -> Self { self.structural_recursive_aliases = structural_recursive_aliases; self @@ -710,13 +710,19 @@ impl OutputFormatContent { pub fn find_enum(&self, name: &str) -> Result<&Enum> { self.enums .get(name) - .ok_or_else(|| anyhow::anyhow!("Enum {} not found", name)) + .ok_or_else(|| anyhow::anyhow!("Enum {name} not found")) } pub fn find_class(&self, name: &str) -> Result<&Class> { self.classes .get(name) - .ok_or_else(|| anyhow::anyhow!("Class {} not found", name)) + .ok_or_else(|| anyhow::anyhow!("Class {name} not found")) + } + + pub fn find_recursive_alias_target(&self, name: &str) -> Result<&FieldType> { + self.structural_recursive_aliases + .get(name) + .ok_or_else(|| anyhow::anyhow!("Recursive alias {name} not found")) } } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index d825dfd28..13ba6e9b2 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -79,8 +79,17 @@ impl TypeCoercer for FieldType { FieldType::Enum(e) => IrRef::Enum(e).coerce(ctx, target, value), FieldType::Literal(l) => l.coerce(ctx, target, value), FieldType::Class(c) => IrRef::Class(c).coerce(ctx, target, value), - // TODO: How to handle this? - FieldType::RecursiveTypeAlias(_) => todo!("Alias with no resolution"), + // TODO: This doesn't look too good compared to the rest of + // match arms. Should we make use of the context like this here? + FieldType::RecursiveTypeAlias(name) => ctx + .of + .find_recursive_alias_target(name) + .map_err(|e| ParsingError { + reason: format!("Failed to find recursive alias target: {e}"), + scope: ctx.scope.clone(), + causes: Vec::new(), + })? + .coerce(ctx, target, value), FieldType::List(_) => coerce_array(ctx, self, value), FieldType::Union(_) => coerce_union(ctx, self, value), FieldType::Optional(_) => coerce_optional(ctx, self, value), @@ -90,7 +99,7 @@ impl TypeCoercer for FieldType { let mut coerced_value = base.coerce(ctx, base, value)?; let constraint_results = run_user_checks(&coerced_value.clone().into(), self) .map_err(|e| ParsingError { - reason: format!("Failed to evaluate constraints: {:?}", e), + reason: format!("Failed to evaluate constraints: {e:?}"), scope: ctx.scope.clone(), causes: Vec::new(), })?; diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs index a16b9ed1a..0f13e4f36 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs @@ -11,6 +11,7 @@ use super::{ParsingContext, ParsingError}; pub(super) enum IrRef<'a> { Enum(&'a String), Class(&'a String), + RecursiveAlias(&'a String), } impl TypeCoercer for IrRef<'_> { @@ -29,6 +30,10 @@ impl TypeCoercer for IrRef<'_> { Ok(c) => c.coerce(ctx, target, value), Err(e) => Err(ctx.error_internal(e.to_string())), }, + IrRef::RecursiveAlias(a) => match ctx.of.find_recursive_alias_target(a.as_str()) { + Ok(a) => a.coerce(ctx, target, value), + Err(e) => Err(ctx.error_internal(e.to_string())), + }, } } } diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 05b7fbd17..069ced74e 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use anyhow::Result; use baml_types::BamlValue; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use internal_baml_core::ir::{ repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, }; @@ -212,12 +212,17 @@ fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, ctx: &RuntimeContext, -) -> Result<(Vec, Vec, IndexSet, IndexSet)> { +) -> Result<( + Vec, + Vec, + IndexSet, + IndexMap, +)> { let mut checked_types = HashSet::new(); let mut enums = Vec::new(); let mut classes = Vec::new(); let mut recursive_classes = IndexSet::new(); - let mut structural_recursive_aliases = IndexSet::new(); + let mut structural_recursive_aliases = IndexMap::new(); let mut start: Vec = vec![output.clone()]; let eval_ctx = ctx.eval_ctx(false); @@ -381,10 +386,11 @@ fn relevant_data_models<'a>( } (FieldType::RecursiveTypeAlias(name), _) => { // TODO: Same O(n) problem as above. - // TODO: Do we need the type information or just the name? for cycle in ir.structural_recursive_alias_cycles() { if cycle.contains_key(name) { - structural_recursive_aliases.extend(cycle.keys().map(ToOwned::to_owned)); + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } } } } From 68b98b704cae52861acc743d7913621152bbf1d7 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 17 Dec 2024 02:16:46 +0100 Subject: [PATCH 06/19] Fix structural cycles rendering --- .../jinja-runtime/src/output_format/types.rs | 113 +++++++++++++++--- .../baml-schema-wasm/src/runtime_wasm/mod.rs | 4 +- 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index bf5bb6045..188f37035 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -312,7 +312,6 @@ fn indefinite_article_a_or_an(word: &str) -> &str { struct RenderState { hoisted_enums: IndexSet, - hoisted_aliases: IndexMap, } impl OutputFormatContent { @@ -512,6 +511,9 @@ impl OutputFormatContent { FieldType::List(inner) => { let is_recursive = match inner.as_ref() { FieldType::Class(nested_class) => self.recursive_classes.contains(nested_class), + FieldType::RecursiveTypeAlias(name) => { + self.structural_recursive_aliases.contains_key(name) + } _ => false, }; @@ -582,7 +584,6 @@ impl OutputFormatContent { let mut render_state = RenderState { hoisted_enums: IndexSet::new(), - hoisted_aliases: IndexMap::new(), }; let mut message = match &self.target { @@ -636,18 +637,9 @@ impl OutputFormatContent { }); } - // Yeah gotta love the borrow checker. // TODO: @antonio - let hoisted_aliases = - std::mem::replace(&mut render_state.hoisted_aliases, Default::default()); - - for (alias, target) in hoisted_aliases.iter() { - let recursive_pointer = self.inner_type_render( - &options, - // TODO: @antonio This code sucks beyond measure. Fix this. - &target, - &mut render_state, - false, - )?; + for (alias, target) in self.structural_recursive_aliases.iter() { + let recursive_pointer = + self.inner_type_render(&options, target, &mut render_state, false)?; type_alias_definitions.push(match &options.hoisted_class_prefix { RenderSetting::Always(prefix) if !prefix.is_empty() => { @@ -670,7 +662,7 @@ impl OutputFormatContent { } if !type_alias_definitions.is_empty() { - output.push_str(&type_alias_definitions.join("\n\n")); + output.push_str(&type_alias_definitions.join("\n")); output.push_str("\n\n"); } @@ -2281,4 +2273,95 @@ Answer in JSON using this schema: )) ); } + + #[test] + fn render_simple_recursive_aliases() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias( + "RecursiveMapAlias".to_string(), + )) + .structural_recursive_aliases(IndexMap::from([( + "RecursiveMapAlias".to_string(), + FieldType::map( + FieldType::string(), + FieldType::RecursiveTypeAlias("RecursiveMapAlias".to_string()), + ), + )])) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"RecursiveMapAlias = map + +Answer in JSON using this schema: RecursiveMapAlias"# + )) + ); + } + + #[test] + fn render_recursive_alias_cycle() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias("A".to_string())) + .structural_recursive_aliases(IndexMap::from([ + ( + "A".to_string(), + FieldType::RecursiveTypeAlias("B".to_string()), + ), + ( + "B".to_string(), + FieldType::RecursiveTypeAlias("C".to_string()), + ), + ( + "C".to_string(), + FieldType::list(FieldType::RecursiveTypeAlias("A".to_string())), + ), + ])) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"A = B +B = C +C = A[] + +Answer in JSON using this schema: A"# + )) + ); + } + + #[test] + fn render_recursive_alias_cycle_with_hoist_prefix() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias("A".to_string())) + .structural_recursive_aliases(IndexMap::from([ + ( + "A".to_string(), + FieldType::RecursiveTypeAlias("B".to_string()), + ), + ( + "B".to_string(), + FieldType::RecursiveTypeAlias("C".to_string()), + ), + ( + "C".to_string(), + FieldType::list(FieldType::RecursiveTypeAlias("A".to_string())), + ), + ])) + .build(); + let rendered = content + .render(RenderOptions::with_hoisted_class_prefix("type")) + .unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"type A = B +type B = C +type C = A[] + +Answer in JSON using this type: A"# + )) + ); + } } diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index f21641107..614b19fbc 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -876,9 +876,7 @@ fn get_dummy_value( baml_runtime::FieldType::Literal(_) => None, baml_runtime::FieldType::Enum(_) => None, baml_runtime::FieldType::Class(_) => None, - baml_runtime::FieldType::Alias { resolution, .. } => { - get_dummy_value(indent, allow_multiline, &resolution) - } + baml_runtime::FieldType::RecursiveTypeAlias(_) => None, baml_runtime::FieldType::List(item) => { let dummy = get_dummy_value(indent + 1, allow_multiline, item); // Repeat it 2 times From d462e5cb0bf43d3801d3dd7779b3b5f0257e5b2f Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 17 Dec 2024 03:31:29 +0100 Subject: [PATCH 07/19] Coerce is wonky --- .../src/deserializer/coercer/field_type.rs | 10 ++ .../src/deserializer/coercer/ir_ref/mod.rs | 2 + engine/baml-runtime/tests/test_runtime.rs | 48 +++++++ .../functions/output/type-aliases.baml | 28 ++++ .../python/baml_client/async_client.py | 106 ++++++++++++++ integ-tests/python/baml_client/inlinedbaml.py | 2 +- integ-tests/python/baml_client/sync_client.py | 106 ++++++++++++++ integ-tests/python/baml_client/types.py | 8 ++ integ-tests/python/tests/test_functions.py | 10 ++ integ-tests/ruby/baml_client/client.rb | 134 ++++++++++++++++++ integ-tests/ruby/baml_client/inlined.rb | 2 +- .../typescript/baml_client/async_client.ts | 116 +++++++++++++++ .../typescript/baml_client/inlinedbaml.ts | 2 +- .../typescript/baml_client/sync_client.ts | 50 +++++++ 14 files changed, 621 insertions(+), 3 deletions(-) diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index 13ba6e9b2..eed33ba0c 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -14,6 +14,8 @@ use super::{ ParsingError, }; +static mut count: u32 = 0; + impl TypeCoercer for FieldType { fn coerce( &self, @@ -21,6 +23,14 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { + unsafe { + eprintln!("{self:?} -> {target:?} -> {value:?}"); + count += 1; + if count == 20 { + panic!("FUCK"); + } + } + match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs index 0f13e4f36..b5f0df842 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs @@ -1,6 +1,8 @@ mod coerce_class; pub mod coerce_enum; +use core::panic; + use anyhow::Result; use internal_baml_core::ir::FieldType; diff --git a/engine/baml-runtime/tests/test_runtime.rs b/engine/baml-runtime/tests/test_runtime.rs index 79efe7151..97e9d07ac 100644 --- a/engine/baml-runtime/tests/test_runtime.rs +++ b/engine/baml-runtime/tests/test_runtime.rs @@ -552,4 +552,52 @@ test RunFoo2Test { Ok(()) } + + #[test] + fn test_recursive_alias_cycle() -> anyhow::Result<()> { + let runtime = make_test_runtime( + r##" +type RecAliasOne = RecAliasTwo +type RecAliasTwo = RecAliasThree +type RecAliasThree = RecAliasOne[] + +function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne { + client "openai/gpt-4o" + prompt r#" + Return the given value: + + {{ input }} + + {{ ctx.output_format }} + "# +} + +test RecursiveAliasCycle { + functions [RecursiveAliasCycle] + args { + input [ + [] + [] + [[], []] + ] + } +} + "##, + )?; + + let ctx = runtime + .create_ctx_manager(BamlValue::String("test".to_string()), None) + .create_ctx_with_default(); + + let function_name = "RecursiveAliasCycle"; + let test_name = "RecursiveAliasCycle"; + let params = runtime.get_test_params(function_name, test_name, &ctx, true)?; + let render_prompt_future = + runtime + .internal() + .render_prompt(function_name, &ctx, ¶ms, None); + let (prompt, scope, _) = runtime.async_runtime.block_on(render_prompt_future)?; + + Ok(()) + } } diff --git a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml index 6d93fec8e..dd1ee3a75 100644 --- a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml +++ b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml @@ -88,6 +88,34 @@ function SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {{ input }} + {{ ctx.output_format }} + "# +} + +type RecursiveListAlias = RecursiveListAlias[] + +function SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias { + client "openai/gpt-4o" + prompt r#" + Return the given JSON array: + + {{ input }} + + {{ ctx.output_format }} + "# +} + +type RecAliasOne = RecAliasTwo +type RecAliasTwo = RecAliasThree +type RecAliasThree = RecAliasOne[] + +function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne { + client "openai/gpt-4o" + prompt r#" + Return the given JSON array: + + {{ input }} + {{ ctx.output_format }} "# } \ No newline at end of file diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index 52da2b5de..1b46052cd 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -1867,6 +1867,29 @@ async def PromptTestStreaming( ) return cast(str, raw.cast_to(types, types)) + async def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> types.RecAliasOne: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "RecursiveAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecAliasOne, raw.cast_to(types, types)) + async def RecursiveClassWithAliasIndirection( self, cls: types.NodeWithAliasIndirection, @@ -1982,6 +2005,29 @@ async def SchemaDescriptions( ) return cast(types.Schema, raw.cast_to(types, types)) + async def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveListAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "SimpleRecursiveListAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveListAlias, raw.cast_to(types, types)) + async def SimpleRecursiveMapAlias( self, input: types.RecursiveMapAlias, @@ -5380,6 +5426,36 @@ def PromptTestStreaming( self.__ctx_manager.get(), ) + def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "RecursiveAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne]( + raw, + lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def RecursiveClassWithAliasIndirection( self, cls: types.NodeWithAliasIndirection, @@ -5530,6 +5606,36 @@ def SchemaDescriptions( self.__ctx_manager.get(), ) + def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "SimpleRecursiveListAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias]( + raw, + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def SimpleRecursiveMapAlias( self, input: types.RecursiveMapAlias, diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index dd4487bf2..e4c73068f 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -83,7 +83,7 @@ "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index 560bb8279..50bde725a 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -1864,6 +1864,29 @@ def PromptTestStreaming( ) return cast(str, raw.cast_to(types, types)) + def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> types.RecAliasOne: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "RecursiveAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecAliasOne, raw.cast_to(types, types)) + def RecursiveClassWithAliasIndirection( self, cls: types.NodeWithAliasIndirection, @@ -1979,6 +2002,29 @@ def SchemaDescriptions( ) return cast(types.Schema, raw.cast_to(types, types)) + def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> types.RecursiveListAlias: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "SimpleRecursiveListAlias", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.RecursiveListAlias, raw.cast_to(types, types)) + def SimpleRecursiveMapAlias( self, input: types.RecursiveMapAlias, @@ -5378,6 +5424,36 @@ def PromptTestStreaming( self.__ctx_manager.get(), ) + def RecursiveAliasCycle( + self, + input: types.RecAliasOne, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.RecAliasOne, types.RecAliasOne]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "RecursiveAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.RecAliasOne, types.RecAliasOne]( + raw, + lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def RecursiveClassWithAliasIndirection( self, cls: types.NodeWithAliasIndirection, @@ -5528,6 +5604,36 @@ def SchemaDescriptions( self.__ctx_manager.get(), ) + def SimpleRecursiveListAlias( + self, + input: types.RecursiveListAlias, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.RecursiveListAlias, types.RecursiveListAlias]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "SimpleRecursiveListAlias", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.RecursiveListAlias, types.RecursiveListAlias]( + raw, + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def SimpleRecursiveMapAlias( self, input: types.RecursiveMapAlias, diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index 4bdc54668..743e918b3 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -480,3 +480,11 @@ class WithReasoning(BaseModel): reasoning: str RecursiveMapAlias: TypeAlias = Dict[str, "RecursiveMapAlias"] + +RecursiveListAlias: TypeAlias = List["RecursiveListAlias"] + +RecAliasOne: TypeAlias = "RecAliasTwo" + +RecAliasTwo: TypeAlias = "RecAliasThree" + +RecAliasThree: TypeAlias = List["RecAliasOne"] diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index 416647db2..a7813f046 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -319,6 +319,16 @@ async def test_simple_recursive_map_alias(self): res = await b.SimpleRecursiveMapAlias({"one": {"two": {"three": {}}}}) assert res == {"one": {"two": {"three": {}}}} + @pytest.mark.asyncio + async def test_simple_recursive_list_alias(self): + res = await b.SimpleRecursiveListAlias([[], [], [[]]]) + assert res == [[], [], [[]]] + + @pytest.mark.asyncio + async def test_recursive_alias_cycles(self): + res = await b.RecursiveAliasCycle([[], [], [[]]]) + assert res == [[], [], [[]]] + class MyCustomClass(NamedArgsSingleClass): date: datetime.datetime diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index 51f5d200d..f122341d7 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -2578,6 +2578,38 @@ def PromptTestStreaming( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def RecursiveAliasCycle( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("RecursiveAliasCycle may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "RecursiveAliasCycle", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -2738,6 +2770,38 @@ def SchemaDescriptions( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def SimpleRecursiveListAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveListAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "SimpleRecursiveListAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -6924,6 +6988,41 @@ def PromptTestStreaming( ) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def RecursiveAliasCycle( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("RecursiveAliasCycle may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "RecursiveAliasCycle", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -7099,6 +7198,41 @@ def SchemaDescriptions( ) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def SimpleRecursiveListAlias( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("SimpleRecursiveListAlias may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "SimpleRecursiveListAlias", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index f12cc922b..daabe128e 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -83,7 +83,7 @@ module Inlined "test-files/functions/output/recursive-type-aliases.baml" => "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml" => "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml" => "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml" => "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml" => "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml" => "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index 792b5c3dd..78e513c62 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -2018,6 +2018,31 @@ export class BamlAsyncClient { } } + async RecursiveAliasCycle( + input: RecAliasOne, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "RecursiveAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecAliasOne + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async RecursiveClassWithAliasIndirection( cls: NodeWithAliasIndirection, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -2143,6 +2168,31 @@ export class BamlAsyncClient { } } + async SimpleRecursiveListAlias( + input: RecursiveListAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "SimpleRecursiveListAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveListAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async SimpleRecursiveMapAlias( input: RecursiveMapAlias, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5856,6 +5906,39 @@ class BamlStreamClient { } } + RecursiveAliasCycle( + input: RecAliasOne, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, RecAliasOne> { + try { + const raw = this.runtime.streamFunction( + "RecursiveAliasCycle", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, RecAliasOne>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is RecAliasOne => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + RecursiveClassWithAliasIndirection( cls: NodeWithAliasIndirection, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -6021,6 +6104,39 @@ class BamlStreamClient { } } + SimpleRecursiveListAlias( + input: RecursiveListAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, RecursiveListAlias> { + try { + const raw = this.runtime.streamFunction( + "SimpleRecursiveListAlias", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, RecursiveListAlias>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is RecursiveListAlias => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + SimpleRecursiveMapAlias( input: RecursiveMapAlias, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index 5c490a61b..2c818f1c4 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -84,7 +84,7 @@ const fileMap = { "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index f9af16e41..50c93b176 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -2018,6 +2018,31 @@ export class BamlSyncClient { } } + RecursiveAliasCycle( + input: RecAliasOne, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): RecAliasOne { + try { + const raw = this.runtime.callFunctionSync( + "RecursiveAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecAliasOne + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + RecursiveClassWithAliasIndirection( cls: NodeWithAliasIndirection, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -2143,6 +2168,31 @@ export class BamlSyncClient { } } + SimpleRecursiveListAlias( + input: RecursiveListAlias, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): RecursiveListAlias { + try { + const raw = this.runtime.callFunctionSync( + "SimpleRecursiveListAlias", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as RecursiveListAlias + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + SimpleRecursiveMapAlias( input: RecursiveMapAlias, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } From e0ae448f3bf412959483f9c784cdd38d7bcd6728 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 17 Dec 2024 17:39:06 +0100 Subject: [PATCH 08/19] Fix test `relevant_data_models` --- .../src/deserializer/coercer/field_type.rs | 10 ------ engine/baml-lib/jsonish/src/tests/mod.rs | 34 +++++++++++++++---- .../jsonish/src/tests/test_aliases.rs | 25 ++++++++++++++ 3 files changed, 53 insertions(+), 16 deletions(-) create mode 100644 engine/baml-lib/jsonish/src/tests/test_aliases.rs diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index eed33ba0c..13ba6e9b2 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -14,8 +14,6 @@ use super::{ ParsingError, }; -static mut count: u32 = 0; - impl TypeCoercer for FieldType { fn coerce( &self, @@ -23,14 +21,6 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - unsafe { - eprintln!("{self:?} -> {target:?} -> {value:?}"); - count += 1; - if count == 20 { - panic!("FUCK"); - } - } - match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 41a22ef4c..00da9d619 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -4,6 +4,7 @@ use internal_baml_jinja::types::{Class, Enum, Name, OutputFormatContent}; #[macro_use] pub mod macros; +mod test_aliases; mod test_basics; mod test_class; mod test_class_2; @@ -16,7 +17,7 @@ mod test_maps; mod test_partials; mod test_unions; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use std::{ collections::{HashMap, HashSet}, path::PathBuf, @@ -56,12 +57,14 @@ fn render_output_format( output: &FieldType, env_values: &EvaluationContext<'_>, ) -> Result { - let (enums, classes, recursive_classes) = relevant_data_models(ir, output, env_values)?; + let (enums, classes, recursive_classes, structural_recursive_aliases) = + relevant_data_models(ir, output, env_values)?; Ok(OutputFormatContent::target(output.clone()) .enums(enums) .classes(classes) .recursive_classes(recursive_classes) + .structural_recursive_aliases(structural_recursive_aliases) .build()) } @@ -126,11 +129,17 @@ fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, env_values: &EvaluationContext<'_>, -) -> Result<(Vec, Vec, IndexSet)> { +) -> Result<( + Vec, + Vec, + IndexSet, + IndexMap, +)> { let mut checked_types: HashSet = HashSet::new(); let mut enums = Vec::new(); let mut classes: Vec = Vec::new(); let mut recursive_classes = IndexSet::new(); + let mut structural_recursive_aliases = IndexMap::new(); let mut start: Vec = vec![output.clone()]; while let Some(output) = start.pop() { @@ -230,8 +239,16 @@ fn relevant_data_models<'a>( }); } } - // TODO: Add structural aliases here. - (FieldType::RecursiveTypeAlias(_), _) => {} + (FieldType::RecursiveTypeAlias(name), _) => { + // TODO: Same O(n) problem as above. + for cycle in ir.structural_recursive_alias_cycles() { + if cycle.contains_key(name) { + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } + } + } + } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _constraints) => {} (FieldType::Constrained { .. }, _) => { @@ -240,7 +257,12 @@ fn relevant_data_models<'a>( } } - Ok((enums, classes, recursive_classes)) + Ok(( + enums, + classes, + recursive_classes, + structural_recursive_aliases, + )) } const EMPTY_FILE: &str = r#" diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs new file mode 100644 index 000000000..061aab3cb --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -0,0 +1,25 @@ +use baml_types::LiteralValue; + +use super::*; + +test_deserializer!( + test_simple_recursive_alias_list, + r#" +type A = A[] + "#, + "[[], [], [[]]]", + FieldType::RecursiveTypeAlias("A".into()), + [[], [], [[]]] +); + +test_deserializer!( + test_recursive_alias_cycle, + r#" +type A = B +type B = C +type C = A[] + "#, + "[[], [], [[]]]", + FieldType::RecursiveTypeAlias("A".into()), + [[], [], [[]]] +); From abb743012216ec2a8bf13c23b0b8f2c8a2e91aed Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 01:06:32 +0100 Subject: [PATCH 09/19] `is_subtype_of` causing issues with aliases --- .../src/ir/ir_helpers/to_baml_arg.rs | 39 +++++++- .../baml-lib/baml-types/src/field_type/mod.rs | 7 +- .../src/deserializer/coercer/field_type.rs | 31 ++++--- .../coercer/ir_ref/coerce_alias.rs | 44 +++++++++ .../src/deserializer/coercer/ir_ref/mod.rs | 1 + .../jsonish/src/tests/test_aliases.rs | 93 ++++++++++++++++++- .../functions/output/type-aliases.baml | 15 +++ .../python/baml_client/async_client.py | 53 +++++++++++ integ-tests/python/baml_client/inlinedbaml.py | 2 +- integ-tests/python/baml_client/sync_client.py | 53 +++++++++++ integ-tests/python/baml_client/types.py | 6 ++ integ-tests/python/tests/test_functions.py | 25 +++++ integ-tests/ruby/baml_client/client.rb | 67 +++++++++++++ integ-tests/ruby/baml_client/inlined.rb | 2 +- .../typescript/baml_client/async_client.ts | 58 ++++++++++++ .../typescript/baml_client/inlinedbaml.ts | 2 +- .../typescript/baml_client/sync_client.ts | 25 +++++ 17 files changed, 497 insertions(+), 26 deletions(-) create mode 100644 engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 4d8bb08d3..7f529cd72 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -43,6 +43,9 @@ impl ArgCoercer { value: &BamlValue, // original value passed in by user scope: &mut ScopeStack, ) -> Result { + eprintln!("coerce_arg: {value:?} -> {field_type:?}"); + eprintln!("scope: {scope}\n"); + let value = match ir.distribute_constraints(field_type) { (FieldType::Primitive(t), _) => match t { TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()), @@ -263,11 +266,10 @@ impl ArgCoercer { Err(()) } }, - // TODO: Is this even possible? (FieldType::RecursiveTypeAlias(name), _) => { let mut maybe_coerced = None; // TODO: Fix this O(n) - for cycle in ir.structural_recursive_alias_cycles().iter() { + for cycle in ir.structural_recursive_alias_cycles() { if let Some(target) = cycle.get(name) { maybe_coerced = Some(self.coerce_arg(ir, target, value, scope)?); break; @@ -329,6 +331,7 @@ impl ArgCoercer { let mut scope = ScopeStack::new(); if first_good_result.is_err() { let result = self.coerce_arg(ir, option, value, &mut scope); + eprintln!("union inner scope scope: {scope}\n"); if !scope.has_errors() && first_good_result.is_err() { first_good_result = result } @@ -458,4 +461,36 @@ mod tests { let res = arg_coercer.coerce_arg(&ir, &type_, &value, &mut ScopeStack::new()); assert!(res.is_err()); } + + #[test] + fn test_mutually_recursive_aliases() { + let ir = make_test_ir( + r##" +type JsonValue = int | string | bool | float | JsonObject | JsonArray +type JsonObject = map +type JsonArray = JsonValue[] + "##, + ) + .unwrap(); + + let arg_coercer = ArgCoercer { + span_path: None, + allow_implicit_cast_to_string: true, + }; + + let json = BamlValue::Map(BamlMap::from([ + ("number".to_string(), BamlValue::Int(1)), + ("string".to_string(), BamlValue::String("test".to_string())), + ("bool".to_string(), BamlValue::Bool(true)), + ])); + + let res = arg_coercer.coerce_arg( + &ir, + &FieldType::RecursiveTypeAlias("JsonValue".to_string()), + &json, + &mut ScopeStack::new(), + ); + + assert_eq!(res, Ok(json)); + } } diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index 8952ba11f..7008f26a5 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -199,12 +199,6 @@ impl FieldType { } (FieldType::Map(_, _), _) => false, - // TODO: is it necessary to check if the alias is part of the same - // cycle? - (FieldType::RecursiveTypeAlias(_), _) | (_, FieldType::RecursiveTypeAlias(_)) => { - self == other - } - ( FieldType::Constrained { base: self_base, @@ -252,6 +246,7 @@ impl FieldType { (FieldType::Primitive(_), _) => false, (FieldType::Enum(_), _) => false, (FieldType::Class(_), _) => false, + (FieldType::RecursiveTypeAlias(_), _) => false, } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index 13ba6e9b2..ab4d64cda 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -9,11 +9,17 @@ use crate::deserializer::{ }; use super::{ - array_helper, coerce_array::coerce_array, coerce_map::coerce_map, - coerce_optional::coerce_optional, coerce_union::coerce_union, ir_ref::IrRef, ParsingContext, - ParsingError, + array_helper, + coerce_array::coerce_array, + coerce_map::coerce_map, + coerce_optional::coerce_optional, + coerce_union::coerce_union, + ir_ref::{coerce_alias::coerce_alias, IrRef}, + ParsingContext, ParsingError, }; +static mut LIMIT: usize = 0; + impl TypeCoercer for FieldType { fn coerce( &self, @@ -21,6 +27,13 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { + unsafe { + LIMIT += 1; + if LIMIT > 500 { + panic!("Stack Overflow Bruh {}", LIMIT); + } + } + match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( @@ -79,17 +92,7 @@ impl TypeCoercer for FieldType { FieldType::Enum(e) => IrRef::Enum(e).coerce(ctx, target, value), FieldType::Literal(l) => l.coerce(ctx, target, value), FieldType::Class(c) => IrRef::Class(c).coerce(ctx, target, value), - // TODO: This doesn't look too good compared to the rest of - // match arms. Should we make use of the context like this here? - FieldType::RecursiveTypeAlias(name) => ctx - .of - .find_recursive_alias_target(name) - .map_err(|e| ParsingError { - reason: format!("Failed to find recursive alias target: {e}"), - scope: ctx.scope.clone(), - causes: Vec::new(), - })? - .coerce(ctx, target, value), + FieldType::RecursiveTypeAlias(name) => coerce_alias(ctx, self, value), FieldType::List(_) => coerce_array(ctx, self, value), FieldType::Union(_) => coerce_union(ctx, self, value), FieldType::Optional(_) => coerce_optional(ctx, self, value), diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs new file mode 100644 index 000000000..352724bd8 --- /dev/null +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_alias.rs @@ -0,0 +1,44 @@ +use anyhow::Result; +use internal_baml_core::ir::FieldType; + +use crate::deserializer::types::BamlValueWithFlags; + +use super::{ParsingContext, ParsingError, TypeCoercer}; + +pub fn coerce_alias( + ctx: &ParsingContext, + target: &FieldType, + value: Option<&crate::jsonish::Value>, +) -> Result { + assert!(matches!(target, FieldType::RecursiveTypeAlias(_))); + log::debug!( + "scope: {scope} :: coercing to: {name} (current: {current})", + name = target.to_string(), + scope = ctx.display_scope(), + current = value.map(|v| v.r#type()).unwrap_or("".into()) + ); + + let FieldType::RecursiveTypeAlias(alias) = target else { + unreachable!("coerce_alias"); + }; + + // See coerce_class.rs + let mut nested_ctx = None; + if let Some(v) = value { + let cls_value_pair = (alias.to_string(), v.to_owned()); + if ctx.visited.contains(&cls_value_pair) { + return Err(ctx.error_circular_reference(alias, v)); + } + nested_ctx = Some(ctx.visit_class_value_pair(cls_value_pair)); + } + let ctx = nested_ctx.as_ref().unwrap_or(ctx); + + ctx.of + .find_recursive_alias_target(alias) + .map_err(|e| ParsingError { + reason: format!("Failed to find recursive alias target: {e}"), + scope: ctx.scope.clone(), + causes: Vec::new(), + })? + .coerce(ctx, target, value) +} diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs index b5f0df842..88a3cfde3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/mod.rs @@ -1,3 +1,4 @@ +pub mod coerce_alias; mod coerce_class; pub mod coerce_enum; diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs index 061aab3cb..35baa352f 100644 --- a/engine/baml-lib/jsonish/src/tests/test_aliases.rs +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -5,13 +5,26 @@ use super::*; test_deserializer!( test_simple_recursive_alias_list, r#" -type A = A[] +type A = A[] "#, "[[], [], [[]]]", FieldType::RecursiveTypeAlias("A".into()), [[], [], [[]]] ); +test_deserializer!( + test_simple_recursive_alias_map, + r#" +type A = map + "#, + r#"{"one": {"two": {}}, "three": {"four": {}}}"#, + FieldType::RecursiveTypeAlias("A".into()), + { + "one": {"two": {}}, + "three": {"four": {}} + } +); + test_deserializer!( test_recursive_alias_cycle, r#" @@ -23,3 +36,81 @@ type C = A[] FieldType::RecursiveTypeAlias("A".into()), [[], [], [[]]] ); + +test_deserializer!( + test_recursive_alias_union, + r#" +type JsonValue = int | string | bool | JsonValue[] | map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true + } +); + +test_deserializer!( + test_complex_recursive_alias, + r#" +type JsonValue = int | string | bool | JsonValue[] | map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + "json": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + } + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + "json": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + } + } +); diff --git a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml index dd1ee3a75..216bfc2ec 100644 --- a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml +++ b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml @@ -116,6 +116,21 @@ function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne { {{ input }} + {{ ctx.output_format }} + "# +} + +type JsonValue = int | string | bool | float | JsonObject | JsonArray +type JsonObject = map +type JsonArray = JsonValue[] + +function JsonTypeAliasCycle(input: JsonValue) -> JsonValue { + client "openai/gpt-4o" + prompt r#" + Return the given input back: + + {{ input }} + {{ ctx.output_format }} "# } \ No newline at end of file diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index 1b46052cd..f14ad456d 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -1453,6 +1453,29 @@ async def InOutSingleLiteralStringMapKey( ) return cast(Dict[Literal["key"], str], raw.cast_to(types, types)) + async def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> types.JsonValue: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "JsonTypeAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.JsonValue, raw.cast_to(types, types)) + async def LiteralUnionsTest( self, input: str, @@ -4888,6 +4911,36 @@ def InOutSingleLiteralStringMapKey( self.__ctx_manager.get(), ) + def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[types.JsonValue, types.JsonValue]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "JsonTypeAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[types.JsonValue, types.JsonValue]( + raw, + lambda x: cast(types.JsonValue, x.cast_to(types, partial_types)), + lambda x: cast(types.JsonValue, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def LiteralUnionsTest( self, input: str, diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index e4c73068f..4d7de9465 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -83,7 +83,7 @@ "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index 50bde725a..e60b62b97 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -1450,6 +1450,29 @@ def InOutSingleLiteralStringMapKey( ) return cast(Dict[Literal["key"], str], raw.cast_to(types, types)) + def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> types.JsonValue: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "JsonTypeAliasCycle", + { + "input": input, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.JsonValue, raw.cast_to(types, types)) + def LiteralUnionsTest( self, input: str, @@ -4886,6 +4909,36 @@ def InOutSingleLiteralStringMapKey( self.__ctx_manager.get(), ) + def JsonTypeAliasCycle( + self, + input: types.JsonValue, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[types.JsonValue, types.JsonValue]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "JsonTypeAliasCycle", + { + "input": input, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[types.JsonValue, types.JsonValue]( + raw, + lambda x: cast(types.JsonValue, x.cast_to(types, partial_types)), + lambda x: cast(types.JsonValue, x.cast_to(types, types)), + self.__ctx_manager.get(), + ) + def LiteralUnionsTest( self, input: str, diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index 743e918b3..195eb2369 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -488,3 +488,9 @@ class WithReasoning(BaseModel): RecAliasTwo: TypeAlias = "RecAliasThree" RecAliasThree: TypeAlias = List["RecAliasOne"] + +JsonValue: TypeAlias = Union[int, str, bool, float, "JsonObject", "JsonArray"] + +JsonObject: TypeAlias = Dict[str, "JsonValue"] + +JsonArray: TypeAlias = List["JsonValue"] diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index a7813f046..8c96406e9 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -329,6 +329,31 @@ async def test_recursive_alias_cycles(self): res = await b.RecursiveAliasCycle([[], [], [[]]]) assert res == [[], [], [[]]] + @pytest.mark.asyncio + async def test_json_type_alias_cycle(self): + data = { + "number": 1, + "string": "test", + "bool": True, + "list": [1, 2, 3], + "object": {"number": 1, "string": "test", "bool": True, "list": [1, 2, 3]}, + "json": { + "number": 1, + "string": "test", + "bool": True, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": True, + "list": [1, 2, 3], + }, + }, + } + + res = await b.JsonTypeAliasCycle(data) + assert res == data + class MyCustomClass(NamedArgsSingleClass): date: datetime.datetime diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index f122341d7..fb0bb086d 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -2002,6 +2002,38 @@ def InOutSingleLiteralStringMapKey( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(T.anything) + } + def JsonTypeAliasCycle( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("JsonTypeAliasCycle may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "JsonTypeAliasCycle", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -6358,6 +6390,41 @@ def InOutSingleLiteralStringMapKey( ) end + sig { + params( + varargs: T.untyped, + input: T.anything, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[T.anything]) + } + def JsonTypeAliasCycle( + *varargs, + input:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("JsonTypeAliasCycle may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "JsonTypeAliasCycle", + { + input: input, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.anything, T.anything].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index daabe128e..5ce3d619d 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -83,7 +83,7 @@ module Inlined "test-files/functions/output/recursive-type-aliases.baml" => "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml" => "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml" => "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml" => "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml" => "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml" => "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index 78e513c62..8fdcf1d81 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -1568,6 +1568,31 @@ export class BamlAsyncClient { } } + async JsonTypeAliasCycle( + input: JsonValue, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "JsonTypeAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as JsonValue + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -5312,6 +5337,39 @@ class BamlStreamClient { } } + JsonTypeAliasCycle( + input: JsonValue, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, JsonValue> { + try { + const raw = this.runtime.streamFunction( + "JsonTypeAliasCycle", + { + "input": input + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, JsonValue>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is JsonValue => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index 2c818f1c4..56ad031f6 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -84,7 +84,7 @@ const fileMap = { "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index 50c93b176..b5fa48d63 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -1568,6 +1568,31 @@ export class BamlSyncClient { } } + JsonTypeAliasCycle( + input: JsonValue, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): JsonValue { + try { + const raw = this.runtime.callFunctionSync( + "JsonTypeAliasCycle", + { + "input": input + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as JsonValue + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } From c4e8b85934e88ffe0273bd985a7320f95e10408f Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 03:49:07 +0100 Subject: [PATCH 10/19] Fixed `subtype`, `coerce` still doesn't work --- .../baml-core/src/ir/ir_helpers/mod.rs | 274 ++++++++++++++++-- .../src/ir/ir_helpers/to_baml_arg.rs | 6 +- .../baml-lib/baml-types/src/field_type/mod.rs | 181 ------------ .../jsonish/src/tests/test_aliases.rs | 2 +- .../functions/output/type-aliases.baml | 2 +- 5 files changed, 251 insertions(+), 214 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 40d932dce..6c5e71a9d 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -57,6 +57,7 @@ pub trait IRHelper { value: BamlValue, field_type: FieldType, ) -> Result>; + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool; fn distribute_constraints<'a>( &'a self, field_type: &'a FieldType, @@ -203,6 +204,124 @@ impl IRHelper for IntermediateRepr { } } + /// BAML does not support class-based subtyping. Nonetheless some builtin + /// BAML types are subtypes of others, and we need to be able to test this + /// when checking the types of values. + /// + /// For examples of pairs of types and their subtyping relationship, see + /// this module's test suite. + /// + /// Consider renaming this to `is_assignable`. + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool { + if base == other { + return true; + } + + if let FieldType::Union(items) = other { + if items.iter().any(|item| self.is_subtype(base, item)) { + return true; + } + } + + match (base, other) { + // TODO: O(n) + (FieldType::RecursiveTypeAlias(name), _) => self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| match cycle.get(name) { + Some(target) => self.is_subtype(target, other), + None => false, + }), + (_, FieldType::RecursiveTypeAlias(name)) => self + .structural_recursive_alias_cycles() + .iter() + .any(|cycle| match cycle.get(name) { + Some(target) => self.is_subtype(base, target), + None => false, + }), + + (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, + (FieldType::Optional(base_item), FieldType::Optional(other_item)) => { + self.is_subtype(base_item, other_item) + } + (_, FieldType::Optional(t)) => self.is_subtype(base, t), + (FieldType::Optional(_), _) => false, + + // Handle types that nest other types. + (FieldType::List(base_item), FieldType::List(other_item)) => { + self.is_subtype(&base_item, other_item) + } + (FieldType::List(_), _) => false, + + (FieldType::Map(base_k, base_v), FieldType::Map(other_k, other_v)) => { + self.is_subtype(other_k, base_k) && self.is_subtype(&**base_v, other_v) + } + (FieldType::Map(_, _), _) => false, + + ( + FieldType::Constrained { + base: constrained_base, + constraints: base_constraints, + }, + FieldType::Constrained { + base: other_base, + constraints: other_constraints, + }, + ) => { + self.is_subtype(constrained_base, other_base) + && base_constraints == other_constraints + } + ( + FieldType::Constrained { + base: contrained_base, + .. + }, + _, + ) => self.is_subtype(contrained_base, other), + ( + _, + FieldType::Constrained { + base: constrained_base, + .. + }, + ) => self.is_subtype(base, constrained_base), + + (FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => { + true + } + (FieldType::Literal(LiteralValue::Bool(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::Bool)) + } + (FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => { + true + } + (FieldType::Literal(LiteralValue::Int(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::Int)) + } + ( + FieldType::Literal(LiteralValue::String(_)), + FieldType::Primitive(TypeValue::String), + ) => true, + (FieldType::Literal(LiteralValue::String(_)), _) => { + self.is_subtype(base, &FieldType::Primitive(TypeValue::String)) + } + + (FieldType::Union(items), _) => items.iter().all(|item| self.is_subtype(item, other)), + + (FieldType::Tuple(base_items), FieldType::Tuple(other_items)) => { + base_items.len() == other_items.len() + && base_items + .iter() + .zip(other_items) + .all(|(base_item, other_item)| self.is_subtype(base_item, other_item)) + } + (FieldType::Tuple(_), _) => false, + (FieldType::Primitive(_), _) => false, + (FieldType::Enum(_), _) => false, + (FieldType::Class(_), _) => false, + } + } + /// For some `BamlValue` with type `FieldType`, walk the structure of both the value /// and the type simultaneously, associating each node in the `BamlValue` with its /// `FieldType`. @@ -216,40 +335,38 @@ impl IRHelper for IntermediateRepr { let literal_type = FieldType::Literal(LiteralValue::String(s.clone())); let primitive_type = FieldType::Primitive(TypeValue::String); - if literal_type.is_subtype_of(&field_type) - || primitive_type.is_subtype_of(&field_type) + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) { return Ok(BamlValueWithMeta::String(s, field_type)); } anyhow::bail!("Could not unify String with {:?}", field_type) } - BamlValue::Int(i) - if FieldType::Literal(LiteralValue::Int(i)).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Int(i, field_type)) - } - BamlValue::Int(i) - if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Int(i, field_type)) - } BamlValue::Int(i) => { + let literal_type = FieldType::Literal(LiteralValue::Int(i)); + let primitive_type = FieldType::Primitive(TypeValue::Int); + + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) + { + return Ok(BamlValueWithMeta::Int(i, field_type)); + } anyhow::bail!("Could not unify Int with {:?}", field_type) } - BamlValue::Float(f) - if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) => - { - Ok(BamlValueWithMeta::Float(f, field_type)) + BamlValue::Float(f) => { + if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_type) { + return Ok(BamlValueWithMeta::Float(f, field_type)); + } + anyhow::bail!("Could not unify Float with {:?}", field_type) } - BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type), BamlValue::Bool(b) => { let literal_type = FieldType::Literal(LiteralValue::Bool(b)); let primitive_type = FieldType::Primitive(TypeValue::Bool); - if literal_type.is_subtype_of(&field_type) - || primitive_type.is_subtype_of(&field_type) + if self.is_subtype(&literal_type, &field_type) + || self.is_subtype(&primitive_type, &field_type) { Ok(BamlValueWithMeta::Bool(b, field_type)) } else { @@ -257,7 +374,9 @@ impl IRHelper for IntermediateRepr { } } - BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => { + BamlValue::Null + if self.is_subtype(&FieldType::Primitive(TypeValue::Null), &field_type) => + { Ok(BamlValueWithMeta::Null(field_type)) } BamlValue::Null => anyhow::bail!("Could not unify Null with {:?}", field_type), @@ -287,7 +406,7 @@ impl IRHelper for IntermediateRepr { Box::new(item_type.clone()), ); - if !map_type.is_subtype_of(&field_type) { + if !self.is_subtype(&map_type, &field_type) { anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type); } @@ -321,7 +440,7 @@ impl IRHelper for IntermediateRepr { Some(item_type) => { let list_type = FieldType::List(Box::new(item_type.clone())); - if !list_type.is_subtype_of(&field_type) { + if !self.is_subtype(&list_type, &field_type) { anyhow::bail!("Could not unify {:?} with {:?}", list_type, field_type); } else { let mapped_items: Vec> = items @@ -335,15 +454,17 @@ impl IRHelper for IntermediateRepr { } BamlValue::Media(m) - if FieldType::Primitive(TypeValue::Media(m.media_type)) - .is_subtype_of(&field_type) => + if self.is_subtype( + &FieldType::Primitive(TypeValue::Media(m.media_type)), + &field_type, + ) => { Ok(BamlValueWithMeta::Media(m, field_type)) } BamlValue::Media(_) => anyhow::bail!("Could not unify Media with {:?}", field_type), BamlValue::Enum(name, val) => { - if FieldType::Enum(name.clone()).is_subtype_of(&field_type) { + if self.is_subtype(&FieldType::Enum(name.clone()), &field_type) { Ok(BamlValueWithMeta::Enum(name, val, field_type)) } else { anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type) @@ -351,7 +472,7 @@ impl IRHelper for IntermediateRepr { } BamlValue::Class(name, fields) => { - if !FieldType::Class(name.clone()).is_subtype_of(&field_type) { + if !self.is_subtype(&FieldType::Class(name.clone()), &field_type) { anyhow::bail!("Could not unify Class {} with {:?}", name, field_type); } else { let class_type = &self.find_class(&name)?.item.elem; @@ -794,3 +915,104 @@ mod tests { assert_eq!(constraints, expected_constraints); } } + +// TODO: Copy pasted from baml-lib/baml-types/src/field_type/mod.rs and poorly +// refactored to match the `is_subtype` changes. Do something with this. +#[cfg(test)] +mod subtype_tests { + use baml_types::BamlMediaType; + use repr::make_test_ir; + + use super::*; + + fn mk_int() -> FieldType { + FieldType::Primitive(TypeValue::Int) + } + fn mk_bool() -> FieldType { + FieldType::Primitive(TypeValue::Bool) + } + fn mk_str() -> FieldType { + FieldType::Primitive(TypeValue::String) + } + + fn mk_optional(ft: FieldType) -> FieldType { + FieldType::Optional(Box::new(ft)) + } + + fn mk_list(ft: FieldType) -> FieldType { + FieldType::List(Box::new(ft)) + } + + fn mk_tuple(ft: Vec) -> FieldType { + FieldType::Tuple(ft) + } + fn mk_union(ft: Vec) -> FieldType { + FieldType::Union(ft) + } + fn mk_str_map(ft: FieldType) -> FieldType { + FieldType::Map(Box::new(mk_str()), Box::new(ft)) + } + + fn ir() -> IntermediateRepr { + make_test_ir("").unwrap() + } + + #[test] + fn subtype_trivial() { + assert!(ir().is_subtype(&mk_int(), &mk_int())) + } + + #[test] + fn subtype_union() { + let i = mk_int(); + let u = mk_union(vec![mk_int(), mk_str()]); + assert!(ir().is_subtype(&i, &u)); + assert!(!ir().is_subtype(&u, &i)); + + let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]); + assert!(ir().is_subtype(&i, &u3)); + assert!(ir().is_subtype(&u, &u3)); + assert!(!ir().is_subtype(&u3, &u)); + } + + #[test] + fn subtype_optional() { + let i = mk_int(); + let o = mk_optional(mk_int()); + assert!(ir().is_subtype(&i, &o)); + assert!(!ir().is_subtype(&o, &i)); + } + + #[test] + fn subtype_list() { + let l_i = mk_list(mk_int()); + let l_o = mk_list(mk_optional(mk_int())); + assert!(ir().is_subtype(&l_i, &l_o)); + assert!(!ir().is_subtype(&l_o, &l_i)); + } + + #[test] + fn subtype_tuple() { + let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); + let y = mk_tuple(vec![mk_int(), mk_int()]); + assert!(ir().is_subtype(&y, &x)); + assert!(!ir().is_subtype(&x, &y)); + } + + #[test] + fn subtype_map_of_list_of_unions() { + let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string()))); + let y = mk_str_map(mk_list(mk_union(vec![ + mk_str(), + mk_int(), + FieldType::Class("Foo".to_string()), + ]))); + assert!(ir().is_subtype(&x, &y)); + } + + #[test] + fn subtype_media() { + let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); + assert!(ir().is_subtype(&x, &x)); + } +} diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 7f529cd72..a13c5cc49 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -43,9 +43,6 @@ impl ArgCoercer { value: &BamlValue, // original value passed in by user scope: &mut ScopeStack, ) -> Result { - eprintln!("coerce_arg: {value:?} -> {field_type:?}"); - eprintln!("scope: {scope}\n"); - let value = match ir.distribute_constraints(field_type) { (FieldType::Primitive(t), _) => match t { TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()), @@ -331,7 +328,6 @@ impl ArgCoercer { let mut scope = ScopeStack::new(); if first_good_result.is_err() { let result = self.coerce_arg(ir, option, value, &mut scope); - eprintln!("union inner scope scope: {scope}\n"); if !scope.has_errors() && first_good_result.is_err() { first_good_result = result } @@ -466,7 +462,7 @@ mod tests { fn test_mutually_recursive_aliases() { let ir = make_test_ir( r##" -type JsonValue = int | string | bool | float | JsonObject | JsonArray +type JsonValue = int | bool | float | string | JsonArray | JsonObject type JsonObject = map type JsonArray = JsonValue[] "##, diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index 7008f26a5..52f59fae0 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -160,185 +160,4 @@ impl FieldType { _ => false, } } - - /// BAML does not support class-based subtyping. Nonetheless some builtin - /// BAML types are subtypes of others, and we need to be able to test this - /// when checking the types of values. - /// - /// For examples of pairs of types and their subtyping relationship, see - /// this module's test suite. - /// - /// Consider renaming this to `is_assignable_to`. - pub fn is_subtype_of(&self, other: &FieldType) -> bool { - if self == other { - return true; - } - - if let FieldType::Union(items) = other { - if items.iter().any(|item| self.is_subtype_of(item)) { - return true; - } - } - - match (self, other) { - (FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true, - (FieldType::Optional(self_item), FieldType::Optional(other_item)) => { - self_item.is_subtype_of(other_item) - } - (_, FieldType::Optional(t)) => self.is_subtype_of(t), - (FieldType::Optional(_), _) => false, - - // Handle types that nest other types. - (FieldType::List(self_item), FieldType::List(other_item)) => { - self_item.is_subtype_of(other_item) - } - (FieldType::List(_), _) => false, - - (FieldType::Map(self_k, self_v), FieldType::Map(other_k, other_v)) => { - other_k.is_subtype_of(self_k) && (**self_v).is_subtype_of(other_v) - } - (FieldType::Map(_, _), _) => false, - - ( - FieldType::Constrained { - base: self_base, - constraints: self_cs, - }, - FieldType::Constrained { - base: other_base, - constraints: other_cs, - }, - ) => self_base.is_subtype_of(other_base) && self_cs == other_cs, - (FieldType::Constrained { base, .. }, _) => base.is_subtype_of(other), - (_, FieldType::Constrained { base, .. }) => self.is_subtype_of(base), - (FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => { - true - } - (FieldType::Literal(LiteralValue::Bool(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::Bool)) - } - (FieldType::Literal(LiteralValue::Int(_)), FieldType::Primitive(TypeValue::Int)) => { - true - } - (FieldType::Literal(LiteralValue::Int(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::Int)) - } - ( - FieldType::Literal(LiteralValue::String(_)), - FieldType::Primitive(TypeValue::String), - ) => true, - (FieldType::Literal(LiteralValue::String(_)), _) => { - self.is_subtype_of(&FieldType::Primitive(TypeValue::String)) - } - - (FieldType::Union(self_items), _) => self_items - .iter() - .all(|self_item| self_item.is_subtype_of(other)), - - (FieldType::Tuple(self_items), FieldType::Tuple(other_items)) => { - self_items.len() == other_items.len() - && self_items - .iter() - .zip(other_items) - .all(|(self_item, other_item)| self_item.is_subtype_of(other_item)) - } - (FieldType::Tuple(_), _) => false, - (FieldType::Primitive(_), _) => false, - (FieldType::Enum(_), _) => false, - (FieldType::Class(_), _) => false, - (FieldType::RecursiveTypeAlias(_), _) => false, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn mk_int() -> FieldType { - FieldType::Primitive(TypeValue::Int) - } - fn mk_bool() -> FieldType { - FieldType::Primitive(TypeValue::Bool) - } - fn mk_str() -> FieldType { - FieldType::Primitive(TypeValue::String) - } - - fn mk_optional(ft: FieldType) -> FieldType { - FieldType::Optional(Box::new(ft)) - } - - fn mk_list(ft: FieldType) -> FieldType { - FieldType::List(Box::new(ft)) - } - - fn mk_tuple(ft: Vec) -> FieldType { - FieldType::Tuple(ft) - } - fn mk_union(ft: Vec) -> FieldType { - FieldType::Union(ft) - } - fn mk_str_map(ft: FieldType) -> FieldType { - FieldType::Map(Box::new(mk_str()), Box::new(ft)) - } - - #[test] - fn subtype_trivial() { - assert!(mk_int().is_subtype_of(&mk_int())) - } - - #[test] - fn subtype_union() { - let i = mk_int(); - let u = mk_union(vec![mk_int(), mk_str()]); - assert!(i.is_subtype_of(&u)); - assert!(!u.is_subtype_of(&i)); - - let u3 = mk_union(vec![mk_int(), mk_bool(), mk_str()]); - assert!(i.is_subtype_of(&u3)); - assert!(u.is_subtype_of(&u3)); - assert!(!u3.is_subtype_of(&u)); - } - - #[test] - fn subtype_optional() { - let i = mk_int(); - let o = mk_optional(mk_int()); - assert!(i.is_subtype_of(&o)); - assert!(!o.is_subtype_of(&i)); - } - - #[test] - fn subtype_list() { - let l_i = mk_list(mk_int()); - let l_o = mk_list(mk_optional(mk_int())); - assert!(l_i.is_subtype_of(&l_o)); - assert!(!l_o.is_subtype_of(&l_i)); - } - - #[test] - fn subtype_tuple() { - let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); - let y = mk_tuple(vec![mk_int(), mk_int()]); - assert!(y.is_subtype_of(&x)); - assert!(!x.is_subtype_of(&y)); - } - - #[test] - fn subtype_map_of_list_of_unions() { - let x = mk_str_map(mk_list(FieldType::Class("Foo".to_string()))); - let y = mk_str_map(mk_list(mk_union(vec![ - mk_str(), - mk_int(), - FieldType::Class("Foo".to_string()), - ]))); - assert!(x.is_subtype_of(&y)); - } - - #[test] - fn subtype_media() { - let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); - assert!(x.is_subtype_of(&x)); - } } diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs index 35baa352f..97f7f55ef 100644 --- a/engine/baml-lib/jsonish/src/tests/test_aliases.rs +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -60,7 +60,7 @@ type JsonValue = int | string | bool | JsonValue[] | map test_deserializer!( test_complex_recursive_alias, r#" -type JsonValue = int | string | bool | JsonValue[] | map +type JsonValue = int | bool | string | JsonValue[] | map "#, r#" { diff --git a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml index 216bfc2ec..2e407303d 100644 --- a/integ-tests/baml_src/test-files/functions/output/type-aliases.baml +++ b/integ-tests/baml_src/test-files/functions/output/type-aliases.baml @@ -133,4 +133,4 @@ function JsonTypeAliasCycle(input: JsonValue) -> JsonValue { {{ ctx.output_format }} "# -} \ No newline at end of file +} From d6b1e9e1ff6bbf3d8148927487789bfd0e547e74 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 14:01:10 +0100 Subject: [PATCH 11/19] Add integ tests for TS --- .../typescript/tests/integ-tests.test.ts | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/integ-tests/typescript/tests/integ-tests.test.ts b/integ-tests/typescript/tests/integ-tests.test.ts index 8eaeb690b..ab43a50c3 100644 --- a/integ-tests/typescript/tests/integ-tests.test.ts +++ b/integ-tests/typescript/tests/integ-tests.test.ts @@ -181,6 +181,58 @@ describe('Integ tests', () => { const res = await b.RecursiveClassWithAliasIndirection({ value: 1, next: { value: 2, next: null } }) expect(res).toEqual({ value: 1, next: { value: 2, next: null } }) }) + + it('merge alias attributes', async () => { + const res = await b.MergeAliasAttributes(123) + expect(res.amount.value).toEqual(123) + expect(res.amount.checks['gt_ten'].status).toEqual('succeeded') + }) + + it('return alias with merged attrs', async () => { + const res = await b.ReturnAliasWithMergedAttributes(123) + expect(res.value).toEqual(123) + expect(res.checks['gt_ten'].status).toEqual('succeeded') + }) + + it('alias with multiple attrs', async () => { + const res = await b.AliasWithMultipleAttrs(123) + expect(res.value).toEqual(123) + expect(res.checks['gt_ten'].status).toEqual('succeeded') + }) + + it('simple recursive map alias', async () => { + const res = await b.SimpleRecursiveMapAlias({ one: { two: { three: {} } } }) + expect(res).toEqual({ one: { two: { three: {} } } }) + }) + + it('simple recursive map alias', async () => { + const res = await b.SimpleRecursiveListAlias([[], [], [[]]]) + expect(res).toEqual([[], [], [[]]]) + }) + + it('recursive alias cycles', async () => { + const res = await b.RecursiveAliasCycle([[], [], [[]]]) + expect(res).toEqual([[], [], [[]]]) + }) + + it('json type alias cycle', async () => { + const data = { + number: 1, + string: 'test', + bool: true, + list: [1, 2, 3], + object: { number: 1, string: 'test', bool: true, list: [1, 2, 3] }, + json: { + number: 1, + string: 'test', + bool: true, + list: [1, 2, 3], + object: { number: 1, string: 'test', bool: true, list: [1, 2, 3] }, + }, + } + const res = await b.JsonTypeAliasCycle(data) + expect(res).toEqual(data) + }) }) it('should work for all outputs', async () => { From c5267b50a7fe8a0ed7143ab8b51d5359e438eab0 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 14:40:29 +0100 Subject: [PATCH 12/19] Remove recursion debug limit --- .../jsonish/src/deserializer/coercer/field_type.rs | 7 ------- integ-tests/typescript/test-report.html | 10 +++++++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index ab4d64cda..d0f984d1e 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -27,13 +27,6 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - unsafe { - LIMIT += 1; - if LIMIT > 500 { - panic!("Stack Overflow Bruh {}", LIMIT); - } - } - match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( diff --git a/integ-tests/typescript/test-report.html b/integ-tests/typescript/test-report.html index ea171ca6e..38bc00df5 100644 --- a/integ-tests/typescript/test-report.html +++ b/integ-tests/typescript/test-report.html @@ -257,9 +257,13 @@ font-size: 1rem; padding: 0 0.5rem; } -

Test Report

Started: 2024-12-09 16:43:58
Suites (1)
0 passed
1 failed
0 pending
Tests (73)
71 passed
2 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
0.64s
Integ tests > should work for all inputs
single string list
passed
0.718s
Integ tests > should work for all inputs
return literal union
passed
0.462s
Integ tests > should work for all inputs
single class
passed
0.469s
Integ tests > should work for all inputs
multiple classes
passed
0.464s
Integ tests > should work for all inputs
single enum list
passed
0.443s
Integ tests > should work for all inputs
single float
passed
0.408s
Integ tests > should work for all inputs
single int
passed
0.398s
Integ tests > should work for all inputs
single literal int
passed
0.428s
Integ tests > should work for all inputs
single literal bool
passed
0.401s
Integ tests > should work for all inputs
single literal string
passed
0.33s
Integ tests > should work for all inputs
single class with literal prop
passed
0.489s
Integ tests > should work for all inputs
single class with literal union prop
passed
1.094s
Integ tests > should work for all inputs
single optional string
passed
0.378s
Integ tests > should work for all inputs
single map string to string
passed
0.676s
Integ tests > should work for all inputs
single map string to class
passed
0.819s
Integ tests > should work for all inputs
single map string to map
passed
0.505s
Integ tests > should work for all inputs
enum key in map
passed
0.775s
Integ tests > should work for all inputs
literal string union key in map
passed
0.519s
Integ tests > should work for all inputs
single literal string key in map
passed
0.603s
Integ tests > should work for all inputs
primitive union alias
passed
0.413s
Integ tests > should work for all inputs
map alias
passed
0.869s
Integ tests > should work for all inputs
alias union
passed
1.331s
Integ tests > should work for all inputs
alias pointing to recursive class
passed
0.609s
Integ tests > should work for all inputs
class pointing to alias that points to recursive class
passed
0.932s
Integ tests > should work for all inputs
recursive class with alias indirection
passed
1.223s
Integ tests
should work for all outputs
passed
5.646s
Integ tests
works with retries1
passed
1.425s
Integ tests
works with retries2
passed
2.348s
Integ tests
works with fallbacks
passed
2.297s
Integ tests
should work with image from url
passed
1.386s
Integ tests
should work with image from base 64
passed
1.239s
Integ tests
should work with audio base 64
passed
1.734s
Integ tests
should work with audio from url
passed
1.844s
Integ tests
should support streaming in OpenAI
passed
2.423s
Integ tests
should support streaming in Gemini
passed
6.634s
Integ tests
should support AWS
passed
1.69s
Integ tests
should support streaming in AWS
passed
1.518s
Integ tests
should allow overriding the region
passed
0.062s
Integ tests
should support OpenAI shorthand
passed
8.762s
Integ tests
should support OpenAI shorthand streaming
passed
19.248s
Integ tests
should support anthropic shorthand
passed
2.567s
Integ tests
should support anthropic shorthand streaming
passed
4.385s
Integ tests
should support streaming without iterating
passed
2.164s
Integ tests
should support streaming in Claude
passed
2.722s
Integ tests
should support vertex
failed
0.003s
Error: BamlError: Failed to read service account file: 
+

Test Report

Started: 2024-12-18 13:37:55
Suites (1)
0 passed
1 failed
0 pending
Tests (80)
77 passed
3 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
0.468s
Integ tests > should work for all inputs
single string list
passed
0.409s
Integ tests > should work for all inputs
return literal union
passed
0.303s
Integ tests > should work for all inputs
single class
passed
0.515s
Integ tests > should work for all inputs
multiple classes
passed
0.512s
Integ tests > should work for all inputs
single enum list
passed
0.421s
Integ tests > should work for all inputs
single float
passed
0.397s
Integ tests > should work for all inputs
single int
passed
0.512s
Integ tests > should work for all inputs
single literal int
passed
0.41s
Integ tests > should work for all inputs
single literal bool
passed
0.611s
Integ tests > should work for all inputs
single literal string
passed
0.412s
Integ tests > should work for all inputs
single class with literal prop
passed
0.512s
Integ tests > should work for all inputs
single class with literal union prop
passed
0.512s
Integ tests > should work for all inputs
single optional string
passed
0.304s
Integ tests > should work for all inputs
single map string to string
passed
0.516s
Integ tests > should work for all inputs
single map string to class
passed
0.713s
Integ tests > should work for all inputs
single map string to map
passed
0.614s
Integ tests > should work for all inputs
enum key in map
passed
0.821s
Integ tests > should work for all inputs
literal string union key in map
passed
0.819s
Integ tests > should work for all inputs
single literal string key in map
passed
0.92s
Integ tests > should work for all inputs
primitive union alias
passed
0.516s
Integ tests > should work for all inputs
map alias
passed
1.122s
Integ tests > should work for all inputs
alias union
passed
1.433s
Integ tests > should work for all inputs
alias pointing to recursive class
passed
0.817s
Integ tests > should work for all inputs
class pointing to alias that points to recursive class
passed
0.905s
Integ tests > should work for all inputs
recursive class with alias indirection
passed
0.735s
Integ tests > should work for all inputs
merge alias attributes
passed
0.669s
Integ tests > should work for all inputs
return alias with merged attrs
passed
0.422s
Integ tests > should work for all inputs
alias with multiple attrs
passed
0.472s
Integ tests > should work for all inputs
simple recursive map alias
passed
1.508s
Integ tests > should work for all inputs
simple recursive map alias
passed
0.499s
Integ tests > should work for all inputs
recursive alias cycles
passed
0.525s
Integ tests > should work for all inputs
json type alias cycle
failed
2.478s
Error: expect(received).toEqual(expected) // deep equality
+
+Expected: {"bool": true, "json": {"bool": true, "list": [1, 2, 3], "number": 1, "object": {"bool": true, "list": [1, 2, 3], "number": 1, "string": "test"}, "string": "test"}, "list": [1, 2, 3], "number": 1, "object": {"bool": true, "list": [1, 2, 3], "number": 1, "string": "test"}, "string": "test"}
+Received: "{number: 1, string: test, bool: true, list: [1, 2, 3], object: {number: 1, string: test, bool: true, list: [1, 2, 3]}, json: {number: 1, string: test, bool: true, list: [1, 2, 3], object: {number: 1, string: test, bool: true, list: [1, 2, 3]}}}"
+    at Object.toEqual (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:234:19)
Integ tests
should work for all outputs
passed
4.88s
Integ tests
works with retries1
passed
1.131s
Integ tests
works with retries2
passed
2.247s
Integ tests
works with fallbacks
passed
2.048s
Integ tests
should work with image from url
passed
1.741s
Integ tests
should work with image from base 64
passed
1.571s
Integ tests
should work with audio base 64
passed
1.164s
Integ tests
should work with audio from url
passed
1.708s
Integ tests
should support streaming in OpenAI
passed
2.114s
Integ tests
should support streaming in Gemini
passed
11.139s
Integ tests
should support AWS
passed
1.553s
Integ tests
should support streaming in AWS
passed
1.608s
Integ tests
should allow overriding the region
passed
0.037s
Integ tests
should support OpenAI shorthand
passed
11.054s
Integ tests
should support OpenAI shorthand streaming
passed
10.68s
Integ tests
should support anthropic shorthand
passed
3.964s
Integ tests
should support anthropic shorthand streaming
passed
3.493s
Integ tests
should support streaming without iterating
passed
2.541s
Integ tests
should support streaming in Claude
passed
1.412s
Integ tests
should support vertex
failed
0.001s
Error: BamlError: Failed to read service account file: 
 
 Caused by:
-    No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.023s
Integ tests
supports tracing async
passed
4.933s
Integ tests
should work with dynamic types single
passed
1.024s
Integ tests
should work with dynamic types enum
passed
1.11s
Integ tests
should work with dynamic literals
passed
0.932s
Integ tests
should work with dynamic types class
passed
1.335s
Integ tests
should work with dynamic inputs class
passed
0.515s
Integ tests
should work with dynamic inputs list
passed
0.608s
Integ tests
should work with dynamic output map
passed
1.073s
Integ tests
should work with dynamic output union
passed
1.795s
Integ tests
should work with nested classes
failed
0.109s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
+    No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.017s
Integ tests
supports tracing async
passed
2.316s
Integ tests
should work with dynamic types single
passed
1.481s
Integ tests
should work with dynamic types enum
passed
0.99s
Integ tests
should work with dynamic literals
passed
0.954s
Integ tests
should work with dynamic types class
passed
1.434s
Integ tests
should work with dynamic inputs class
passed
0.511s
Integ tests
should work with dynamic inputs list
passed
0.52s
Integ tests
should work with dynamic output map
passed
0.66s
Integ tests
should work with dynamic output union
passed
1.791s
Integ tests
should work with nested classes
failed
0.104s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
     at BamlStream.parsed [as getFinalResponse] (/workspaces/baml/engine/language_client_typescript/stream.js:58:39)
-    at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:635:19)
Integ tests
should work with dynamic client
passed
0.493s
Integ tests
should work with 'onLogEvent'
passed
2.269s
Integ tests
should work with a sync client
passed
0.701s
Integ tests
should raise an error when appropriate
passed
1.082s
Integ tests
should raise a BAMLValidationError
passed
0.66s
Integ tests
should reset environment variables correctly
passed
2.453s
Integ tests
should use aliases when serializing input objects - classes
passed
0.94s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
0.802s
Integ tests
should use aliases when serializing input objects - enums
passed
0.408s
Integ tests
should use aliases when serializing input objects - lists
passed
0.513s
Integ tests
constraints: should handle checks in return types
passed
0.638s
Integ tests
constraints: should handle checks in returned unions
passed
0.709s
Integ tests
constraints: should handle block-level checks
passed
2.75s
Integ tests
constraints: should handle nested-block-level checks
passed
0.612s
Integ tests
simple recursive type
passed
2.458s
Integ tests
mutually recursive type
passed
2.048s
\ No newline at end of file + at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:687:19)
Integ tests
should work with dynamic client
passed
0.34s
Integ tests
should work with 'onLogEvent'
passed
1.79s
Integ tests
should work with a sync client
passed
0.417s
Integ tests
should raise an error when appropriate
passed
1.026s
Integ tests
should raise a BAMLValidationError
passed
0.512s
Integ tests
should reset environment variables correctly
passed
5.938s
Integ tests
should use aliases when serializing input objects - classes
passed
0.813s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
0.824s
Integ tests
should use aliases when serializing input objects - enums
passed
0.719s
Integ tests
should use aliases when serializing input objects - lists
passed
0.407s
Integ tests
constraints: should handle checks in return types
passed
0.676s
Integ tests
constraints: should handle checks in returned unions
passed
1.07s
Integ tests
constraints: should handle block-level checks
passed
0.598s
Integ tests
constraints: should handle nested-block-level checks
passed
0.625s
Integ tests
simple recursive type
passed
3.175s
Integ tests
mutually recursive type
passed
2.458s
\ No newline at end of file From cac1a1699827b90c288b3ce462daf21ec1c78111 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 16:04:15 +0100 Subject: [PATCH 13/19] Add more tests (doesn't work because of score function) --- .../jsonish/src/tests/test_aliases.rs | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs index 97f7f55ef..7d819646b 100644 --- a/engine/baml-lib/jsonish/src/tests/test_aliases.rs +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -38,7 +38,7 @@ type C = A[] ); test_deserializer!( - test_recursive_alias_union, + test_json_without_nested_objects, r#" type JsonValue = int | string | bool | JsonValue[] | map "#, @@ -58,9 +58,61 @@ type JsonValue = int | string | bool | JsonValue[] | map ); test_deserializer!( - test_complex_recursive_alias, + test_json_with_nested_list, r#" -type JsonValue = int | bool | string | JsonValue[] | map +type JsonValue = int | string | bool | JsonValue[] | map + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } +); + +test_deserializer!( + test_json_with_nested_object, + r#" +type JsonValue = int | bool | JsonValue[] | map | string + "#, + r#" + { + "number": 1, + "string": "test", + "bool": true, + "json": { + "number": 1, + "string": "test", + "bool": true + } + } + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + { + "number": 1, + "string": "test", + "bool": true, + "json": { + "number": 1, + "string": "test", + "bool": true + } + } +); + +test_deserializer!( + test_full_json_with_nested_objects, + r#" +type JsonValue = int | bool | JsonValue[] | map | string "#, r#" { From 39141cb3fd6320c385c9efcc3c8f20f14182348d Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 16:36:48 +0100 Subject: [PATCH 14/19] Add codegen for TS --- .../src/typescript/generate_types.rs | 29 +++++++++++++++++-- .../src/typescript/templates/types.ts.j2 | 5 ++++ integ-tests/python/baml_client/inlinedbaml.py | 2 +- integ-tests/ruby/baml_client/inlined.rb | 2 +- .../typescript/baml_client/inlinedbaml.ts | 2 +- integ-tests/typescript/baml_client/types.ts | 16 ++++++++++ 6 files changed, 51 insertions(+), 5 deletions(-) diff --git a/engine/language_client_codegen/src/typescript/generate_types.rs b/engine/language_client_codegen/src/typescript/generate_types.rs index 0022f110a..195764bc5 100644 --- a/engine/language_client_codegen/src/typescript/generate_types.rs +++ b/engine/language_client_codegen/src/typescript/generate_types.rs @@ -4,8 +4,8 @@ use anyhow::Result; use itertools::Itertools; use internal_baml_core::ir::{ - repr::{Docstring, IntermediateRepr}, - ClassWalker, EnumWalker, + repr::{Docstring, IntermediateRepr, Walker}, + ClassWalker, EnumWalker, FieldType, }; use crate::{type_check_attributes, GeneratorArgs, TypeCheckAttributes}; @@ -24,6 +24,7 @@ pub(crate) struct TypeBuilder<'ir> { pub(crate) struct TypescriptTypes<'ir> { enums: Vec>, classes: Vec>, + structural_recursive_alias_cycles: Vec>, } struct TypescriptEnum<'ir> { @@ -40,6 +41,11 @@ pub struct TypescriptClass<'ir> { pub docstring: Option, } +struct TypescriptTypeAlias<'ir> { + name: Cow<'ir, str>, + target: String, +} + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTypes<'ir> { type Error = anyhow::Error; @@ -55,6 +61,10 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTyp .walk_classes() .map(|e| Into::::into(&e)) .collect::>(), + structural_recursive_alias_cycles: ir + .walk_alias_cycles() + .map(TypescriptTypeAlias::from) + .collect::>(), }) } } @@ -132,6 +142,21 @@ impl<'ir> From<&ClassWalker<'ir>> for TypescriptClass<'ir> { } } +// TODO: Define AliasWalker to simplify type. +impl<'ir> From> for TypescriptTypeAlias<'ir> { + fn from( + Walker { + db, + item: (name, target), + }: Walker<(&'ir String, &'ir FieldType)>, + ) -> Self { + Self { + name: Cow::Borrowed(name), + target: target.to_type_ref(db), + } + } +} + pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { checks .0 diff --git a/engine/language_client_codegen/src/typescript/templates/types.ts.j2 b/engine/language_client_codegen/src/typescript/templates/types.ts.j2 index 91ab34165..308967604 100644 --- a/engine/language_client_codegen/src/typescript/templates/types.ts.j2 +++ b/engine/language_client_codegen/src/typescript/templates/types.ts.j2 @@ -52,3 +52,8 @@ export interface {{cls.name}} { {%- endif %} } {% endfor %} + +{#- Type Aliases -#} +{% for alias in structural_recursive_alias_cycles %} +type {{alias.name}} = {{alias.target}} +{% endfor %} diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index 4d7de9465..fcb867aef 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -83,7 +83,7 @@ "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index 5ce3d619d..ca979d424 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -83,7 +83,7 @@ module Inlined "test-files/functions/output/recursive-type-aliases.baml" => "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml" => "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml" => "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml" => "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/unions.baml" => "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml" => "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml" => "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index 56ad031f6..c01a4def0 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -84,7 +84,7 @@ const fileMap = { "test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}", "test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}", + "test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype JsonValue = int | string | bool | float | JsonObject | JsonArray\ntype JsonObject = map\ntype JsonArray = JsonValue[]\n\nfunction JsonTypeAliasCycle(input: JsonValue) -> JsonValue {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given input back:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n", "test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}", diff --git a/integ-tests/typescript/baml_client/types.ts b/integ-tests/typescript/baml_client/types.ts index 17a6fb1d9..8f9bb995a 100644 --- a/integ-tests/typescript/baml_client/types.ts +++ b/integ-tests/typescript/baml_client/types.ts @@ -630,3 +630,19 @@ export interface WithReasoning { reasoning: string } + +type RecursiveMapAlias = Record + +type RecursiveListAlias = RecursiveListAlias[] + +type RecAliasOne = RecAliasTwo + +type RecAliasTwo = RecAliasThree + +type RecAliasThree = RecAliasOne[] + +type JsonValue = number | string | boolean | number | JsonObject | JsonArray + +type JsonObject = Record + +type JsonArray = JsonValue[] From 401a97d178c7408353dbf2b61a78e263d881973e Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 17:08:31 +0100 Subject: [PATCH 15/19] Add docs for Ruby type alias --- engine/language_client_codegen/src/ruby/field_type.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index 91e37e33e..c6cdba590 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -9,7 +9,8 @@ impl ToRuby for FieldType { match self { FieldType::Class(name) => format!("Baml::Types::{}", name.clone()), FieldType::Enum(name) => format!("T.any(Baml::Types::{}, String)", name.clone()), - // TODO: Can we define recursive aliases in Ruby with Sorbet? + // Sorbet does not support recursive type aliases. + // https://sorbet.org/docs/type-aliases FieldType::RecursiveTypeAlias(_name) => "T.anything".to_string(), // TODO: Temporary solution until we figure out Ruby literals. FieldType::Literal(value) => value.literal_base_type().to_ruby(), From fc25050f0383d00374dda6ff2192ea2a831c45c4 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 17:42:32 +0100 Subject: [PATCH 16/19] Fix OpenAPI map keys --- engine/language_client_codegen/src/openapi.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index 33020edcd..9fb2cca91 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -573,8 +573,14 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { }), }, FieldType::Map(key, value) => { - if !matches!(**key, FieldType::Primitive(TypeValue::String)) { - anyhow::bail!("BAML<->OpenAPI only supports string keys in maps") + if !matches!( + **key, + FieldType::Primitive(TypeValue::String) + | FieldType::Enum(_) + | FieldType::Literal(LiteralValue::String(_)) + | FieldType::Union(_) + ) { + anyhow::bail!(format!("BAML<->OpenAPI only supports strings, enums and literal strings as map keys but got {key}")) } TypeSpecWithMeta { meta: TypeMetadata { From 342fb5e437b0d785f69534587c887868c2de81de Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 18:15:57 +0100 Subject: [PATCH 17/19] Fix score of `JsonToString` flag --- engine/baml-lib/jsonish/src/deserializer/score.rs | 2 +- engine/baml-lib/jsonish/src/tests/test_aliases.rs | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index f198702bd..17065f137 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -48,7 +48,7 @@ impl WithScore for Flag { Flag::StrippedNonAlphaNumeric(_) => 3, Flag::SubstringMatch(_) => 2, Flag::ImpliedKey(_) => 2, - Flag::JsonToString(_) => 2, + Flag::JsonToString(_) => 5, Flag::SingleToArray => 1, // Parsing errors are bad. Flag::ArrayItemParseError(x, _) => 1 + (*x as i32), diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs index 7d819646b..0b4a944eb 100644 --- a/engine/baml-lib/jsonish/src/tests/test_aliases.rs +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -40,18 +40,20 @@ type C = A[] test_deserializer!( test_json_without_nested_objects, r#" -type JsonValue = int | string | bool | JsonValue[] | map +type JsonValue = int | float | string | bool | JsonValue[] | map "#, r#" { - "number": 1, + "int": 1, + "float": 1.0, "string": "test", "bool": true } "#, FieldType::RecursiveTypeAlias("JsonValue".into()), { - "number": 1, + "int": 1, + "float": 1.0, "string": "test", "bool": true } @@ -82,7 +84,7 @@ type JsonValue = int | string | bool | JsonValue[] | map test_deserializer!( test_json_with_nested_object, r#" -type JsonValue = int | bool | JsonValue[] | map | string +type JsonValue = int | bool | string | JsonValue[] | map "#, r#" { @@ -112,7 +114,7 @@ type JsonValue = int | bool | JsonValue[] | map | string test_deserializer!( test_full_json_with_nested_objects, r#" -type JsonValue = int | bool | JsonValue[] | map | string +type JsonValue = int | bool | string JsonValue[] | map "#, r#" { From 2e10579e32f5b46725499c4fd741ff5bbe42c894 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 18:23:21 +0100 Subject: [PATCH 18/19] Fix integ tests for json type cycle --- integ-tests/python/tests/test_functions.py | 1 + integ-tests/typescript/test-report.html | 10 +++------- integ-tests/typescript/tests/integ-tests.test.ts | 1 + 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index 8c96406e9..ead15c024 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -353,6 +353,7 @@ async def test_json_type_alias_cycle(self): res = await b.JsonTypeAliasCycle(data) assert res == data + assert res["json"]["object"]["list"] == [1, 2, 3] class MyCustomClass(NamedArgsSingleClass): diff --git a/integ-tests/typescript/test-report.html b/integ-tests/typescript/test-report.html index 38bc00df5..87108c133 100644 --- a/integ-tests/typescript/test-report.html +++ b/integ-tests/typescript/test-report.html @@ -257,13 +257,9 @@ font-size: 1rem; padding: 0 0.5rem; } -

Test Report

Started: 2024-12-18 13:37:55
Suites (1)
0 passed
1 failed
0 pending
Tests (80)
77 passed
3 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
0.468s
Integ tests > should work for all inputs
single string list
passed
0.409s
Integ tests > should work for all inputs
return literal union
passed
0.303s
Integ tests > should work for all inputs
single class
passed
0.515s
Integ tests > should work for all inputs
multiple classes
passed
0.512s
Integ tests > should work for all inputs
single enum list
passed
0.421s
Integ tests > should work for all inputs
single float
passed
0.397s
Integ tests > should work for all inputs
single int
passed
0.512s
Integ tests > should work for all inputs
single literal int
passed
0.41s
Integ tests > should work for all inputs
single literal bool
passed
0.611s
Integ tests > should work for all inputs
single literal string
passed
0.412s
Integ tests > should work for all inputs
single class with literal prop
passed
0.512s
Integ tests > should work for all inputs
single class with literal union prop
passed
0.512s
Integ tests > should work for all inputs
single optional string
passed
0.304s
Integ tests > should work for all inputs
single map string to string
passed
0.516s
Integ tests > should work for all inputs
single map string to class
passed
0.713s
Integ tests > should work for all inputs
single map string to map
passed
0.614s
Integ tests > should work for all inputs
enum key in map
passed
0.821s
Integ tests > should work for all inputs
literal string union key in map
passed
0.819s
Integ tests > should work for all inputs
single literal string key in map
passed
0.92s
Integ tests > should work for all inputs
primitive union alias
passed
0.516s
Integ tests > should work for all inputs
map alias
passed
1.122s
Integ tests > should work for all inputs
alias union
passed
1.433s
Integ tests > should work for all inputs
alias pointing to recursive class
passed
0.817s
Integ tests > should work for all inputs
class pointing to alias that points to recursive class
passed
0.905s
Integ tests > should work for all inputs
recursive class with alias indirection
passed
0.735s
Integ tests > should work for all inputs
merge alias attributes
passed
0.669s
Integ tests > should work for all inputs
return alias with merged attrs
passed
0.422s
Integ tests > should work for all inputs
alias with multiple attrs
passed
0.472s
Integ tests > should work for all inputs
simple recursive map alias
passed
1.508s
Integ tests > should work for all inputs
simple recursive map alias
passed
0.499s
Integ tests > should work for all inputs
recursive alias cycles
passed
0.525s
Integ tests > should work for all inputs
json type alias cycle
failed
2.478s
Error: expect(received).toEqual(expected) // deep equality
-
-Expected: {"bool": true, "json": {"bool": true, "list": [1, 2, 3], "number": 1, "object": {"bool": true, "list": [1, 2, 3], "number": 1, "string": "test"}, "string": "test"}, "list": [1, 2, 3], "number": 1, "object": {"bool": true, "list": [1, 2, 3], "number": 1, "string": "test"}, "string": "test"}
-Received: "{number: 1, string: test, bool: true, list: [1, 2, 3], object: {number: 1, string: test, bool: true, list: [1, 2, 3]}, json: {number: 1, string: test, bool: true, list: [1, 2, 3], object: {number: 1, string: test, bool: true, list: [1, 2, 3]}}}"
-    at Object.toEqual (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:234:19)
Integ tests
should work for all outputs
passed
4.88s
Integ tests
works with retries1
passed
1.131s
Integ tests
works with retries2
passed
2.247s
Integ tests
works with fallbacks
passed
2.048s
Integ tests
should work with image from url
passed
1.741s
Integ tests
should work with image from base 64
passed
1.571s
Integ tests
should work with audio base 64
passed
1.164s
Integ tests
should work with audio from url
passed
1.708s
Integ tests
should support streaming in OpenAI
passed
2.114s
Integ tests
should support streaming in Gemini
passed
11.139s
Integ tests
should support AWS
passed
1.553s
Integ tests
should support streaming in AWS
passed
1.608s
Integ tests
should allow overriding the region
passed
0.037s
Integ tests
should support OpenAI shorthand
passed
11.054s
Integ tests
should support OpenAI shorthand streaming
passed
10.68s
Integ tests
should support anthropic shorthand
passed
3.964s
Integ tests
should support anthropic shorthand streaming
passed
3.493s
Integ tests
should support streaming without iterating
passed
2.541s
Integ tests
should support streaming in Claude
passed
1.412s
Integ tests
should support vertex
failed
0.001s
Error: BamlError: Failed to read service account file: 
+

Test Report

Started: 2024-12-18 17:20:38
Suites (1)
0 passed
1 failed
0 pending
Tests (80)
78 passed
2 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
0.39s
Integ tests > should work for all inputs
single string list
passed
0.484s
Integ tests > should work for all inputs
return literal union
passed
0.354s
Integ tests > should work for all inputs
single class
passed
0.41s
Integ tests > should work for all inputs
multiple classes
passed
0.399s
Integ tests > should work for all inputs
single enum list
passed
0.392s
Integ tests > should work for all inputs
single float
passed
0.356s
Integ tests > should work for all inputs
single int
passed
0.336s
Integ tests > should work for all inputs
single literal int
passed
0.314s
Integ tests > should work for all inputs
single literal bool
passed
0.388s
Integ tests > should work for all inputs
single literal string
passed
0.403s
Integ tests > should work for all inputs
single class with literal prop
passed
0.764s
Integ tests > should work for all inputs
single class with literal union prop
passed
0.512s
Integ tests > should work for all inputs
single optional string
passed
0.93s
Integ tests > should work for all inputs
single map string to string
passed
0.644s
Integ tests > should work for all inputs
single map string to class
passed
0.687s
Integ tests > should work for all inputs
single map string to map
passed
0.529s
Integ tests > should work for all inputs
enum key in map
passed
0.911s
Integ tests > should work for all inputs
literal string union key in map
passed
0.788s
Integ tests > should work for all inputs
single literal string key in map
passed
0.611s
Integ tests > should work for all inputs
primitive union alias
passed
0.657s
Integ tests > should work for all inputs
map alias
passed
1.155s
Integ tests > should work for all inputs
alias union
passed
1.45s
Integ tests > should work for all inputs
alias pointing to recursive class
passed
0.753s
Integ tests > should work for all inputs
class pointing to alias that points to recursive class
passed
1.114s
Integ tests > should work for all inputs
recursive class with alias indirection
passed
0.865s
Integ tests > should work for all inputs
merge alias attributes
passed
0.558s
Integ tests > should work for all inputs
return alias with merged attrs
passed
0.749s
Integ tests > should work for all inputs
alias with multiple attrs
passed
0.673s
Integ tests > should work for all inputs
simple recursive map alias
passed
0.701s
Integ tests > should work for all inputs
simple recursive map alias
passed
0.745s
Integ tests > should work for all inputs
recursive alias cycles
passed
0.498s
Integ tests > should work for all inputs
json type alias cycle
passed
2.621s
Integ tests
should work for all outputs
passed
4.553s
Integ tests
works with retries1
passed
1.172s
Integ tests
works with retries2
passed
2.265s
Integ tests
works with fallbacks
passed
1.893s
Integ tests
should work with image from url
passed
1.86s
Integ tests
should work with image from base 64
passed
1.766s
Integ tests
should work with audio base 64
passed
1.889s
Integ tests
should work with audio from url
passed
2.357s
Integ tests
should support streaming in OpenAI
passed
2.328s
Integ tests
should support streaming in Gemini
passed
8.829s
Integ tests
should support AWS
passed
1.549s
Integ tests
should support streaming in AWS
passed
1.623s
Integ tests
should allow overriding the region
passed
0.047s
Integ tests
should support OpenAI shorthand
passed
18.752s
Integ tests
should support OpenAI shorthand streaming
passed
13.463s
Integ tests
should support anthropic shorthand
passed
3.58s
Integ tests
should support anthropic shorthand streaming
passed
3.213s
Integ tests
should support streaming without iterating
passed
4.054s
Integ tests
should support streaming in Claude
passed
1.013s
Integ tests
should support vertex
failed
0.002s
Error: BamlError: Failed to read service account file: 
 
 Caused by:
-    No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.017s
Integ tests
supports tracing async
passed
2.316s
Integ tests
should work with dynamic types single
passed
1.481s
Integ tests
should work with dynamic types enum
passed
0.99s
Integ tests
should work with dynamic literals
passed
0.954s
Integ tests
should work with dynamic types class
passed
1.434s
Integ tests
should work with dynamic inputs class
passed
0.511s
Integ tests
should work with dynamic inputs list
passed
0.52s
Integ tests
should work with dynamic output map
passed
0.66s
Integ tests
should work with dynamic output union
passed
1.791s
Integ tests
should work with nested classes
failed
0.104s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
+    No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.015s
Integ tests
supports tracing async
passed
3.008s
Integ tests
should work with dynamic types single
passed
1.334s
Integ tests
should work with dynamic types enum
passed
0.981s
Integ tests
should work with dynamic literals
passed
1.422s
Integ tests
should work with dynamic types class
passed
0.984s
Integ tests
should work with dynamic inputs class
passed
0.553s
Integ tests
should work with dynamic inputs list
passed
0.654s
Integ tests
should work with dynamic output map
passed
0.672s
Integ tests
should work with dynamic output union
passed
1.875s
Integ tests
should work with nested classes
failed
0.107s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
     at BamlStream.parsed [as getFinalResponse] (/workspaces/baml/engine/language_client_typescript/stream.js:58:39)
-    at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:687:19)
Integ tests
should work with dynamic client
passed
0.34s
Integ tests
should work with 'onLogEvent'
passed
1.79s
Integ tests
should work with a sync client
passed
0.417s
Integ tests
should raise an error when appropriate
passed
1.026s
Integ tests
should raise a BAMLValidationError
passed
0.512s
Integ tests
should reset environment variables correctly
passed
5.938s
Integ tests
should use aliases when serializing input objects - classes
passed
0.813s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
0.824s
Integ tests
should use aliases when serializing input objects - enums
passed
0.719s
Integ tests
should use aliases when serializing input objects - lists
passed
0.407s
Integ tests
constraints: should handle checks in return types
passed
0.676s
Integ tests
constraints: should handle checks in returned unions
passed
1.07s
Integ tests
constraints: should handle block-level checks
passed
0.598s
Integ tests
constraints: should handle nested-block-level checks
passed
0.625s
Integ tests
simple recursive type
passed
3.175s
Integ tests
mutually recursive type
passed
2.458s
\ No newline at end of file + at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:688:19)
Integ tests
should work with dynamic client
passed
0.39s
Integ tests
should work with 'onLogEvent'
passed
1.623s
Integ tests
should work with a sync client
passed
0.539s
Integ tests
should raise an error when appropriate
passed
0.971s
Integ tests
should raise a BAMLValidationError
passed
0.419s
Integ tests
should reset environment variables correctly
passed
1.657s
Integ tests
should use aliases when serializing input objects - classes
passed
1.005s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
0.851s
Integ tests
should use aliases when serializing input objects - enums
passed
0.369s
Integ tests
should use aliases when serializing input objects - lists
passed
0.374s
Integ tests
constraints: should handle checks in return types
passed
0.737s
Integ tests
constraints: should handle checks in returned unions
passed
0.788s
Integ tests
constraints: should handle block-level checks
passed
0.531s
Integ tests
constraints: should handle nested-block-level checks
passed
0.605s
Integ tests
simple recursive type
passed
2.579s
Integ tests
mutually recursive type
passed
2.009s
\ No newline at end of file diff --git a/integ-tests/typescript/tests/integ-tests.test.ts b/integ-tests/typescript/tests/integ-tests.test.ts index ab43a50c3..3ba5b9745 100644 --- a/integ-tests/typescript/tests/integ-tests.test.ts +++ b/integ-tests/typescript/tests/integ-tests.test.ts @@ -232,6 +232,7 @@ describe('Integ tests', () => { } const res = await b.JsonTypeAliasCycle(data) expect(res).toEqual(data) + expect(res.json.object.list).toEqual([1, 2, 3]) }) }) From 8ff0397931021fcc5d45d4704e10e44fa44f1947 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Wed, 18 Dec 2024 20:22:14 +0100 Subject: [PATCH 19/19] Fix scoring ranking whaterver --- .../src/deserializer/coercer/array_helper.rs | 31 ++++++++++- .../src/deserializer/coercer/field_type.rs | 15 ++++- .../jsonish/src/deserializer/score.rs | 2 +- .../jsonish/src/deserializer/types.rs | 16 ++++++ .../jsonish/src/tests/test_aliases.rs | 55 ++++++++++++++++++- .../jsonish/src/tests/test_constraints.rs | 14 ++--- .../baml-lib/jsonish/src/tests/test_maps.rs | 1 + 7 files changed, 122 insertions(+), 12 deletions(-) diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs index 0f22e5bbf..5c1deb6a9 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/array_helper.rs @@ -14,7 +14,13 @@ pub fn coerce_array_to_singular( ) -> Result { let parsed = items.iter().map(|item| coercion(item)).collect::>(); - pick_best(ctx, target, &parsed) + let mut best = pick_best(ctx, target, &parsed); + + if let Ok(ref mut f) = best { + f.add_flag(Flag::FirstMatch(0, parsed.to_vec())) + } + + best } pub(super) fn pick_best( @@ -180,6 +186,29 @@ pub(super) fn pick_best( } } + // Devalue strings that were cast from objects. + if !a_val.is_composite() && b_val.is_composite() { + if a_val + .conditions() + .flags() + .iter() + .any(|f| matches!(f, Flag::JsonToString(..) | Flag::FirstMatch(_, _))) + { + return std::cmp::Ordering::Greater; + } + } + + if a_val.is_composite() && !b_val.is_composite() { + if b_val + .conditions() + .flags() + .iter() + .any(|f| matches!(f, Flag::JsonToString(..) | Flag::FirstMatch(_, _))) + { + return std::cmp::Ordering::Less; + } + } + match a_default.cmp(&b_default) { std::cmp::Ordering::Equal => match a_score.cmp(&b_score) { std::cmp::Ordering::Equal => a.cmp(&b), diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index d0f984d1e..df7cf8048 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -27,7 +27,12 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - match value { + unsafe { + LIMIT += 1; + eprintln!("LIMIT: {}", LIMIT); + } + + let ret_v = match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( "scope: {scope} :: coercing to: {name} (current: {current})", @@ -112,7 +117,15 @@ impl TypeCoercer for FieldType { Ok(coerced_value) } }, + }; + + unsafe { + LIMIT -= 1; } + + eprintln!("ret_v: {:?}", ret_v); + + ret_v } } diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index 17065f137..f198702bd 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -48,7 +48,7 @@ impl WithScore for Flag { Flag::StrippedNonAlphaNumeric(_) => 3, Flag::SubstringMatch(_) => 2, Flag::ImpliedKey(_) => 2, - Flag::JsonToString(_) => 5, + Flag::JsonToString(_) => 2, Flag::SingleToArray => 1, // Parsing errors are bad. Flag::ArrayItemParseError(x, _) => 1 + (*x as i32), diff --git a/engine/baml-lib/jsonish/src/deserializer/types.rs b/engine/baml-lib/jsonish/src/deserializer/types.rs index 1303ce7bc..fae0ee606 100644 --- a/engine/baml-lib/jsonish/src/deserializer/types.rs +++ b/engine/baml-lib/jsonish/src/deserializer/types.rs @@ -33,6 +33,22 @@ pub enum BamlValueWithFlags { } impl BamlValueWithFlags { + pub fn is_composite(&self) -> bool { + match self { + BamlValueWithFlags::String(_) => false, + BamlValueWithFlags::Int(_) => false, + BamlValueWithFlags::Float(_) => false, + BamlValueWithFlags::Bool(_) => false, + BamlValueWithFlags::Null(_) => false, + BamlValueWithFlags::Enum(_, value_with_flags) => false, + + BamlValueWithFlags::List(deserializer_conditions, vec) => true, + BamlValueWithFlags::Map(deserializer_conditions, index_map) => true, + BamlValueWithFlags::Class(_, deserializer_conditions, index_map) => true, + BamlValueWithFlags::Media(value_with_flags) => true, + } + } + pub fn score(&self) -> i32 { match self { BamlValueWithFlags::String(f) => f.score(), diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs index 0b4a944eb..639bf949b 100644 --- a/engine/baml-lib/jsonish/src/tests/test_aliases.rs +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -30,7 +30,7 @@ test_deserializer!( r#" type A = B type B = C -type C = A[] +type C = A[] "#, "[[], [], [[]]]", FieldType::RecursiveTypeAlias("A".into()), @@ -114,7 +114,7 @@ type JsonValue = int | bool | string | JsonValue[] | map test_deserializer!( test_full_json_with_nested_objects, r#" -type JsonValue = int | bool | string JsonValue[] | map +type JsonValue = JsonValue[] | map | int | bool | string "#, r#" { @@ -168,3 +168,54 @@ type JsonValue = int | bool | string JsonValue[] | map } } ); + +test_deserializer!( + test_list_of_json_objects, + r#" +type JsonValue = int | string | bool | JsonValue[] | map + "#, + r#" + [ + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + ] + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + [ + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + ] +); + +test_deserializer!( + test_nested_list, + r#" +type JsonValue = int | float | bool | string | JsonValue[] | map + "#, + r#" + [[42.1]] + "#, + FieldType::RecursiveTypeAlias("JsonValue".into()), + // [[[[[[[[[[[[[[[[[[[[42]]]]]]]]]]]]]]]]]]]] + [[42.1]] +); diff --git a/engine/baml-lib/jsonish/src/tests/test_constraints.rs b/engine/baml-lib/jsonish/src/tests/test_constraints.rs index f543cd115..728c1bf67 100644 --- a/engine/baml-lib/jsonish/src/tests/test_constraints.rs +++ b/engine/baml-lib/jsonish/src/tests/test_constraints.rs @@ -16,7 +16,7 @@ test_deserializer_with_expected_score!( CLASS_FOO_INT_STRING, r#"{"age": 11, "name": "Greg"}"#, FieldType::Class("Foo".to_string()), - 0 + 1 ); test_deserializer_with_expected_score!( @@ -24,7 +24,7 @@ test_deserializer_with_expected_score!( CLASS_FOO_INT_STRING, r#"{"age": 21, "name": "Grog"}"#, FieldType::Class("Foo".to_string()), - 0 + 1 ); test_failing_deserializer!( @@ -61,7 +61,7 @@ test_deserializer_with_expected_score!( UNION_WITH_CHECKS, r#"{"bar": 5, "things":[]}"#, FieldType::Class("Either".to_string()), - 2 + 3 ); test_deserializer_with_expected_score!( @@ -69,7 +69,7 @@ test_deserializer_with_expected_score!( UNION_WITH_CHECKS, r#"{"bar": 15, "things":[]}"#, FieldType::Class("Either".to_string()), - 2 + 3 ); test_failing_deserializer!( @@ -90,7 +90,7 @@ test_deserializer_with_expected_score!( MAP_WITH_CHECKS, r#"{"foo": {"hello": 10, "there":13}}"#, FieldType::Class("Foo".to_string()), - 1 + 2 ); test_deserializer_with_expected_score!( @@ -98,7 +98,7 @@ test_deserializer_with_expected_score!( MAP_WITH_CHECKS, r#"{"foo": {"hello": 11, "there":13}}"#, FieldType::Class("Foo".to_string()), - 1 + 2 ); const NESTED_CLASS_CONSTRAINTS: &str = r#" @@ -116,7 +116,7 @@ test_deserializer_with_expected_score!( NESTED_CLASS_CONSTRAINTS, r#"{"inner": {"value": 15}}"#, FieldType::Class("Outer".to_string()), - 0 + 1 ); const BLOCK_LEVEL: &str = r#" diff --git a/engine/baml-lib/jsonish/src/tests/test_maps.rs b/engine/baml-lib/jsonish/src/tests/test_maps.rs index 906e4a1c5..5c1b1d641 100644 --- a/engine/baml-lib/jsonish/src/tests/test_maps.rs +++ b/engine/baml-lib/jsonish/src/tests/test_maps.rs @@ -166,6 +166,7 @@ fn test_union_of_map_and_class() { assert!(result.is_ok(), "Failed to parse: {:?}", result); let value = result.unwrap(); + dbg!(&value); assert!(matches!(value, BamlValueWithFlags::Class(_, _, _))); log::trace!("Score: {}", value.score());