Skip to content

Commit

Permalink
Fix dynamic enums which already are defined in BAML Fixes #1079 (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
hellovai authored Oct 22, 2024
1 parent 71df0b7 commit 22d0f1c
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 14 deletions.
21 changes: 20 additions & 1 deletion engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,26 @@ fn pythonize_strict(
let enum_type = match enum_module.getattr(enum_name.as_str()) {
Ok(e) => e,
// This can be true in the case of dynamic types.
/*
tb = TypeBuilder()
tb.add_enum("Foo")
*/
Err(_) => return Ok(value.into_py(py)),
};

// Call the constructor with the value
let instance = enum_type.call1((value,))?;
let instance = match enum_type.call1((value,)) {
Ok(instance) => instance,
Err(_) => {
// This can happen if the enum value is dynamic
/*
enum Foo {
@@dynamic
}
*/
return Ok(value.into_py(py));
}
};
Ok(instance.into())
}
BamlValue::Class(class_name, index_map) => {
Expand All @@ -109,6 +124,10 @@ fn pythonize_strict(
let class_type = match cls_module.getattr(class_name.as_str()) {
Ok(class) => class,
// This can be true in the case of dynamic types.
/*
tb = TypeBuilder()
tb.add_class("Foo")
*/
Err(_) => return Ok(properties_dict.into()),
};
let instance = class_type.call_method("model_validate", (properties_dict,), None)?;
Expand Down
12 changes: 6 additions & 6 deletions integ-tests/baml_src/generators.baml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ generator lang_ruby {
version "0.62.0"
}

generator openapi {
output_type rest/openapi
output_dir "../openapi"
version "0.62.0"
on_generate "rm .gitignore"
}
// generator openapi {
// output_type rest/openapi
// output_dir "../openapi"
// version "0.62.0"
// on_generate "rm .gitignore"
// }
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,16 @@ enum Hobby {

@@dynamic
}


function ExtractHobby(text: string) -> Hobby[] {
client GPT4
prompt #"
{{ _.role('system') }}
{# This is a special macro that prints out the output schema of the function #}
{{ ctx.output_format }}

{{ _.role('user') }}
{{text}}
"#
}
53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,29 @@ async def ExpectFailure(
)
return cast(str, raw.cast_to(types, types))

async def ExtractHobby(
self,
text: str,
baml_options: BamlCallOptions = {},
) -> List[Union[types.Hobby, str]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"ExtractHobby",
{
"text": text,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(List[Union[types.Hobby, str]], raw.cast_to(types, types))

async def ExtractNames(
self,
input: str,
Expand Down Expand Up @@ -2768,6 +2791,36 @@ def ExpectFailure(
self.__ctx_manager.get(),
)

def ExtractHobby(
self,
text: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[List[Optional[Union[types.Hobby, str]]], List[Union[types.Hobby, str]]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"ExtractHobby",
{
"text": text,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[List[Optional[Union[types.Hobby, str]]], List[Union[types.Hobby, str]]](
raw,
lambda x: cast(List[Optional[Union[types.Hobby, str]]], x.cast_to(types, partial_types)),
lambda x: cast(List[Union[types.Hobby, str]], x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def ExtractNames(
self,
input: str,
Expand Down
4 changes: 2 additions & 2 deletions integ-tests/python/baml_client/inlinedbaml.py

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,29 @@ def ExpectFailure(
)
return cast(str, raw.cast_to(types, types))

def ExtractHobby(
self,
text: str,
baml_options: BamlCallOptions = {},
) -> List[Union[types.Hobby, str]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.call_function_sync(
"ExtractHobby",
{
"text": text,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(List[Union[types.Hobby, str]], raw.cast_to(types, types))

def ExtractNames(
self,
input: str,
Expand Down Expand Up @@ -2766,6 +2789,36 @@ def ExpectFailure(
self.__ctx_manager.get(),
)

def ExtractHobby(
self,
text: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[List[Optional[Union[types.Hobby, str]]], List[Union[types.Hobby, str]]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function_sync(
"ExtractHobby",
{
"text": text,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlSyncStream[List[Optional[Union[types.Hobby, str]]], List[Union[types.Hobby, str]]](
raw,
lambda x: cast(List[Optional[Union[types.Hobby, str]]], x.cast_to(types, partial_types)),
lambda x: cast(List[Union[types.Hobby, str]], x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def ExtractNames(
self,
input: str,
Expand Down
17 changes: 16 additions & 1 deletion integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..baml_client import partial_types
from ..baml_client.types import (
DynInputOutput,
Hobby,
NamedArgsSingleEnumList,
NamedArgsSingleClass,
Nested,
Expand Down Expand Up @@ -790,7 +791,7 @@ async def test_dynamic_inputs_list2():


@pytest.mark.asyncio
async def test_dynamic_types_enum():
async def test_dynamic_types_new_enum():
tb = TypeBuilder()
field_enum = tb.add_enum("Animal")
animals = ["giraffe", "elephant", "lion"]
Expand All @@ -804,6 +805,20 @@ async def test_dynamic_types_enum():
assert len(res) > 0
assert res[0].animalLiked == "GIRAFFE", res[0]


@pytest.mark.asyncio
async def test_dynamic_types_existing_enum():
tb = TypeBuilder()
tb.Hobby.add_value("Golfing")
res = await b.ExtractHobby(
"My name is Harrison. My hair is black and I'm 6 feet tall. golf and music are my favorite!.",
{"tb": tb},
)
assert len(res) > 0
assert "Golfing" in res, res
assert Hobby.MUSIC in res, res


@pytest.mark.asyncio
async def test_dynamic_literals():
tb = TypeBuilder()
Expand Down
67 changes: 67 additions & 0 deletions integ-tests/ruby/baml_client/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,38 @@ def ExpectFailure(
(raw.parsed_using_types(Baml::Types))
end

sig {
params(
varargs: T.untyped,
text: String,
baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)]
).returns(T::Array[T.any(Baml::Types::Hobby, String)])
}
def ExtractHobby(
*varargs,
text:,
baml_options: {}
)
if varargs.any?

raise ArgumentError.new("ExtractHobby may only be called with keyword arguments")
end
if (baml_options.keys - [:client_registry, :tb]).any?
raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}")
end

raw = @runtime.call_function(
"ExtractHobby",
{
text: text,
},
@ctx_manager,
baml_options[:tb]&.instance_variable_get(:@registry),
baml_options[:client_registry],
)
(raw.parsed_using_types(Baml::Types))
end

sig {
params(
varargs: T.untyped,
Expand Down Expand Up @@ -3681,6 +3713,41 @@ def ExpectFailure(
)
end

sig {
params(
varargs: T.untyped,
text: String,
baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)]
).returns(Baml::BamlStream[T::Array[T.any(Baml::Types::Hobby, String)]])
}
def ExtractHobby(
*varargs,
text:,
baml_options: {}
)
if varargs.any?

raise ArgumentError.new("ExtractHobby may only be called with keyword arguments")
end
if (baml_options.keys - [:client_registry, :tb]).any?
raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}")
end

raw = @runtime.stream_function(
"ExtractHobby",
{
text: text,
},
@ctx_manager,
baml_options[:tb]&.instance_variable_get(:@registry),
baml_options[:client_registry],
)
Baml::BamlStream[T::Array[T.nilable(Baml::Types::Hobby)], T::Array[T.any(Baml::Types::Hobby, String)]].new(
ffi_stream: raw,
ctx_manager: @ctx_manager
)
end

sig {
params(
varargs: T.untyped,
Expand Down
4 changes: 2 additions & 2 deletions integ-tests/ruby/baml_client/inlined.rb

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions integ-tests/typescript/baml_client/async_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,31 @@ export class BamlAsyncClient {
}
}

async ExtractHobby(
text: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
): Promise<(string | Hobby)[]> {
try {
const raw = await this.runtime.callFunction(
"ExtractHobby",
{
"text": text
},
this.ctx_manager.cloneContext(),
__baml_options__?.tb?.__tb(),
__baml_options__?.clientRegistry,
)
return raw.parsed() as (string | Hobby)[]
} catch (error: any) {
const bamlError = createBamlValidationError(error);
if (bamlError instanceof BamlValidationError) {
throw bamlError;
} else {
throw error;
}
}
}

async ExtractNames(
input: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
Expand Down Expand Up @@ -2999,6 +3024,39 @@ class BamlStreamClient {
}
}

ExtractHobby(
text: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
): BamlStream<RecursivePartialNull<(string | Hobby)[]>, (string | Hobby)[]> {
try {
const raw = this.runtime.streamFunction(
"ExtractHobby",
{
"text": text
},
undefined,
this.ctx_manager.cloneContext(),
__baml_options__?.tb?.__tb(),
__baml_options__?.clientRegistry,
)
return new BamlStream<RecursivePartialNull<(string | Hobby)[]>, (string | Hobby)[]>(
raw,
(a): a is RecursivePartialNull<(string | Hobby)[]> => a,
(a): a is (string | Hobby)[] => a,
this.ctx_manager.cloneContext(),
__baml_options__?.tb?.__tb(),
)
} catch (error) {
if (error instanceof Error) {
const bamlError = createBamlValidationError(error);
if (bamlError instanceof BamlValidationError) {
throw bamlError;
}
}
throw error;
}
}

ExtractNames(
input: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
Expand Down
Loading

0 comments on commit 22d0f1c

Please sign in to comment.