Skip to content

Commit

Permalink
Codegen works! (probably not)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Nov 22, 2024
1 parent 28907de commit 29bed6b
Show file tree
Hide file tree
Showing 16 changed files with 818 additions and 13 deletions.
4 changes: 3 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ impl ArgCoercer {
Err(())
}
},
(FieldType::Alias { .. }, _) => todo!(),
(FieldType::Alias { resolution, .. }, _) => {
self.coerce_arg(ir, &resolution, value, scope)
}
(FieldType::List(item), _) => match value {
BamlValue::List(arr) => {
let mut items = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ impl WithRepr<FieldType> for ast::FieldType {
Some(TypeWalker::TypeAlias(alias_walker)) => FieldType::Alias {
name: alias_walker.name().to_owned(),
target: Box::new(alias_walker.target().repr(db)?),
resolution: Box::new(FieldType::int()), // TODO
resolution: Box::new(alias_walker.resolved().repr(db)?),
},
None => return Err(anyhow!("Field type uses unresolvable local identifier")),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ type Graph = map<string, string[]>

type Combination = Primitive | List | Graph

function AliasPrimitive(p: Primitive) -> Primitive {
function PrimitiveAlias(p: Primitive) -> Primitive {
client "openai/gpt-4o"
prompt r#"
Return the given value back: {{ p }}
Expand Down
3 changes: 1 addition & 2 deletions engine/baml-lib/jsonish/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ fn relevant_data_models<'a>(
let mut recursive_classes = IndexSet::new();
let mut start: Vec<baml_types::FieldType> = vec![output.clone()];

while !start.is_empty() {
let output = start.pop().unwrap();
while let Some(output) = start.pop() {
match ir.distribute_constraints(&output) {
(FieldType::Enum(enm), constraints) => {
if checked_types.insert(output.to_string()) {
Expand Down
3 changes: 1 addition & 2 deletions engine/baml-lib/parser-database/src/names/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pub(super) struct Names {
/// Tests have their own namespace.
pub(super) tests: HashMap<StringId, HashMap<StringId, TopId>>,
pub(super) model_fields: HashMap<(ast::TypeExpId, StringId), ast::FieldId>,
pub(super) type_aliases: HashMap<ast::TypeExpId, Option<StringId>>,
// pub(super) composite_type_fields: HashMap<(ast::CompositeTypeId, StringId), ast::FieldId>,
}

Expand Down Expand Up @@ -95,7 +94,7 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) {
(ast::TopId::TypeAlias(_), ast::Top::TypeAlias(type_alias)) => {
validate_type_alias_name(type_alias, ctx.diagnostics);

let type_alias_id = ctx.interner.intern(type_alias.name());
ctx.interner.intern(type_alias.name());

Some(either::Left(&mut names.tops))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ fn relevant_data_models<'a>(
recursive_classes.insert(cls.to_owned());
}
}
(FieldType::Alias { .. }, _) => todo!(),
(FieldType::Alias { resolution, .. }, _) => {
start.push(*resolution.clone());
}
(FieldType::Literal(_), _) => {}
(FieldType::Primitive(_), _) => {}
(FieldType::Constrained { .. }, _) => {
Expand Down
6 changes: 1 addition & 5 deletions engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@ use baml_runtime::{
},
ChatMessagePart, RenderedPrompt,
};
use serde::Serialize;
use serde_json::json;

use crate::runtime_wasm::ToJsValue;
use baml_types::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64};
use baml_types::{BamlMediaContent, BamlMediaType, MediaBase64};
use serde_wasm_bindgen::to_value;
use wasm_bindgen::prelude::*;

use super::WasmFunction;

#[wasm_bindgen(getter_with_clone)]
pub struct WasmScope {
scope: OrchestrationScope,
Expand Down
36 changes: 36 additions & 0 deletions integ-tests/baml_src/test-files/functions/output/type-aliases.baml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
type Primitive = int | string | bool | float

type List = string[]

type Graph = map<string, string[]>

type Combination = Primitive | List | Graph

function PrimitiveAlias(p: Primitive) -> Primitive {
client "openai/gpt-4o"
prompt r#"
Return the given value back: {{ p }}
"#
}

function MapAlias(m: Graph) -> Graph {
client "openai/gpt-4o"
prompt r#"
Return the given Graph back:

{{ m }}

{{ ctx.output_format }}
"#
}

function NestedAlias(c: Combination) -> Combination {
client "openai/gpt-4o"
prompt r#"
Return the given value back:

{{ c }}

{{ ctx.output_format }}
"#
}
159 changes: 159 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,29 @@ async def MakeNestedBlockConstraint(
)
return cast(types.NestedBlockConstraint, raw.cast_to(types, types))

async def MapAlias(
self,
m: Dict[str, List[str]],
baml_options: BamlCallOptions = {},
) -> Dict[str, List[str]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"MapAlias",
{
"m": m,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(Dict[str, List[str]], raw.cast_to(types, types))

async def MyFunc(
self,
input: str,
Expand All @@ -1476,6 +1499,29 @@ async def MyFunc(
)
return cast(types.DynamicOutput, raw.cast_to(types, types))

async def NestedAlias(
self,
c: Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]],
baml_options: BamlCallOptions = {},
) -> Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"NestedAlias",
{
"c": c,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types))

async def OptionalTest_Function(
self,
input: str,
Expand Down Expand Up @@ -1545,6 +1591,29 @@ async def PredictAgeBare(
)
return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types))

async def PrimitiveAlias(
self,
p: Union[int, str, bool, float],
baml_options: BamlCallOptions = {},
) -> Union[int, str, bool, float]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"PrimitiveAlias",
{
"p": p,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(Union[int, str, bool, float], raw.cast_to(types, types))

async def PromptTestClaude(
self,
input: str,
Expand Down Expand Up @@ -4610,6 +4679,36 @@ def MakeNestedBlockConstraint(
self.__ctx_manager.get(),
)

def MapAlias(
self,
m: Dict[str, List[str]],
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"MapAlias",
{
"m": m,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]](
raw,
lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, partial_types)),
lambda x: cast(Dict[str, List[str]], x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def MyFunc(
self,
input: str,
Expand Down Expand Up @@ -4640,6 +4739,36 @@ def MyFunc(
self.__ctx_manager.get(),
)

def NestedAlias(
self,
c: Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]],
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"NestedAlias",
{
"c": c,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]](
raw,
lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, partial_types)),
lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def OptionalTest_Function(
self,
input: str,
Expand Down Expand Up @@ -4730,6 +4859,36 @@ def PredictAgeBare(
self.__ctx_manager.get(),
)

def PrimitiveAlias(
self,
p: Union[int, str, bool, float],
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"PrimitiveAlias",
{
"p": p,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]](
raw,
lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, partial_types)),
lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def PromptTestClaude(
self,
input: str,
Expand Down
1 change: 1 addition & 0 deletions integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"test-files/functions/output/recursive-class.baml": "class Node {\n data int\n next Node?\n}\n\nclass LinkedList {\n head Node?\n len int\n}\n\nclient<llm> O1 {\n provider \"openai\"\n options {\n model \"o1-mini\"\n default_role \"user\"\n }\n}\n\nfunction BuildLinkedList(input: int[]) -> LinkedList {\n client O1\n prompt #\"\n Build a linked list from the input array of integers.\n\n INPUT:\n {{ input }}\n\n {{ ctx.output_format }} \n \"#\n}\n\ntest TestLinkedList {\n functions [BuildLinkedList]\n args {\n input [1, 2, 3, 4, 5]\n }\n}\n",
"test-files/functions/output/serialization-error.baml": "class DummyOutput {\n nonce string\n nonce2 string\n @@dynamic\n}\n\nfunction DummyOutputFunction(input: string) -> DummyOutput {\n client GPT35\n prompt #\"\n Say \"hello there\".\n \"#\n}",
"test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/type-aliases.baml": "type Primitive = int | string | bool | float\n\ntype List = string[]\n\ntype Graph = map<string, string[]>\n\ntype Combination = Primitive | List | Graph\n\nfunction PrimitiveAlias(p: Primitive) -> Primitive {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back: {{ p }}\n \"#\n}\n\nfunction MapAlias(m: Graph) -> Graph {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given Graph back:\n\n {{ m }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction NestedAlias(c: Combination) -> Combination {\n client \"openai/gpt-4o\"\n prompt r#\"\n Return the given value back:\n\n {{ c }}\n\n {{ ctx.output_format }}\n \"#\n}\n",
"test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Sonnet\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n",
"test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}",
Expand Down
Loading

0 comments on commit 29bed6b

Please sign in to comment.