From e0ae448f3bf412959483f9c784cdd38d7bcd6728 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 17 Dec 2024 17:39:06 +0100 Subject: [PATCH] Fix test `relevant_data_models` --- .../src/deserializer/coercer/field_type.rs | 10 ------ engine/baml-lib/jsonish/src/tests/mod.rs | 34 +++++++++++++++---- .../jsonish/src/tests/test_aliases.rs | 25 ++++++++++++++ 3 files changed, 53 insertions(+), 16 deletions(-) create mode 100644 engine/baml-lib/jsonish/src/tests/test_aliases.rs diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index eed33ba0c..13ba6e9b2 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -14,8 +14,6 @@ use super::{ ParsingError, }; -static mut count: u32 = 0; - impl TypeCoercer for FieldType { fn coerce( &self, @@ -23,14 +21,6 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - unsafe { - eprintln!("{self:?} -> {target:?} -> {value:?}"); - count += 1; - if count == 20 { - panic!("FUCK"); - } - } - match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 41a22ef4c..00da9d619 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -4,6 +4,7 @@ use internal_baml_jinja::types::{Class, Enum, Name, OutputFormatContent}; #[macro_use] pub mod macros; +mod test_aliases; mod test_basics; mod test_class; mod test_class_2; @@ -16,7 +17,7 @@ mod test_maps; mod test_partials; mod test_unions; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use std::{ collections::{HashMap, HashSet}, path::PathBuf, @@ -56,12 +57,14 @@ fn render_output_format( output: &FieldType, env_values: &EvaluationContext<'_>, ) -> Result { - let (enums, classes, recursive_classes) = relevant_data_models(ir, output, env_values)?; + let (enums, classes, recursive_classes, structural_recursive_aliases) = + relevant_data_models(ir, output, env_values)?; Ok(OutputFormatContent::target(output.clone()) .enums(enums) .classes(classes) .recursive_classes(recursive_classes) + .structural_recursive_aliases(structural_recursive_aliases) .build()) } @@ -126,11 +129,17 @@ fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, env_values: &EvaluationContext<'_>, -) -> Result<(Vec, Vec, IndexSet)> { +) -> Result<( + Vec, + Vec, + IndexSet, + IndexMap, +)> { let mut checked_types: HashSet = HashSet::new(); let mut enums = Vec::new(); let mut classes: Vec = Vec::new(); let mut recursive_classes = IndexSet::new(); + let mut structural_recursive_aliases = IndexMap::new(); let mut start: Vec = vec![output.clone()]; while let Some(output) = start.pop() { @@ -230,8 +239,16 @@ fn relevant_data_models<'a>( }); } } - // TODO: Add structural aliases here. - (FieldType::RecursiveTypeAlias(_), _) => {} + (FieldType::RecursiveTypeAlias(name), _) => { + // TODO: Same O(n) problem as above. + for cycle in ir.structural_recursive_alias_cycles() { + if cycle.contains_key(name) { + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } + } + } + } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _constraints) => {} (FieldType::Constrained { .. }, _) => { @@ -240,7 +257,12 @@ fn relevant_data_models<'a>( } } - Ok((enums, classes, recursive_classes)) + Ok(( + enums, + classes, + recursive_classes, + structural_recursive_aliases, + )) } const EMPTY_FILE: &str = r#" diff --git a/engine/baml-lib/jsonish/src/tests/test_aliases.rs b/engine/baml-lib/jsonish/src/tests/test_aliases.rs new file mode 100644 index 000000000..061aab3cb --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/test_aliases.rs @@ -0,0 +1,25 @@ +use baml_types::LiteralValue; + +use super::*; + +test_deserializer!( + test_simple_recursive_alias_list, + r#" +type A = A[] + "#, + "[[], [], [[]]]", + FieldType::RecursiveTypeAlias("A".into()), + [[], [], [[]]] +); + +test_deserializer!( + test_recursive_alias_cycle, + r#" +type A = B +type B = C +type C = A[] + "#, + "[[], [], [[]]]", + FieldType::RecursiveTypeAlias("A".into()), + [[], [], [[]]] +);