Skip to content

Commit

Permalink
example stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Dec 12, 2024
1 parent 8d94b0c commit 3ba6b88
Show file tree
Hide file tree
Showing 23 changed files with 443 additions and 17 deletions.
2 changes: 1 addition & 1 deletion engine/baml-lib/jinja/src/evaluate_type/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ fn infer_const_type(v: &minijinja::value::Value) -> Type {
acc.push(x);
Some(Type::Union(acc))
} else {
unreachable!()
unreachable!("minijinja")
}
}
Some(acc) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub(super) fn coerce_array(

let inner = match list_target {
FieldType::List(inner) => inner,
_ => unreachable!(),
_ => unreachable!("coerce_array"),
};

let mut items = vec![];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(super) fn coerce_optional(

let inner = match optional_target {
FieldType::Optional(inner) => inner,
_ => unreachable!(),
_ => unreachable!("coerce_optional"),
};

let mut flags = DeserializerConditions::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub(super) fn coerce_union(

let options = match union_target {
FieldType::Union(options) => options,
_ => unreachable!(),
_ => unreachable!("coerce_union"),
};

let parsed = options
Expand Down
55 changes: 47 additions & 8 deletions engine/baml-schema-wasm/tests/test_file_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,52 @@ test Two {

assert!(js_error.is_object());

assert_eq!(
js_error,
serde_wasm_bindgen::to_value::<HashMap<String, Vec<String>>>(&HashMap::from_iter([(
"all_files".to_string(),
vec!["error.baml".to_string()]
)]))
.unwrap()
);
// assert_eq!(
// js_error,
// serde_wasm_bindgen::to_value::<HashMap<String, Vec<String>>>(&HashMap::from_iter([(
// "all_files".to_string(),
// vec!["error.baml".to_string()]
// )]))
// .unwrap()
// );
}

#[wasm_bindgen_test]
fn test_type_alias_with_assert() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Info));
let sample_baml_content = r##"
class Foo {
foo int
}
type Bar = Foo
"##;
let mut files = HashMap::new();
files.insert("error.baml".to_string(), sample_baml_content.to_string());
let files_js = to_value(&files).unwrap();
let project = WasmProject::new("baml_src", files_js)
.map_err(JsValue::from)
.unwrap();

let env_vars = [("OPENAI_API_KEY", "12345")]
.iter()
.cloned()
.collect::<HashMap<_, _>>();
let env_vars_js = to_value(&env_vars).unwrap();

let Err(js_error) = project.runtime(env_vars_js) else {
panic!("Expected error, got Ok");
};

assert!(js_error.is_object());

// assert_eq!(
// js_error,
// serde_wasm_bindgen::to_value::<HashMap<String, Vec<String>>>(&HashMap::from_iter([(
// "all_files".to_string(),
// vec!["error.baml".to_string()]
// )]))
// .unwrap()
// );
}

}
20 changes: 20 additions & 0 deletions integ-tests/baml_src/test-files/aliases/antonio.baml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class Foo2 {
bar int
baz string
sub Subthing @assert( {{ this.bar == 10}} ) | null
}

class Foo3 {
bar int
baz string
sub Foo3 | null
}

type Subthing = Foo2 @assert( {{ this.bar == 10 }})

function RunFoo2(input: Foo3) -> Foo2 {
client Claude
prompt #"Generate a Foo2 wrapping 30. Use {{ input }}.
{{ ctx.output_format }}
"#
}
53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1959,6 +1959,29 @@ async def ReturnMalformedConstraints(
)
return cast(types.MalformedConstraints, raw.cast_to(types, types))

async def RunFoo2(
self,
input: types.Foo3,
baml_options: BamlCallOptions = {},
) -> types.Foo2:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"RunFoo2",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.Foo2, raw.cast_to(types, types))

async def SchemaDescriptions(
self,
input: str,
Expand Down Expand Up @@ -5477,6 +5500,36 @@ def ReturnMalformedConstraints(
self.__ctx_manager.get(),
)

def RunFoo2(
self,
input: types.Foo3,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[partial_types.Foo2, types.Foo2]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"RunFoo2",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[partial_types.Foo2, types.Foo2](
raw,
lambda x: cast(partial_types.Foo2, x.cast_to(types, partial_types)),
lambda x: cast(types.Foo2, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def SchemaDescriptions(
self,
input: str,
Expand Down
1 change: 1 addition & 0 deletions integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"fiddle-examples/symbol-tuning.baml": "enum Category3 {\n Refund @alias(\"k1\")\n @description(\"Customer wants to refund a product\")\n\n CancelOrder @alias(\"k2\")\n @description(\"Customer wants to cancel an order\")\n\n TechnicalSupport @alias(\"k3\")\n @description(\"Customer needs help with a technical issue unrelated to account creation or login\")\n\n AccountIssue @alias(\"k4\")\n @description(\"Specifically relates to account-login or account-creation\")\n\n Question @alias(\"k5\")\n @description(\"Customer has a question\")\n}\n\nfunction ClassifyMessage3(input: string) -> Category {\n client GPT4\n\n prompt #\"\n Classify the following INPUT into ONE\n of the following categories:\n\n INPUT: {{ input }}\n\n {{ ctx.output_format }}\n\n Response:\n \"#\n}",
"generators.baml": "generator lang_python {\n output_type python/pydantic\n output_dir \"../python\"\n version \"0.70.1\"\n}\n\ngenerator lang_typescript {\n output_type typescript\n output_dir \"../typescript\"\n version \"0.70.1\"\n}\n\ngenerator lang_ruby {\n output_type ruby/sorbet\n output_dir \"../ruby\"\n version \"0.70.1\"\n}\n\n// generator openapi {\n// output_type rest/openapi\n// output_dir \"../openapi\"\n// version \"0.70.1\"\n// on_generate \"rm .gitignore\"\n// }\n",
"test-files/aliases/aliased-inputs.baml": "\nclass InputClass {\n key string @alias(\"color\")\n key2 string\n}\n\n\nclass InputClassNested {\n key string\n nested InputClass @alias(\"interesting-key\")\n}\n \n\nfunction AliasedInputClass(input: InputClass) -> string {\n client GPT35\n prompt #\"\n\n {{input}}\n\n This is a test. What's the name of the first json key above? Remember, tell me the key, not value.\n \"#\n}\n \nfunction AliasedInputClass2(input: InputClass) -> string {\n client GPT35\n prompt #\"\n\n {# making sure we can still access the original key #}\n {%if input.key == \"tiger\"%}\n Repeat this value back to me, and nothing else: {{input.key}}\n {%endif%}\n \"#\n}\n \n function AliasedInputClassNested(input: InputClassNested) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"user\")}}\n\n {{input}}\n\n This is a test. What's the name of the second json key above? Remember, tell me the key, not value.\n \"#\n }\n\n\nenum AliasedEnum {\n KEY_ONE @alias(\"tiger\")\n KEY_TWO\n}\n\nfunction AliasedInputEnum(input: AliasedEnum) -> string {\n client GPT4o\n prompt #\"\n {{ _.role(\"user\")}}\n\n\n Write out this word only in your response, in lowercase:\n ---\n {{input}}\n ---\n Answer:\n \"#\n}\n\n\nfunction AliasedInputList(input: AliasedEnum[]) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"user\")}}\n Given this array:\n ---\n {{input}}\n ---\n\n Return the first element in the array:\n \"#\n}\n\n",
"test-files/aliases/antonio.baml": "class Foo2 {\n bar int\n baz string\n sub Subthing @assert( {{ this.bar == 10}} ) | null\n}\n\nclass Foo3 {\n bar int\n baz string\n sub Foo3 | null\n}\n\ntype Subthing = Foo2 @assert( {{ this.bar == 10 }})\n\nfunction RunFoo2(input: Foo3) -> Foo2 {\n client Claude\n prompt #\"Generate a Foo2 wrapping 30. Use {{ input }}.\n {{ ctx.output_format }}\n \"#\n}",
"test-files/aliases/classes.baml": "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n",
"test-files/aliases/enums.baml": "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}",
"test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.",
Expand Down
10 changes: 10 additions & 0 deletions integ-tests/python/baml_client/partial_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ class FlightConfirmation(BaseModel):
arrivalTime: Optional[str] = None
seatNumber: Optional[str] = None

class Foo2(BaseModel):
bar: Optional[int] = None
baz: Optional[str] = None
sub: Optional[Union[Optional["Foo2"], Optional[None]]] = None

class Foo3(BaseModel):
bar: Optional[int] = None
baz: Optional[str] = None
sub: Optional[Union["Foo3", Optional[None]]] = None

class FooAny(BaseModel):
planetary_age: Optional[Union["Martian", "Earthling"]] = None
certainty: Checked[Optional[int],Literal["unreasonably_certain"]]
Expand Down
53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,29 @@ def ReturnMalformedConstraints(
)
return cast(types.MalformedConstraints, raw.cast_to(types, types))

def RunFoo2(
self,
input: types.Foo3,
baml_options: BamlCallOptions = {},
) -> types.Foo2:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.call_function_sync(
"RunFoo2",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.Foo2, raw.cast_to(types, types))

def SchemaDescriptions(
self,
input: str,
Expand Down Expand Up @@ -5475,6 +5498,36 @@ def ReturnMalformedConstraints(
self.__ctx_manager.get(),
)

def RunFoo2(
self,
input: types.Foo3,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[partial_types.Foo2, types.Foo2]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function_sync(
"RunFoo2",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlSyncStream[partial_types.Foo2, types.Foo2](
raw,
lambda x: cast(partial_types.Foo2, x.cast_to(types, partial_types)),
lambda x: cast(types.Foo2, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def SchemaDescriptions(
self,
input: str,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/type_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class TypeBuilder(_TypeBuilder):
def __init__(self):
super().__init__(classes=set(
["BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",]
["BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","Foo2","Foo3","FooAny","Forest","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",]
), enums=set(
["AliasedEnum","Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","MapKey","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum",]
))
Expand Down
10 changes: 10 additions & 0 deletions integ-tests/python/baml_client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,16 @@ class FlightConfirmation(BaseModel):
arrivalTime: str
seatNumber: str

class Foo2(BaseModel):
bar: int
baz: str
sub: Union["Foo2", Optional[None]]

class Foo3(BaseModel):
bar: int
baz: str
sub: Union["Foo3", Optional[None]]

class FooAny(BaseModel):
planetary_age: Union["Martian", "Earthling"]
certainty: Checked[int,Literal["unreasonably_certain"]]
Expand Down
6 changes: 6 additions & 0 deletions integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ClassToRecAlias,
NodeWithAliasIndirection,
MergeAttrs,
Foo3
)
import baml_client.types as types
from ..baml_client.tracing import trace, set_tags, flush, on_log_event
Expand Down Expand Up @@ -1587,3 +1588,8 @@ async def test_block_constraint_arguments():
nested_block_constraint = NestedBlockConstraintForParam(nbcfp=block_constraint)
await b.UseNestedBlockConstraint(nested_block_constraint)
assert "Failed assert: hi" in str(e)

@pytest.mark.asyncio
async def test_alias_bug():
res = await b.RunFoo2(input=Foo3(bar=10, baz="hi", sub=None))
assert True
Loading

0 comments on commit 3ba6b88

Please sign in to comment.