From 68b98b704cae52861acc743d7913621152bbf1d7 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 17 Dec 2024 02:16:46 +0100 Subject: [PATCH] Fix structural cycles rendering --- .../jinja-runtime/src/output_format/types.rs | 113 +++++++++++++++--- .../baml-schema-wasm/src/runtime_wasm/mod.rs | 4 +- 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index bf5bb6045..188f37035 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -312,7 +312,6 @@ fn indefinite_article_a_or_an(word: &str) -> &str { struct RenderState { hoisted_enums: IndexSet, - hoisted_aliases: IndexMap, } impl OutputFormatContent { @@ -512,6 +511,9 @@ impl OutputFormatContent { FieldType::List(inner) => { let is_recursive = match inner.as_ref() { FieldType::Class(nested_class) => self.recursive_classes.contains(nested_class), + FieldType::RecursiveTypeAlias(name) => { + self.structural_recursive_aliases.contains_key(name) + } _ => false, }; @@ -582,7 +584,6 @@ impl OutputFormatContent { let mut render_state = RenderState { hoisted_enums: IndexSet::new(), - hoisted_aliases: IndexMap::new(), }; let mut message = match &self.target { @@ -636,18 +637,9 @@ impl OutputFormatContent { }); } - // Yeah gotta love the borrow checker. // TODO: @antonio - let hoisted_aliases = - std::mem::replace(&mut render_state.hoisted_aliases, Default::default()); - - for (alias, target) in hoisted_aliases.iter() { - let recursive_pointer = self.inner_type_render( - &options, - // TODO: @antonio This code sucks beyond measure. Fix this. - &target, - &mut render_state, - false, - )?; + for (alias, target) in self.structural_recursive_aliases.iter() { + let recursive_pointer = + self.inner_type_render(&options, target, &mut render_state, false)?; type_alias_definitions.push(match &options.hoisted_class_prefix { RenderSetting::Always(prefix) if !prefix.is_empty() => { @@ -670,7 +662,7 @@ impl OutputFormatContent { } if !type_alias_definitions.is_empty() { - output.push_str(&type_alias_definitions.join("\n\n")); + output.push_str(&type_alias_definitions.join("\n")); output.push_str("\n\n"); } @@ -2281,4 +2273,95 @@ Answer in JSON using this schema: )) ); } + + #[test] + fn render_simple_recursive_aliases() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias( + "RecursiveMapAlias".to_string(), + )) + .structural_recursive_aliases(IndexMap::from([( + "RecursiveMapAlias".to_string(), + FieldType::map( + FieldType::string(), + FieldType::RecursiveTypeAlias("RecursiveMapAlias".to_string()), + ), + )])) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"RecursiveMapAlias = map + +Answer in JSON using this schema: RecursiveMapAlias"# + )) + ); + } + + #[test] + fn render_recursive_alias_cycle() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias("A".to_string())) + .structural_recursive_aliases(IndexMap::from([ + ( + "A".to_string(), + FieldType::RecursiveTypeAlias("B".to_string()), + ), + ( + "B".to_string(), + FieldType::RecursiveTypeAlias("C".to_string()), + ), + ( + "C".to_string(), + FieldType::list(FieldType::RecursiveTypeAlias("A".to_string())), + ), + ])) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"A = B +B = C +C = A[] + +Answer in JSON using this schema: A"# + )) + ); + } + + #[test] + fn render_recursive_alias_cycle_with_hoist_prefix() { + let content = OutputFormatContent::target(FieldType::RecursiveTypeAlias("A".to_string())) + .structural_recursive_aliases(IndexMap::from([ + ( + "A".to_string(), + FieldType::RecursiveTypeAlias("B".to_string()), + ), + ( + "B".to_string(), + FieldType::RecursiveTypeAlias("C".to_string()), + ), + ( + "C".to_string(), + FieldType::list(FieldType::RecursiveTypeAlias("A".to_string())), + ), + ])) + .build(); + let rendered = content + .render(RenderOptions::with_hoisted_class_prefix("type")) + .unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"type A = B +type B = C +type C = A[] + +Answer in JSON using this type: A"# + )) + ); + } } diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index f21641107..614b19fbc 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -876,9 +876,7 @@ fn get_dummy_value( baml_runtime::FieldType::Literal(_) => None, baml_runtime::FieldType::Enum(_) => None, baml_runtime::FieldType::Class(_) => None, - baml_runtime::FieldType::Alias { resolution, .. } => { - get_dummy_value(indent, allow_multiline, &resolution) - } + baml_runtime::FieldType::RecursiveTypeAlias(_) => None, baml_runtime::FieldType::List(item) => { let dummy = get_dummy_value(indent + 1, allow_multiline, item); // Repeat it 2 times