Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for azure and failure scenarios #1250

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions integ-tests/baml_src/test-files/providers/providers.baml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ function TestAzure(input: string) -> string {
"#
}

client GPT35AzureFailed {
provider azure-openai
options {
resource_name "west-us-azure-baml-incorrect-suffix"
deployment_id "gpt-35-turbo-default"
api_key env.AZURE_OPENAI_API_KEY
}
}
function TestAzureFailure(input: string) -> string {
client GPT35AzureFailed
prompt #"
Write a nice haiku about {{ input }}
"#
}

function TestOllama(input: string) -> string {
client Ollama
prompt #"
Expand Down
53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,29 @@ async def TestAzure(
)
return cast(str, raw.cast_to(types, types))

async def TestAzureFailure(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"TestAzureFailure",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(str, raw.cast_to(types, types))

async def TestCaching(
self,
input: str,not_cached: str,
Expand Down Expand Up @@ -5332,6 +5355,36 @@ def TestAzure(
self.__ctx_manager.get(),
)

def TestAzureFailure(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"TestAzureFailure",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Optional[str], str](
raw,
lambda x: cast(Optional[str], x.cast_to(types, partial_types)),
lambda x: cast(str, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def TestCaching(
self,
input: str,not_cached: str,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/inlinedbaml.py

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,29 @@ def TestAzure(
)
return cast(str, raw.cast_to(types, types))

def TestAzureFailure(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.call_function_sync(
"TestAzureFailure",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(str, raw.cast_to(types, types))

def TestCaching(
self,
input: str,not_cached: str,
Expand Down Expand Up @@ -5330,6 +5353,36 @@ def TestAzure(
self.__ctx_manager.get(),
)

def TestAzureFailure(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function_sync(
"TestAzureFailure",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlSyncStream[Optional[str], str](
raw,
lambda x: cast(Optional[str], x.cast_to(types, partial_types)),
lambda x: cast(str, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def TestCaching(
self,
input: str,not_cached: str,
Expand Down
67 changes: 67 additions & 0 deletions integ-tests/ruby/baml_client/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,38 @@ def TestAzure(
(raw.parsed_using_types(Baml::Types))
end

sig {
params(
varargs: T.untyped,
input: String,
baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)]
).returns(String)
}
def TestAzureFailure(
*varargs,
input:,
baml_options: {}
)
if varargs.any?

raise ArgumentError.new("TestAzureFailure may only be called with keyword arguments")
end
if (baml_options.keys - [:client_registry, :tb]).any?
raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}")
end

raw = @runtime.call_function(
"TestAzureFailure",
{
input: input,
},
@ctx_manager,
baml_options[:tb]&.instance_variable_get(:@registry),
baml_options[:client_registry],
)
(raw.parsed_using_types(Baml::Types))
end

sig {
params(
varargs: T.untyped,
Expand Down Expand Up @@ -6814,6 +6846,41 @@ def TestAzure(
)
end

sig {
params(
varargs: T.untyped,
input: String,
baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)]
).returns(Baml::BamlStream[String])
}
def TestAzureFailure(
*varargs,
input:,
baml_options: {}
)
if varargs.any?

raise ArgumentError.new("TestAzureFailure may only be called with keyword arguments")
end
if (baml_options.keys - [:client_registry, :tb]).any?
raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}")
end

raw = @runtime.stream_function(
"TestAzureFailure",
{
input: input,
},
@ctx_manager,
baml_options[:tb]&.instance_variable_get(:@registry),
baml_options[:client_registry],
)
Baml::BamlStream[T.nilable(String), String].new(
ffi_stream: raw,
ctx_manager: @ctx_manager
)
end

sig {
params(
varargs: T.untyped,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/ruby/baml_client/inlined.rb

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions integ-tests/typescript/baml_client/async_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,31 @@ export class BamlAsyncClient {
}
}

async TestAzureFailure(
input: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
): Promise<string> {
try {
const raw = await this.runtime.callFunction(
"TestAzureFailure",
{
"input": input
},
this.ctx_manager.cloneContext(),
__baml_options__?.tb?.__tb(),
__baml_options__?.clientRegistry,
)
return raw.parsed() as string
} catch (error: any) {
const bamlError = createBamlValidationError(error);
if (bamlError instanceof BamlValidationError) {
throw bamlError;
} else {
throw error;
}
}
}

async TestCaching(
input: string,not_cached: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
Expand Down Expand Up @@ -5804,6 +5829,39 @@ class BamlStreamClient {
}
}

TestAzureFailure(
input: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
): BamlStream<RecursivePartialNull<string>, string> {
try {
const raw = this.runtime.streamFunction(
"TestAzureFailure",
{
"input": input
},
undefined,
this.ctx_manager.cloneContext(),
__baml_options__?.tb?.__tb(),
__baml_options__?.clientRegistry,
)
return new BamlStream<RecursivePartialNull<string>, string>(
raw,
(a): a is RecursivePartialNull<string> => a,
(a): a is string => a,
this.ctx_manager.cloneContext(),
__baml_options__?.tb?.__tb(),
)
} catch (error) {
if (error instanceof Error) {
const bamlError = createBamlValidationError(error);
if (bamlError instanceof BamlValidationError) {
throw bamlError;
}
}
throw error;
}
}

TestCaching(
input: string,not_cached: string,
__baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry }
Expand Down
Loading
Loading