Skip to content

Commit

Permalink
Coerce is wonky
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Dec 17, 2024
1 parent 68b98b7 commit d462e5c
Show file tree
Hide file tree
Showing 14 changed files with 621 additions and 3 deletions.
10 changes: 10 additions & 0 deletions engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@ use super::{
ParsingError,
};

static mut count: u32 = 0;

impl TypeCoercer for FieldType {
fn coerce(
&self,
ctx: &ParsingContext,
target: &FieldType,
value: Option<&crate::jsonish::Value>,
) -> Result<BamlValueWithFlags, ParsingError> {
unsafe {
eprintln!("{self:?} -> {target:?} -> {value:?}");
count += 1;
if count == 20 {
panic!("FUCK");
}
}

match value {
Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => {
log::debug!(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod coerce_class;
pub mod coerce_enum;

use core::panic;

use anyhow::Result;
use internal_baml_core::ir::FieldType;

Expand Down
48 changes: 48 additions & 0 deletions engine/baml-runtime/tests/test_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,4 +552,52 @@ test RunFoo2Test {

Ok(())
}

#[test]
fn test_recursive_alias_cycle() -> anyhow::Result<()> {
let runtime = make_test_runtime(
r##"
type RecAliasOne = RecAliasTwo
type RecAliasTwo = RecAliasThree
type RecAliasThree = RecAliasOne[]
function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {
client "openai/gpt-4o"
prompt r#"
Return the given value:
{{ input }}
{{ ctx.output_format }}
"#
}
test RecursiveAliasCycle {
functions [RecursiveAliasCycle]
args {
input [
[]
[]
[[], []]
]
}
}
"##,
)?;

let ctx = runtime
.create_ctx_manager(BamlValue::String("test".to_string()), None)
.create_ctx_with_default();

let function_name = "RecursiveAliasCycle";
let test_name = "RecursiveAliasCycle";
let params = runtime.get_test_params(function_name, test_name, &ctx, true)?;
let render_prompt_future =
runtime
.internal()
.render_prompt(function_name, &ctx, &params, None);
let (prompt, scope, _) = runtime.async_runtime.block_on(render_prompt_future)?;

Ok(())
}
}
28 changes: 28 additions & 0 deletions integ-tests/baml_src/test-files/functions/output/type-aliases.baml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,34 @@ function SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias

{{ input }}

{{ ctx.output_format }}
"#
}

type RecursiveListAlias = RecursiveListAlias[]

function SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {
client "openai/gpt-4o"
prompt r#"
Return the given JSON array:

{{ input }}

{{ ctx.output_format }}
"#
}

type RecAliasOne = RecAliasTwo
type RecAliasTwo = RecAliasThree
type RecAliasThree = RecAliasOne[]

function RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {
client "openai/gpt-4o"
prompt r#"
Return the given JSON array:

{{ input }}

{{ ctx.output_format }}
"#
}
106 changes: 106 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,29 @@ async def PromptTestStreaming(
)
return cast(str, raw.cast_to(types, types))

async def RecursiveAliasCycle(
self,
input: types.RecAliasOne,
baml_options: BamlCallOptions = {},
) -> types.RecAliasOne:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"RecursiveAliasCycle",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.RecAliasOne, raw.cast_to(types, types))

async def RecursiveClassWithAliasIndirection(
self,
cls: types.NodeWithAliasIndirection,
Expand Down Expand Up @@ -1982,6 +2005,29 @@ async def SchemaDescriptions(
)
return cast(types.Schema, raw.cast_to(types, types))

async def SimpleRecursiveListAlias(
self,
input: types.RecursiveListAlias,
baml_options: BamlCallOptions = {},
) -> types.RecursiveListAlias:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"SimpleRecursiveListAlias",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.RecursiveListAlias, raw.cast_to(types, types))

async def SimpleRecursiveMapAlias(
self,
input: types.RecursiveMapAlias,
Expand Down Expand Up @@ -5380,6 +5426,36 @@ def PromptTestStreaming(
self.__ctx_manager.get(),
)

def RecursiveAliasCycle(
self,
input: types.RecAliasOne,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"RecursiveAliasCycle",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne](
raw,
lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)),
lambda x: cast(types.RecAliasOne, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def RecursiveClassWithAliasIndirection(
self,
cls: types.NodeWithAliasIndirection,
Expand Down Expand Up @@ -5530,6 +5606,36 @@ def SchemaDescriptions(
self.__ctx_manager.get(),
)

def SimpleRecursiveListAlias(
self,
input: types.RecursiveListAlias,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"SimpleRecursiveListAlias",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias](
raw,
lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)),
lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def SimpleRecursiveMapAlias(
self,
input: types.RecursiveMapAlias,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"test-files/functions/output/recursive-type-aliases.baml": "class LinkedListAliasNode {\n value int\n next LinkedListAliasNode?\n}\n\n// Simple alias that points to recursive type.\ntype LinkedListAlias = LinkedListAliasNode\n\nfunction AliasThatPointsToRecursiveType(list: LinkedListAlias) -> LinkedListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given linked list back:\n \n {{ list }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// Class that points to an alias that points to a recursive type.\nclass ClassToRecAlias {\n list LinkedListAlias\n}\n\nfunction ClassThatPointsToRecursiveClassThroughAlias(cls: ClassToRecAlias) -> ClassToRecAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n\n// This is tricky cause this class should be hoisted, but classes and aliases\n// are two different types in the AST. This test will make sure they can interop.\nclass NodeWithAliasIndirection {\n value int\n next NodeIndirection?\n}\n\ntype NodeIndirection = NodeWithAliasIndirection\n\nfunction RecursiveClassWithAliasIndirection(cls: NodeWithAliasIndirection) -> NodeWithAliasIndirection {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given object back:\n \n {{ cls }}\n \n {{ ctx.output_format }}\n \"#\n}\n",
"test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}",
"test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map<string, string[]>\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map<string, RecursiveMapAlias>\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}",
"test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map<string, string[]>\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n\n// Test attribute merging.\ntype Currency = int @check(gt_ten, {{ this > 10 }})\ntype Amount = Currency @assert({{ this > 0 }})\n\nclass MergeAttrs {\n amount Amount @description(\"In USD\")\n}\n\n// This should be allowed.\ntype MultipleAttrs = int @assert({{ this > 0 }}) @check(gt_ten, {{ this > 10 }})\n\nfunction MergeAliasAttributes(money: int) -> MergeAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer in the specified format:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ReturnAliasWithMergedAttributes(money: Amount) -> Amount {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction AliasWithMultipleAttrs(money: MultipleAttrs) -> MultipleAttrs {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given integer without additional context:\n\n {{ money }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveMapAlias = map<string, RecursiveMapAlias>\n\nfunction SimpleRecursiveMapAlias(input: RecursiveMapAlias) -> RecursiveMapAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecursiveListAlias = RecursiveListAlias[]\n\nfunction SimpleRecursiveListAlias(input: RecursiveListAlias) -> RecursiveListAlias {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\ntype RecAliasOne = RecAliasTwo\ntype RecAliasTwo = RecAliasThree\ntype RecAliasThree = RecAliasOne[]\n\nfunction RecursiveAliasCycle(input: RecAliasOne) -> RecAliasOne {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given JSON array:\n\n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}",
"test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n",
"test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}",
Expand Down
Loading

0 comments on commit d462e5c

Please sign in to comment.