From 9bd01b8ec16d8eccf380100f52a358b097ddf678 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Thu, 12 Dec 2024 21:06:40 +0100 Subject: [PATCH] Fix `FieldType::Contrained` prompt rendering with aliases --- .../jinja-runtime/src/output_format/types.rs | 20 ++++--- engine/baml-runtime/tests/test_runtime.rs | 54 +++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index d08b929c5..f3e357ba2 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -426,9 +426,12 @@ impl OutputFormatContent { } }, FieldType::Literal(v) => v.to_string(), - FieldType::Constrained { base, .. } => { - self.inner_type_render(options, base, render_state, group_hoisted_literals)? - } + FieldType::Constrained { base, .. } => self.render_possibly_recursive_type( + options, + base, + render_state, + group_hoisted_literals, + )?, FieldType::Enum(e) => { let Some(enm) = self.enums.get(e) else { return Err(minijinja::Error::new( @@ -536,9 +539,14 @@ impl OutputFormatContent { } FieldType::Map(key_type, value_type) => MapRender { style: &options.map_style, - // TODO: Key can't be recursive because we only support strings - // as keys. Change this if needed in the future. - key_type: self.inner_type_render(options, key_type, render_state, false)?, + // NOTE: Key can't be recursive because we only support strings + // as keys. + key_type: self.render_possibly_recursive_type( + options, + key_type, + render_state, + false, + )?, value_type: self.render_possibly_recursive_type( options, value_type, diff --git a/engine/baml-runtime/tests/test_runtime.rs b/engine/baml-runtime/tests/test_runtime.rs index e10a957a4..79efe7151 100644 --- a/engine/baml-runtime/tests/test_runtime.rs +++ b/engine/baml-runtime/tests/test_runtime.rs @@ -498,4 +498,58 @@ test TestTree { Ok(()) } + + #[test] + fn test_constrained_type_alias() -> anyhow::Result<()> { + let runtime = make_test_runtime( + r##" +class Foo2 { + bar int + baz string + sub Subthing @assert( {{ this.bar == 10}} ) | null +} + +class Foo3 { + bar int + baz string + sub Foo3 | null +} + +type Subthing = Foo2 @assert( {{ this.bar == 10 }}) + +function RunFoo2(input: Foo3) -> Foo2 { + client "openai/gpt-4o" + prompt #"Generate a Foo2 wrapping 30. Use {{ input }}. + {{ ctx.output_format }} + "# +} + +test RunFoo2Test { + functions [RunFoo2] + args { + input { + bar 30 + baz "hello" + sub null + } + } +} + "##, + )?; + + let ctx = runtime + .create_ctx_manager(BamlValue::String("test".to_string()), None) + .create_ctx_with_default(); + + let function_name = "RunFoo2"; + let test_name = "RunFoo2Test"; + let params = runtime.get_test_params(function_name, test_name, &ctx, true)?; + let render_prompt_future = + runtime + .internal() + .render_prompt(function_name, &ctx, ¶ms, None); + let (prompt, scope, _) = runtime.async_runtime.block_on(render_prompt_future)?; + + Ok(()) + } }