From 9b5739565b684c2179ac2ab24cabaa441a6269a7 Mon Sep 17 00:00:00 2001 From: hellovai Date: Tue, 3 Dec 2024 22:16:43 -0800 Subject: [PATCH] Fix azure client Add new client paramters: allowed_roles, default_role, finish_reason_allow_list, finish_reason_deny_list (#1209) TODO: add more tests, but screenshots show it working ---- > [!IMPORTANT] > This PR adds role selection and finish reason filtering to clients, updates error handling, and modifies tests for new parameters. > > - **Behavior**: > - Adds `role_selection` and `finish_reason_filter` parameters to clients in `anthropic.rs`, `aws_bedrock.rs`, and `google_ai.rs`. > - Implements `allowed_roles()` and `default_role()` methods for role selection. > - Introduces `finish_reason_filter` to manage finish reasons. > - **Error Handling**: > - Adds `FinishReasonError` in `errors.rs` for Python and TypeScript clients. > - Updates error handling in `internal_monkeypatch.py` and `errors.rs` for Python. > - Updates error handling in `errors.rs` for TypeScript. > - **Testing**: > - Updates tests in `test_cli.rs`, `test_runtime.rs`, and `test_file_manager.rs` for new parameters. > - Modifies test utilities in `harness.rs` for new client parameters. > - **Misc**: > - Adds `log-once` dependency in `Cargo.toml` and `Cargo.lock`. > - Minor updates to `README.md` and logging in `app.py`. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral) for 86521793b6867b8144fba186e0c5f3d2ec60dfbe. It will automatically update as commits are pushed. --------- Co-authored-by: Aaron Villalpando --- engine/Cargo.lock | 10 ++ engine/baml-lib/jinja-runtime/src/lib.rs | 120 ++++++++++++- .../llm-client/src/clients/anthropic.rs | 81 ++++----- .../llm-client/src/clients/aws_bedrock.rs | 84 +++++----- .../llm-client/src/clients/google_ai.rs | 83 +++++---- .../llm-client/src/clients/helpers.rs | 81 +++++++-- .../baml-lib/llm-client/src/clients/openai.rs | 111 +++++++------ .../baml-lib/llm-client/src/clients/vertex.rs | 81 ++++----- engine/baml-lib/llm-client/src/clientspec.rs | 157 ++++++++++++++++++ engine/baml-runtime/Cargo.toml | 3 +- engine/baml-runtime/src/cli/serve/error.rs | 19 +++ engine/baml-runtime/src/cli/serve/mod.rs | 2 +- engine/baml-runtime/src/errors.rs | 21 +++ .../internal/llm_client/orchestrator/call.rs | 18 +- .../internal/llm_client/orchestrator/mod.rs | 22 +++ .../llm_client/orchestrator/stream.rs | 18 +- .../primitive/anthropic/anthropic_client.rs | 26 +-- .../llm_client/primitive/aws/aws_client.rs | 26 +-- .../primitive/google/googleai_client.rs | 45 +++-- .../llm_client/primitive/google/types.rs | 2 +- .../src/internal/llm_client/primitive/mod.rs | 9 + .../primitive/openai/openai_client.rs | 69 +++----- .../llm_client/primitive/openai/types.rs | 21 +-- .../primitive/vertex/vertex_client.rs | 24 +-- .../src/internal/llm_client/traits/chat.rs | 19 ++- .../src/internal/llm_client/traits/mod.rs | 8 +- engine/baml-runtime/src/types/response.rs | 14 +- engine/baml-runtime/tests/harness.rs | 2 +- engine/baml-runtime/tests/test_cli.rs | 5 +- engine/baml-runtime/tests/test_runtime.rs | 10 +- engine/baml-schema-wasm/src/lib.rs | 2 + .../baml-schema-wasm/src/runtime_wasm/mod.rs | 23 ++- .../tests/test_file_manager.rs | 1 + .../baml_py/internal_monkeypatch.py | 16 +- engine/language_client_python/src/errors.rs | 19 +++ engine/language_client_typescript/index.d.ts | 17 +- .../language_client_typescript/index.d.ts.map | 2 +- engine/language_client_typescript/index.js | 67 ++++++-- .../language_client_typescript/src/errors.rs | 22 +++ .../typescript_src/index.ts | 113 +++++++++---- integ-tests/typescript/test-report.html | 7 +- .../src/baml_wasm_web/test_uis/testHooks.ts | 35 +++- .../baml_wasm_web/test_uis/test_result.tsx | 33 +++- 43 files changed, 1098 insertions(+), 450 deletions(-) diff --git a/engine/Cargo.lock b/engine/Cargo.lock index f973e31a9..ff8167eb5 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -961,6 +961,7 @@ dependencies = [ "jsonish", "jsonwebtoken", "log", + "log-once", "mime", "mime_guess", "minijinja", @@ -2971,6 +2972,15 @@ dependencies = [ "value-bag", ] +[[package]] +name = "log-once" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d8a05e3879b317b1b6dbf353e5bba7062bedcc59815267bb23eaa0c576cebf0" +dependencies = [ + "log", +] + [[package]] name = "magnus" version = "0.7.1" diff --git a/engine/baml-lib/jinja-runtime/src/lib.rs b/engine/baml-lib/jinja-runtime/src/lib.rs index 82ffa7eea..2c2928abb 100644 --- a/engine/baml-lib/jinja-runtime/src/lib.rs +++ b/engine/baml-lib/jinja-runtime/src/lib.rs @@ -25,6 +25,7 @@ pub struct RenderContext_Client { pub name: String, pub provider: String, pub default_role: String, + pub allowed_roles: Vec, } #[derive(Debug)] @@ -49,6 +50,7 @@ fn render_minijinja( mut ctx: RenderContext, template_string_macros: &[TemplateStringMacro], default_role: String, + allowed_roles: Vec, ) -> Result { let mut env = get_env(); @@ -240,7 +242,11 @@ fn render_minijinja( // Only add the message if it contains meaningful content if !parts.is_empty() { chat_messages.push(RenderedChatMessage { - role: role.as_ref().unwrap_or(&default_role).to_string(), + role: match role.as_ref() { + Some(r) if allowed_roles.contains(r) => r.clone(), + Some(_) => default_role.clone(), + None => default_role.clone(), + }, allow_duplicate_role, parts, }); @@ -410,12 +416,14 @@ pub fn render_prompt( let eval_ctx = EvaluationContext::new(env_vars, false); let minijinja_args: minijinja::Value = args.clone().to_minijinja_value(ir, &eval_ctx); let default_role = ctx.client.default_role.clone(); + let allowed_roles = ctx.client.allowed_roles.clone(); let rendered = render_minijinja( template, &minijinja_args, ctx, template_string_macros, default_role, + allowed_roles, ); match rendered { @@ -505,6 +513,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -565,6 +574,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -623,6 +633,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -690,6 +701,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string(), "john doe".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -767,6 +779,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -817,6 +830,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -856,6 +870,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -895,6 +910,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -934,6 +950,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -995,6 +1012,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -1054,6 +1072,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string(), "john doe".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -1099,6 +1118,91 @@ mod render_tests { Ok(()) } + + #[test] + fn render_with_kwargs_default_role() -> anyhow::Result<()> { + setup_logging(); + + let args: BamlValue = BamlValue::Map(BamlMap::from([( + "haiku_subject".to_string(), + BamlValue::String("sakura".to_string()), + )])); + + let ir = make_test_ir( + " + class C { + + } + ", + )?; + + let rendered = render_prompt( + r#" + + You are an assistant that always responds + in a very excited way with emojis + and also outputs this word 4 times + after giving a response: {{ haiku_subject }} + + {{ _.chat(role=ctx.tags.ROLE) }} + + Tell me a haiku about {{ haiku_subject }}. {{ ctx.output_format }} + + {{ _.chat("user") }} + End the haiku with a line about your maker, {{ ctx.client.provider }}. + "#, + &args, + RenderContext { + client: RenderContext_Client { + name: "gpt4".to_string(), + provider: "openai".to_string(), + default_role: "system".to_string(), + allowed_roles: vec!["system".to_string(), "user".to_string()], + }, + output_format: OutputFormatContent::new_string(), + tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), + }, + &[], + &ir, + &HashMap::new(), + )?; + + assert_eq!( + rendered, + RenderedPrompt::Chat(vec![ + RenderedChatMessage { + role: "system".to_string(), + allow_duplicate_role: false, + parts: vec![ChatMessagePart::Text( + [ + "You are an assistant that always responds", + "in a very excited way with emojis", + "and also outputs this word 4 times", + "after giving a response: sakura" + ] + .join("\n") + )] + }, + RenderedChatMessage { + role: "system".to_string(), + allow_duplicate_role: false, + parts: vec![ChatMessagePart::Text( + "Tell me a haiku about sakura.".to_string() + )] + }, + RenderedChatMessage { + role: "user".to_string(), + allow_duplicate_role: false, + parts: vec![ChatMessagePart::Text( + "End the haiku with a line about your maker, openai.".to_string() + )] + } + ]) + ); + + Ok(()) + } + #[test] fn render_chat_starts_with_system() -> anyhow::Result<()> { setup_logging(); @@ -1131,6 +1235,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -1185,6 +1290,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1235,6 +1341,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1281,6 +1388,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1300,6 +1408,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1342,6 +1451,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1361,6 +1471,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1403,6 +1514,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1476,6 +1588,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1569,6 +1682,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1674,6 +1788,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1750,6 +1865,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1809,6 +1925,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1913,6 +2030,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), diff --git a/engine/baml-lib/llm-client/src/clients/anthropic.rs b/engine/baml-lib/llm-client/src/clients/anthropic.rs index 1c5de6b79..fc9a7403a 100644 --- a/engine/baml-lib/llm-client/src/clients/anthropic.rs +++ b/engine/baml-lib/llm-client/src/clients/anthropic.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::Result; use baml_types::{EvaluationContext, StringOr, UnresolvedValue}; @@ -12,12 +12,12 @@ use super::helpers::{Error, PropertyHandler, UnresolvedUrl}; pub struct UnresolvedAnthropic { base_url: UnresolvedUrl, api_key: StringOr, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, headers: IndexMap, properties: IndexMap)>, + finish_reason_filter: UnresolvedFinishReasonFilter, } impl UnresolvedAnthropic { @@ -25,8 +25,7 @@ impl UnresolvedAnthropic { UnresolvedAnthropic { base_url: self.base_url.clone(), api_key: self.api_key.clone(), - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), allowed_metadata: self.allowed_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), headers: self @@ -39,6 +38,7 @@ impl UnresolvedAnthropic { .iter() .map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))) .collect(), + finish_reason_filter: self.finish_reason_filter.clone(), } } } @@ -46,28 +46,42 @@ impl UnresolvedAnthropic { pub struct ResolvedAnthropic { pub base_url: String, pub api_key: String, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_metadata: AllowedRoleMetadata, pub supported_request_modes: SupportedRequestModes, pub headers: IndexMap, pub properties: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, } +impl ResolvedAnthropic { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection + .default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) + } +} + + impl UnresolvedAnthropic { pub fn required_env_vars(&self) -> HashSet { let mut env_vars = HashSet::new(); env_vars.extend(self.base_url.required_env_vars()); env_vars.extend(self.api_key.required_env_vars()); - env_vars.extend( - self.allowed_roles - .iter() - .flat_map(|r| r.required_env_vars()), - ); - if let Some(r) = self.default_role.as_ref() { - env_vars.extend(r.required_env_vars()) - } + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); env_vars.extend(self.headers.values().flat_map(|v| v.required_env_vars())); @@ -81,25 +95,6 @@ impl UnresolvedAnthropic { } pub fn resolve(&self, ctx: &EvaluationContext<'_>) -> Result { - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } - let base_url = self.base_url.resolve(ctx)?; let mut headers = self @@ -130,13 +125,13 @@ impl UnresolvedAnthropic { Ok(ResolvedAnthropic { base_url, api_key: self.api_key.resolve(ctx)?, - allowed_roles, - default_role, + role_selection: self.role_selection.resolve(ctx)?, allowed_metadata: self.allowed_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), headers, properties, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -148,17 +143,11 @@ impl UnresolvedAnthropic { .map(|(_, v, _)| v.clone()) .unwrap_or(StringOr::EnvVar("ANTHROPIC_API_KEY".to_string())); - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); - + let finish_reason_filter = properties.ensure_finish_reason_filter(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { return Err(errors); @@ -167,12 +156,12 @@ impl UnresolvedAnthropic { Ok(Self { base_url, api_key, - allowed_roles, - default_role, + role_selection, allowed_metadata, supported_request_modes, headers, properties, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs b/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs index 359b23db4..766092b8a 100644 --- a/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs +++ b/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::Result; use baml_types::{EvaluationContext, StringOr}; @@ -13,11 +13,11 @@ pub struct UnresolvedAwsBedrock { region: StringOr, access_key_id: StringOr, secret_access_key: StringOr, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_role_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, inference_config: Option, + finish_reason_filter: UnresolvedFinishReasonFilter, } #[derive(Debug, Clone)] @@ -64,10 +64,30 @@ pub struct ResolvedAwsBedrock { pub access_key_id: Option, pub secret_access_key: Option, pub inference_config: Option, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_role_metadata: AllowedRoleMetadata, pub supported_request_modes: SupportedRequestModes, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedAwsBedrock { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection + .default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) + } } impl UnresolvedAwsBedrock { @@ -79,14 +99,7 @@ impl UnresolvedAwsBedrock { env_vars.extend(self.region.required_env_vars()); env_vars.extend(self.access_key_id.required_env_vars()); env_vars.extend(self.secret_access_key.required_env_vars()); - env_vars.extend( - self.allowed_roles - .iter() - .flat_map(|r| r.required_env_vars()), - ); - if let Some(r) = self.default_role.as_ref() { - env_vars.extend(r.required_env_vars()) - } + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_role_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); if let Some(c) = self.inference_config.as_ref() { @@ -100,32 +113,14 @@ impl UnresolvedAwsBedrock { return Err(anyhow::anyhow!("model must be provided")); }; - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; Ok(ResolvedAwsBedrock { model: model.resolve(ctx)?, region: self.region.resolve(ctx).ok(), access_key_id: self.access_key_id.resolve(ctx).ok(), secret_access_key: self.secret_access_key.resolve(ctx).ok(), - allowed_roles, - default_role, + role_selection, allowed_role_metadata: self.allowed_role_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), inference_config: self @@ -133,6 +128,7 @@ impl UnresolvedAwsBedrock { .as_ref() .map(|c| c.resolve(ctx)) .transpose()?, + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -175,12 +171,7 @@ impl UnresolvedAwsBedrock { .map(|(_, v, _)| v.clone()) .unwrap_or_else(|| baml_types::StringOr::EnvVar("AWS_SECRET_ACCESS_KEY".to_string())); - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); @@ -218,11 +209,11 @@ impl UnresolvedAwsBedrock { }), "stop_sequences" => inference_config.stop_sequences = match v.into_array() { Ok((stop_sequences, _)) => Some(stop_sequences.into_iter().filter_map(|s| match s.into_str() { - Ok((s, _)) => Some(s), - Err(e) => { - properties.push_error(format!("stop_sequences values must be a string: got {}", e.r#type()), e.meta().clone()); - None - } + Ok((s, _)) => Some(s), + Err(e) => { + properties.push_error(format!("stop_sequences values must be a string: got {}", e.r#type()), e.meta().clone()); + None + } }) .collect::>()), Err(e) => { @@ -241,6 +232,7 @@ impl UnresolvedAwsBedrock { } Some(inference_config) }; + let finish_reason_filter = properties.ensure_finish_reason_filter(); // TODO: Handle inference_configuration let errors = properties.finalize_empty(); @@ -253,11 +245,11 @@ impl UnresolvedAwsBedrock { region, access_key_id, secret_access_key, - allowed_roles, - default_role, + role_selection, allowed_role_metadata: allowed_metadata, supported_request_modes, inference_config, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/google_ai.rs b/engine/baml-lib/llm-client/src/clients/google_ai.rs index 7a9829b69..3f36fd61a 100644 --- a/engine/baml-lib/llm-client/src/clients/google_ai.rs +++ b/engine/baml-lib/llm-client/src/clients/google_ai.rs @@ -2,6 +2,9 @@ use std::collections::HashSet; use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; use anyhow::Result; +use crate::{ + FinishReasonFilter, RolesSelection, UnresolvedFinishReasonFilter, UnresolvedRolesSelection +}; use baml_types::{EvaluationContext, StringOr, UnresolvedValue}; use indexmap::IndexMap; @@ -13,19 +16,18 @@ pub struct UnresolvedGoogleAI { api_key: StringOr, base_url: UnresolvedUrl, headers: IndexMap, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, model: Option, allowed_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, + finish_reason_filter: UnresolvedFinishReasonFilter, properties: IndexMap)>, } impl UnresolvedGoogleAI { pub fn without_meta(&self) -> UnresolvedGoogleAI<()> { UnresolvedGoogleAI { - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), api_key: self.api_key.clone(), model: self.model.clone(), base_url: self.base_url.clone(), @@ -41,13 +43,13 @@ impl UnresolvedGoogleAI { .iter() .map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))) .collect::>(), + finish_reason_filter: self.finish_reason_filter.clone(), } } } pub struct ResolvedGoogleAI { - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub api_key: String, pub model: String, pub base_url: String, @@ -56,6 +58,26 @@ pub struct ResolvedGoogleAI { pub supported_request_modes: SupportedRequestModes, pub properties: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedGoogleAI { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) + } } impl UnresolvedGoogleAI { @@ -64,17 +86,10 @@ impl UnresolvedGoogleAI { env_vars.extend(self.api_key.required_env_vars()); env_vars.extend(self.base_url.required_env_vars()); env_vars.extend(self.headers.values().flat_map(StringOr::required_env_vars)); - env_vars.extend( - self.allowed_roles - .iter() - .flat_map(|r| r.required_env_vars()), - ); - if let Some(r) = self.default_role.as_ref() { - env_vars.extend(r.required_env_vars()) - } if let Some(m) = self.model.as_ref() { env_vars.extend(m.required_env_vars()) } + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); env_vars.extend( @@ -87,23 +102,7 @@ impl UnresolvedGoogleAI { pub fn resolve(&self, ctx: &EvaluationContext<'_>) -> Result { let api_key = self.api_key.resolve(ctx)?; - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - let allowed_roles = self - .allowed_roles - .iter() - .map(|r| r.resolve(ctx)) - .collect::>>()?; - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; let model = self .model @@ -121,12 +120,11 @@ impl UnresolvedGoogleAI { .collect::>>()?; Ok(ResolvedGoogleAI { - default_role, + role_selection, api_key, model, base_url, headers, - allowed_roles, allowed_metadata: self.allowed_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), properties: self @@ -135,20 +133,13 @@ impl UnresolvedGoogleAI { .map(|(k, (_, v))| Ok((k.clone(), v.resolve_serde::(ctx)?))) .collect::>>()?, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } pub fn create_from(mut properties: PropertyHandler) -> Result>> { - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); - - let api_key = properties - .ensure_api_key() - .unwrap_or(StringOr::EnvVar("GOOGLE_API_KEY".to_string())); + let role_selection = properties.ensure_roles_selection(); + let api_key = properties.ensure_api_key().map(|v| v.clone()).unwrap_or(StringOr::EnvVar("GOOGLE_API_KEY".to_string())); let model = properties .ensure_string("model", false) @@ -161,7 +152,7 @@ impl UnresolvedGoogleAI { let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); - + let finish_reason_filter = properties.ensure_finish_reason_filter(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { @@ -169,8 +160,7 @@ impl UnresolvedGoogleAI { } Ok(Self { - allowed_roles, - default_role, + role_selection, api_key, model, base_url, @@ -178,6 +168,7 @@ impl UnresolvedGoogleAI { allowed_metadata, supported_request_modes, properties, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/helpers.rs b/engine/baml-lib/llm-client/src/clients/helpers.rs index 3f19a9714..3dd5868e8 100644 --- a/engine/baml-lib/llm-client/src/clients/helpers.rs +++ b/engine/baml-lib/llm-client/src/clients/helpers.rs @@ -3,7 +3,10 @@ use std::{borrow::Cow, collections::HashSet}; use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; use indexmap::IndexMap; -use crate::{SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{ + SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, + UnresolvedRolesSelection, +}; #[derive(Debug, Clone)] pub struct UnresolvedUrl(StringOr); @@ -157,9 +160,13 @@ impl PropertyHandler { result.map(|(key_span, value, meta)| (key_span.clone(), value, meta.clone())) } - pub fn ensure_allowed_roles(&mut self) -> Option> { + fn ensure_allowed_roles(&mut self) -> Option> { self.ensure_array("allowed_roles", false) - .map(|(_, value, _)| { + .map(|(_, value, value_span)| { + if value.is_empty() { + self.push_error("allowed_roles must not be empty", value_span); + } + value .into_iter() .filter_map(|v| match v.as_str() { @@ -179,11 +186,17 @@ impl PropertyHandler { }) } - pub fn ensure_default_role( - &mut self, - allowed_roles: &[StringOr], - default_role_index: usize, - ) -> Option { + pub(crate) fn ensure_roles_selection(&mut self) -> UnresolvedRolesSelection { + let allowed_roles = self.ensure_allowed_roles(); + let default_role = self.ensure_default_role(allowed_roles.as_ref().unwrap_or(&vec![ + StringOr::Value("user".to_string()), + StringOr::Value("assistant".to_string()), + StringOr::Value("system".to_string()), + ])); + UnresolvedRolesSelection::new(allowed_roles, default_role) + } + + fn ensure_default_role(&mut self, allowed_roles: &[StringOr]) -> Option { self.ensure_string("default_role", false) .and_then(|(_, value, span)| { if allowed_roles.iter().any(|v| value.maybe_eq(v)) { @@ -196,14 +209,13 @@ impl PropertyHandler { .join(", "); self.push_error( format!( - "default_role must be one of {allowed_roles_str}. Got: {value}" + "default_role must be one of {allowed_roles_str}. Got: {value}. To support different default roles, add allowed_roles [\"user\", \"assistant\", \"system\", ...]" ), span, ); None } }) - .or_else(|| allowed_roles.get(default_role_index).cloned()) } pub fn ensure_api_key(&mut self) -> Option { @@ -232,6 +244,55 @@ impl PropertyHandler { } } + pub fn ensure_finish_reason_filter(&mut self) -> UnresolvedFinishReasonFilter { + let allow_list = self.ensure_array("finish_reason_allow_list", false); + let deny_list = self.ensure_array("finish_reason_deny_list", false); + + match (allow_list, deny_list) { + (Some(allow), Some(deny)) => { + self.push_error( + "finish_reason_allow_list and finish_reason_deny_list cannot be used together", + allow.0, + ); + self.push_error( + "finish_reason_allow_list and finish_reason_deny_list cannot be used together", + deny.0, + ); + UnresolvedFinishReasonFilter::All + } + (Some((_, allow, _)), None) => UnresolvedFinishReasonFilter::AllowList( + allow + .into_iter() + .filter_map(|v| match v.as_str() { + Some(s) => Some(s.clone()), + None => { + self.push_error( + "values in finish_reason_allow_list must be strings.", + v.meta().clone(), + ); + None + } + }) + .collect(), + ), + (None, Some((_, deny, _))) => UnresolvedFinishReasonFilter::DenyList( + deny.into_iter() + .filter_map(|v| match v.into_str() { + Ok((s, _)) => Some(s.clone()), + Err(other) => { + self.push_error( + "values in finish_reason_deny_list must be strings.", + other.meta().clone(), + ); + None + } + }) + .collect(), + ), + (None, None) => UnresolvedFinishReasonFilter::All, + } + } + pub fn ensure_any(&mut self, key: &str) -> Option<(Meta, UnresolvedValue)> { self.options.shift_remove(key) } diff --git a/engine/baml-lib/llm-client/src/clients/openai.rs b/engine/baml-lib/llm-client/src/clients/openai.rs index 5da0fda57..80c01996a 100644 --- a/engine/baml-lib/llm-client/src/clients/openai.rs +++ b/engine/baml-lib/llm-client/src/clients/openai.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::Result; use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; @@ -12,13 +12,13 @@ use super::helpers::{Error, PropertyHandler, UnresolvedUrl}; pub struct UnresolvedOpenAI { base_url: Option>, api_key: Option, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_role_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, headers: IndexMap, properties: IndexMap)>, query_params: IndexMap, + finish_reason_filter: UnresolvedFinishReasonFilter, } impl UnresolvedOpenAI { @@ -26,8 +26,7 @@ impl UnresolvedOpenAI { UnresolvedOpenAI { base_url: self.base_url.clone(), api_key: self.api_key.clone(), - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), allowed_role_metadata: self.allowed_role_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), headers: self @@ -45,6 +44,7 @@ impl UnresolvedOpenAI { .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(), + finish_reason_filter: self.finish_reason_filter.clone(), } } } @@ -52,14 +52,46 @@ impl UnresolvedOpenAI { pub struct ResolvedOpenAI { pub base_url: String, pub api_key: Option, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_metadata: AllowedRoleMetadata, - pub supported_request_modes: SupportedRequestModes, + supported_request_modes: SupportedRequestModes, pub headers: IndexMap, pub properties: IndexMap, pub query_params: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedOpenAI { + fn is_o1_model(&self) -> bool { + self.properties.get("model").is_some_and(|model| model.as_str().map(|s| s.starts_with("o1-")).unwrap_or(false)) + } + + pub fn supports_streaming(&self) -> bool { + match self.supported_request_modes.stream { + Some(v) => v, + None => !self.is_o1_model(), + } + } + + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + if self.is_o1_model() { + vec!["user".to_string(), "assistant".to_string()] + } else { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + } + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| { + // TODO: guard against empty allowed_roles + // The compiler should already guarantee that this is non-empty + self.allowed_roles().remove(0) + + }) + } } impl UnresolvedOpenAI { @@ -80,12 +112,7 @@ impl UnresolvedOpenAI { if let Some(key) = self.api_key.as_ref() { env_vars.extend(key.required_env_vars()) } - self.allowed_roles - .iter() - .for_each(|role| env_vars.extend(role.required_env_vars())); - if let Some(role) = self.default_role.as_ref() { - env_vars.extend(role.required_env_vars()) - } + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_role_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); self.headers @@ -131,24 +158,7 @@ impl UnresolvedOpenAI { .map(|key| key.resolve(ctx)) .transpose()?; - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; let headers = self .headers @@ -184,14 +194,14 @@ impl UnresolvedOpenAI { Ok(ResolvedOpenAI { base_url, api_key, - allowed_roles, - default_role, + role_selection, allowed_metadata: self.allowed_role_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), headers, properties, query_params, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -255,19 +265,19 @@ impl UnresolvedOpenAI { } }; - let api_key = Some( - properties - .ensure_api_key() - .unwrap_or_else(|| StringOr::EnvVar("AZURE_OPENAI_API_KEY".to_string())), - ); + let api_key = properties + .ensure_api_key() + .map(|v| v.clone()) + .unwrap_or_else(|| StringOr::EnvVar("AZURE_OPENAI_API_KEY".to_string())); let mut query_params = IndexMap::new(); if let Some((_, v, _)) = properties.ensure_string("api_version", false) { query_params.insert("api-version".to_string(), v.clone()); } - let mut instance = Self::create_common(properties, base_url, api_key)?; + let mut instance = Self::create_common(properties, base_url, None)?; instance.query_params = query_params; + instance.headers.entry("api-key".to_string()).or_insert(api_key); Ok(instance) } @@ -290,7 +300,13 @@ impl UnresolvedOpenAI { let api_key = properties.ensure_api_key(); - Self::create_common(properties, Some(either::Either::Left(base_url)), api_key) + let mut instance = Self::create_common(properties, Some(either::Either::Left(base_url)), api_key)?; + // Ollama uses smaller models many of which prefer the user role + if instance.role_selection.default.is_none() { + instance.role_selection.default = Some(StringOr::Value("user".to_string())); + } + + Ok(instance) } fn create_common( @@ -298,16 +314,11 @@ impl UnresolvedOpenAI { base_url: Option>, api_key: Option, ) -> Result>> { - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); + let finish_reason_filter = properties.ensure_finish_reason_filter(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { @@ -317,13 +328,13 @@ impl UnresolvedOpenAI { Ok(Self { base_url, api_key, - allowed_roles, - default_role, + role_selection, allowed_role_metadata: allowed_metadata, supported_request_modes, headers, properties, query_params: IndexMap::new(), + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clients/vertex.rs b/engine/baml-lib/llm-client/src/clients/vertex.rs index 715df0915..61bfdd07e 100644 --- a/engine/baml-lib/llm-client/src/clients/vertex.rs +++ b/engine/baml-lib/llm-client/src/clients/vertex.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; +use crate::{AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection}; use anyhow::{Context, Result}; use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; @@ -120,10 +120,10 @@ pub struct UnresolvedVertex { authorization: UnresolvedServiceAccountDetails, model: StringOr, headers: IndexMap, - allowed_roles: Vec, - default_role: Option, + role_selection: UnresolvedRolesSelection, allowed_role_metadata: UnresolvedAllowedRoleMetadata, supported_request_modes: SupportedRequestModes, + finish_reason_filter: UnresolvedFinishReasonFilter, properties: IndexMap)>, } @@ -132,12 +132,31 @@ pub struct ResolvedVertex { pub authorization: ResolvedServiceAccountDetails, pub model: String, pub headers: IndexMap, - pub allowed_roles: Vec, - pub default_role: String, + role_selection: RolesSelection, pub allowed_metadata: AllowedRoleMetadata, pub supported_request_modes: SupportedRequestModes, pub properties: IndexMap, pub proxy_url: Option, + pub finish_reason_filter: FinishReasonFilter, +} + +impl ResolvedVertex { + pub fn allowed_roles(&self) -> Vec { + self.role_selection.allowed_or_else(|| { + vec!["system".to_string(), "user".to_string(), "assistant".to_string()] + }) + } + + pub fn default_role(&self) -> String { + self.role_selection.default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) + } } impl UnresolvedVertex { @@ -152,15 +171,8 @@ impl UnresolvedVertex { } env_vars.extend(self.authorization.required_env_vars()); env_vars.extend(self.model.required_env_vars()); - env_vars.extend(self.headers.values().flat_map(|v| v.required_env_vars())); - env_vars.extend( - self.allowed_roles - .iter() - .flat_map(|r| r.required_env_vars()), - ); - if let Some(r) = self.default_role.as_ref() { - env_vars.extend(r.required_env_vars()) - } + env_vars.extend(self.headers.values().flat_map(StringOr::required_env_vars)); + env_vars.extend(self.role_selection.required_env_vars()); env_vars.extend(self.allowed_role_metadata.required_env_vars()); env_vars.extend(self.supported_request_modes.required_env_vars()); env_vars.extend( @@ -179,8 +191,7 @@ impl UnresolvedVertex { authorization: self.authorization.without_meta(), model: self.model.clone(), headers: self.headers.clone(), - allowed_roles: self.allowed_roles.clone(), - default_role: self.default_role.clone(), + role_selection: self.role_selection.clone(), allowed_role_metadata: self.allowed_role_metadata.clone(), supported_request_modes: self.supported_request_modes.clone(), properties: self @@ -188,6 +199,7 @@ impl UnresolvedVertex { .iter() .map(|(k, (_, v))| (k.clone(), ((), v.without_meta()))) .collect(), + finish_reason_filter: self.finish_reason_filter.clone(), } } @@ -221,24 +233,7 @@ impl UnresolvedVertex { let model = self.model.resolve(ctx)?; - let allowed_roles = self - .allowed_roles - .iter() - .map(|role| role.resolve(ctx)) - .collect::>>()?; - - let Some(default_role) = self.default_role.as_ref() else { - return Err(anyhow::anyhow!("default_role must be provided")); - }; - let default_role = default_role.resolve(ctx)?; - - if !allowed_roles.contains(&default_role) { - return Err(anyhow::anyhow!( - "default_role must be in allowed_roles: {} not in {:?}", - default_role, - allowed_roles - )); - } + let role_selection = self.role_selection.resolve(ctx)?; let headers = self .headers @@ -251,8 +246,7 @@ impl UnresolvedVertex { authorization, model, headers, - allowed_roles, - default_role, + role_selection, allowed_metadata: self.allowed_role_metadata.resolve(ctx)?, supported_request_modes: self.supported_request_modes.clone(), properties: self @@ -261,6 +255,7 @@ impl UnresolvedVertex { .map(|(k, (_, v))| Ok((k.clone(), v.resolve_serde::(ctx)?))) .collect::>>()?, proxy_url: super::helpers::get_proxy_url(ctx), + finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, }) } @@ -348,16 +343,12 @@ impl UnresolvedVertex { .ensure_string("project_id", false) .map(|(_, v, _)| v); - let allowed_roles = properties.ensure_allowed_roles().unwrap_or(vec![ - StringOr::Value("system".to_string()), - StringOr::Value("user".to_string()), - StringOr::Value("assistant".to_string()), - ]); - - let default_role = properties.ensure_default_role(allowed_roles.as_slice(), 1); + let role_selection = properties.ensure_roles_selection(); let allowed_metadata = properties.ensure_allowed_metadata(); let supported_request_modes = properties.ensure_supported_request_modes(); let headers = properties.ensure_headers().unwrap_or_default(); + let finish_reason_filter = properties.ensure_finish_reason_filter(); + let (properties, errors) = properties.finalize(); if !errors.is_empty() { return Err(errors); @@ -373,11 +364,11 @@ impl UnresolvedVertex { authorization, model, headers, - allowed_roles, - default_role, + role_selection, allowed_role_metadata: allowed_metadata, supported_request_modes, properties, + finish_reason_filter, }) } } diff --git a/engine/baml-lib/llm-client/src/clientspec.rs b/engine/baml-lib/llm-client/src/clientspec.rs index a9277f9df..6641db0fa 100644 --- a/engine/baml-lib/llm-client/src/clientspec.rs +++ b/engine/baml-lib/llm-client/src/clientspec.rs @@ -197,6 +197,163 @@ impl SupportedRequestModes { } } +#[derive(Clone, Debug)] +pub enum UnresolvedFinishReasonFilter { + All, + AllowList(HashSet), + DenyList(HashSet), +} + +#[derive(Clone, Debug)] +pub enum FinishReasonFilter { + All, + AllowList(HashSet), + DenyList(HashSet), +} + +impl UnresolvedFinishReasonFilter { + pub fn required_env_vars(&self) -> HashSet { + match self { + Self::AllowList(allow) => allow + .iter() + .map(|s| s.required_env_vars()) + .flatten() + .collect(), + Self::DenyList(deny) => deny + .iter() + .map(|s| s.required_env_vars()) + .flatten() + .collect(), + _ => HashSet::new(), + } + } + + pub fn resolve(&self, ctx: &impl GetEnvVar) -> Result { + match self { + Self::AllowList(allow) => Ok(FinishReasonFilter::AllowList( + allow + .iter() + .map(|s| s.resolve(ctx)) + .collect::>>()?, + )), + Self::DenyList(deny) => Ok(FinishReasonFilter::DenyList( + deny.iter() + .map(|s| s.resolve(ctx)) + .collect::>>()?, + )), + Self::All => Ok(FinishReasonFilter::All), + } + } +} + +impl FinishReasonFilter { + pub fn is_allowed(&self, reason: Option>) -> bool { + log::warn!( + "debug is_allowed: {:?} {}", + self, + reason + .as_ref() + .map(|r| r.as_ref().to_string()) + .unwrap_or("".into()) + ); + match self { + Self::AllowList(allow) => { + let Some(reason) = reason.map(|r| r.as_ref().to_string()) else { + return false; + }; + allow.contains(&reason) + } + Self::DenyList(deny) => { + let Some(reason) = reason.map(|r| r.as_ref().to_string()) else { + return true; + }; + !deny.contains(&reason) + } + Self::All => true, + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct UnresolvedRolesSelection { + pub allowed: Option>, + pub default: Option, +} + +impl UnresolvedRolesSelection { + pub fn new(allowed: Option>, default: Option) -> Self { + Self { allowed, default } + } + + pub fn required_env_vars(&self) -> HashSet { + let mut env_vars = HashSet::new(); + if let Some(allowed) = &self.allowed { + env_vars.extend(allowed.iter().map(|s| s.required_env_vars()).flatten()); + } + if let Some(default) = &self.default { + env_vars.extend(default.required_env_vars()); + } + env_vars + } + + pub fn resolve(&self, ctx: &impl GetEnvVar) -> Result { + let allowed = self + .allowed + .as_ref() + .map(|allowed| { + allowed + .iter() + .map(|s| s.resolve(ctx)) + .collect::>>() + }) + .transpose()?; + + let default = self + .default + .as_ref() + .map(|default| default.resolve(ctx)) + .transpose()?; + + match (&allowed, &default) { + (Some(allowed), Some(default)) => { + if !allowed.contains(&default) { + return Err(anyhow::anyhow!("default_role must be in allowed_roles: {}. Not found in {:?}", default, allowed)); + } + } + (None, Some(default)) => { + match default.as_str() { + "system" | "user" | "assistant" => {} + _ => return Err(anyhow::anyhow!("default_role must be one of 'system', 'user' or 'assistant': {}. Please specify \"allowed_roles\" if you want to use other custom default role.", default)), + } + } + _ => {} + } + Ok(RolesSelection { allowed, default }) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct RolesSelection { + allowed: Option>, + default: Option, +} + +impl RolesSelection { + pub fn allowed_or_else(&self, f: impl FnOnce() -> Vec) -> Vec { + match self.allowed.as_ref() { + Some(allowed) => allowed.clone(), + None => f(), + } + } + + pub fn default_or_else(&self, f: impl FnOnce() -> String) -> String { + match self.default.as_ref() { + Some(default) => default.clone(), + None => f(), + } + } +} + #[derive(Clone, Debug)] pub enum UnresolvedAllowedRoleMetadata { Value(StringOr), diff --git a/engine/baml-runtime/Cargo.toml b/engine/baml-runtime/Cargo.toml index 5db9034e6..679d0c272 100644 --- a/engine/baml-runtime/Cargo.toml +++ b/engine/baml-runtime/Cargo.toml @@ -91,6 +91,7 @@ valuable = { version = "0.1.0", features = ["derive"] } tracing = { version = "0.1.40", features = ["valuable"] } tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter","valuable"] } thiserror = "2.0.1" +log-once = "0.4.1" [target.'cfg(target_arch = "wasm32")'.dependencies] @@ -141,7 +142,7 @@ which = "6.0.3" [features] -defaults = [] +defaults = ["skip-integ-tests"] internal = [] skip-integ-tests = [] diff --git a/engine/baml-runtime/src/cli/serve/error.rs b/engine/baml-runtime/src/cli/serve/error.rs index 236e4015f..cb30b7709 100644 --- a/engine/baml-runtime/src/cli/serve/error.rs +++ b/engine/baml-runtime/src/cli/serve/error.rs @@ -27,6 +27,13 @@ pub enum BamlError { raw_output: String, message: String, }, + #[serde(rename_all = "snake_case")] + FinishReasonError { + prompt: String, + raw_output: String, + message: String, + finish_reason: Option, + }, /// This is the only variant not documented at the aforementioned link: /// this is the catch-all for unclassified errors. #[serde(rename_all = "snake_case")] @@ -46,6 +53,17 @@ impl BamlError { raw_output: raw_output.to_string(), message: message.to_string(), }, + ExposedError::FinishReasonError { + prompt, + raw_output, + message, + finish_reason, + } => Self::FinishReasonError { + prompt: prompt.to_string(), + raw_output: raw_output.to_string(), + message: message.to_string(), + finish_reason: finish_reason.clone(), + }, } } else if let Some(er) = err.downcast_ref::() { Self::InvalidArgument { @@ -93,6 +111,7 @@ impl IntoResponse for BamlError { match &self { BamlError::InvalidArgument { .. } => StatusCode::BAD_REQUEST, BamlError::ClientError { .. } => StatusCode::BAD_GATEWAY, + BamlError::FinishReasonError { .. } => StatusCode::INTERNAL_SERVER_ERROR, // ??? - FIXME BamlError::ValidationFailure { .. } => StatusCode::INTERNAL_SERVER_ERROR, // ??? - FIXME BamlError::InternalError { .. } => StatusCode::INTERNAL_SERVER_ERROR, }, diff --git a/engine/baml-runtime/src/cli/serve/mod.rs b/engine/baml-runtime/src/cli/serve/mod.rs index a515d782d..1feab0f6d 100644 --- a/engine/baml-runtime/src/cli/serve/mod.rs +++ b/engine/baml-runtime/src/cli/serve/mod.rs @@ -226,7 +226,7 @@ impl Server { baml_api_key: Option<&XBamlApiKey>, ) -> AuthEnforcementMode { let Ok(password) = std::env::var("BAML_PASSWORD") else { - log::warn!("BAML_PASSWORD not set, skipping auth check"); + log_once::warn_once!("BAML_PASSWORD not set, skipping auth check"); return AuthEnforcementMode::NoEnforcement; }; diff --git a/engine/baml-runtime/src/errors.rs b/engine/baml-runtime/src/errors.rs index 1028357eb..130f2f0e0 100644 --- a/engine/baml-runtime/src/errors.rs +++ b/engine/baml-runtime/src/errors.rs @@ -5,6 +5,12 @@ pub enum ExposedError { raw_output: String, message: String, }, + FinishReasonError { + prompt: String, + raw_output: String, + message: String, + finish_reason: Option, + }, } impl std::error::Error for ExposedError {} @@ -23,6 +29,21 @@ impl std::fmt::Display for ExposedError { message, prompt, raw_output ) } + ExposedError::FinishReasonError { + prompt, + raw_output, + message, + finish_reason, + } => { + write!( + f, + "Finish reason error: {}\nPrompt: {}\nRaw Response: {}\nFinish Reason: {}", + message, + prompt, + raw_output, + finish_reason.as_ref().map_or("", |f| f.as_str()) + ) + } } } } diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs index 762c8a577..1156b43e8 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs @@ -8,7 +8,7 @@ use crate::{ internal::{ llm_client::{ parsed_value_to_response, - traits::{WithPrompt, WithSingleCallable}, + traits::{WithClientProperties, WithPrompt, WithSingleCallable}, LLMResponse, ResponseBamlValue, }, prompt_renderer::PromptRenderer, @@ -52,7 +52,21 @@ pub async fn orchestrate( }; let response = node.single_call(ctx, &prompt).await; let parsed_response = match &response { - LLMResponse::Success(s) => Some(parse_fn(&s.content)), + LLMResponse::Success(s) => { + if !node + .finish_reason_filter() + .is_allowed(s.metadata.finish_reason.as_ref()) + { + Some(Err(anyhow::anyhow!(crate::errors::ExposedError::FinishReasonError { + prompt: prompt.to_string(), + raw_output: s.content.clone(), + message: "Finish reason not allowed".to_string(), + finish_reason: s.metadata.finish_reason.clone(), + }))) + } else { + Some(parse_fn(&s.content)) + } + }, _ => None, }; diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs index 31c2ebdf2..a5ee1dd71 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs @@ -219,3 +219,25 @@ impl WithStreamable for OrchestratorNode { self.provider.stream(ctx, prompt).await } } + +impl WithClientProperties for OrchestratorNode { + fn default_role(&self) -> String { + self.provider.default_role() + } + + fn allowed_metadata(&self) -> &internal_llm_client::AllowedRoleMetadata { + self.provider.allowed_metadata() + } + + fn supports_streaming(&self) -> bool { + self.provider.supports_streaming() + } + + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + self.provider.finish_reason_filter() + } + + fn allowed_roles(&self) -> Vec { + self.provider.allowed_roles() + } +} diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index 4349476e6..c3cf30fa2 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -9,7 +9,7 @@ use crate::{ internal::{ llm_client::{ parsed_value_to_response, - traits::{WithPrompt, WithStreamable}, + traits::{WithClientProperties, WithPrompt, WithStreamable}, LLMErrorResponse, LLMResponse, ResponseBamlValue, }, prompt_renderer::PromptRenderer, @@ -100,7 +100,21 @@ where }; let parsed_response = match &final_response { - LLMResponse::Success(s) => Some(parse_fn(&s.content)), + LLMResponse::Success(s) => { + if !node + .finish_reason_filter() + .is_allowed(s.metadata.finish_reason.as_ref()) + { + Some(Err(anyhow::anyhow!(crate::errors::ExposedError::FinishReasonError { + prompt: s.prompt.to_string(), + raw_output: s.content.clone(), + message: "Finish reason not allowed".to_string(), + finish_reason: s.metadata.finish_reason.clone(), + }))) + } else { + Some(parse_fn(&s.content)) + } + }, _ => None, }; let (parsed_response, response_value) = match parsed_response { diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs index 20eac7fee..a285630e7 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs @@ -86,6 +86,15 @@ impl WithClientProperties for AnthropicClient { .stream .unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for AnthropicClient { @@ -253,13 +262,13 @@ impl WithStreamChat for AnthropicClient { impl AnthropicClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -275,14 +284,14 @@ impl AnthropicClient { } pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { - let properties = resolve_properties(&client.elem().provider, client.options(), ctx)?; - let default_role = properties.default_role.clone(); + let properties = resolve_properties(&client.elem().provider, &client.options(), ctx)?; Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -361,13 +370,6 @@ impl RequestBuilder for AnthropicClient { } impl WithChat for AnthropicClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { let (response, system_now, instant_now) = match make_parsed_request::< AnthropicMessageResponse, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs index 6d203399f..69c223642 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs @@ -65,14 +65,14 @@ fn resolve_properties( impl AwsClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -87,15 +87,15 @@ impl AwsClient { } pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { - let properties = resolve_properties(&client.elem().provider, client.options(), ctx)?; - let default_role = properties.default_role.clone(); // clone before moving + let properties = resolve_properties(&client.elem().provider, &client.options(), ctx)?; Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -282,6 +282,15 @@ impl WithClientProperties for AwsClient { .stream .unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for AwsClient { @@ -589,13 +598,6 @@ impl AwsClient { } impl WithChat for AwsClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat( &self, _ctx: &RuntimeContext, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs index bf6dd67c9..8ba6c8639 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs @@ -75,6 +75,15 @@ impl WithClientProperties for GoogleAIClient { .stream .unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for GoogleAIClient { @@ -157,8 +166,8 @@ impl SseResponseTrait for GoogleAIClient { } }; - if let Some(choice) = event.candidates.first() { - if let Some(content) = choice.content.parts.first() { + if let Some(choice) = event.candidates.get(0) { + if let Some(content) = choice.content.as_ref().and_then(|c| c.parts.get(0)) { inner.content += &content.text; } if let Some(FinishReason::Stop) = choice.finish_reason.as_ref() { @@ -193,14 +202,14 @@ impl WithStreamChat for GoogleAIClient { impl GoogleAIClient { pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { - let properties = resolve_properties(&client.elem().provider, client.options(), ctx)?; - let default_role = properties.default_role.clone(); + let properties = resolve_properties(&client.elem().provider, &client.options(), ctx)?; Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -221,14 +230,14 @@ impl GoogleAIClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -301,13 +310,6 @@ impl RequestBuilder for GoogleAIClient { } impl WithChat for GoogleAIClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { //non-streaming, complete response is returned let (response, system_now, instant_now) = @@ -334,10 +336,23 @@ impl WithChat for GoogleAIClient { }); } + let Some(content) = response.candidates[0].content.as_ref() else { + return LLMResponse::LLMFailure(LLMErrorResponse { + client: self.context.name.to_string(), + model: None, + prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.to_vec()), + start_time: system_now, + request_options: self.properties.properties.clone(), + latency: instant_now.elapsed(), + message: "No content returned".to_string(), + code: ErrorCode::Other(200), + }); + }; + LLMResponse::Success(LLMCompleteResponse { client: self.context.name.to_string(), prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.to_vec()), - content: response.candidates[0].content.parts[0].text.clone(), + content: content.parts[0].text.clone(), start_time: system_now, latency: instant_now.elapsed(), request_options: self.properties.properties.clone(), diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs b/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs index 9edfbae11..976f5b8c9 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs @@ -221,7 +221,7 @@ pub enum HarmSeverity { #[serde(rename_all = "camelCase")] pub struct Candidate { pub index: Option, - pub content: Content, + pub content: Option, pub finish_reason: Option, pub safety_ratings: Option>, // pub citation_metadata: Option, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs b/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs index 541873236..c6c887f21 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/mod.rs @@ -94,6 +94,15 @@ impl WithClientProperties for LLMPrimitiveProvider { fn supports_streaming(&self) -> bool { match_llm_provider!(self, supports_streaming) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + match_llm_provider!(self, finish_reason_filter) + } + fn default_role(&self) -> String { + match_llm_provider!(self, default_role) + } + fn allowed_roles(&self) -> Vec { + match_llm_provider!(self, allowed_roles) + } } impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMPrimitiveProvider { diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs index 834e58c3e..398fa0e52 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs @@ -6,7 +6,7 @@ use baml_types::{BamlMap, BamlMedia, BamlMediaContent, BamlMediaType}; use internal_baml_core::ir::ClientWalker; use internal_baml_jinja::{ChatMessagePart, RenderContext_Client, RenderedChatMessage}; use internal_llm_client::openai::ResolvedOpenAI; -use internal_llm_client::AllowedRoleMetadata; +use internal_llm_client::{AllowedRoleMetadata, FinishReasonFilter}; use serde_json::json; use crate::internal::llm_client::{ @@ -14,7 +14,7 @@ use crate::internal::llm_client::{ }; use super::properties; -use super::types::{ChatCompletionResponse, ChatCompletionResponseDelta, FinishReason}; +use super::types::{ChatCompletionResponse, ChatCompletionResponseDelta}; use crate::client_registry::ClientProperty; use crate::internal::llm_client::primitive::request::{ @@ -56,19 +56,21 @@ impl WithClientProperties for OpenAIClient { fn allowed_metadata(&self) -> &AllowedRoleMetadata { &self.properties.allowed_metadata } + + fn finish_reason_filter(&self) -> &FinishReasonFilter { + &self.properties.finish_reason_filter + } + + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } + + fn default_role(&self) -> String { + self.properties.default_role() + } + fn supports_streaming(&self) -> bool { - match self.properties.supported_request_modes.stream { - Some(v) => v, - None => { - match self.properties.properties.get("model") { - Some(serde_json::Value::String(model)) => { - // OpenAI's streaming is not available for o1-* models - !model.starts_with("o1-") - } - _ => true, - } - } - } + self.properties.supports_streaming() } } @@ -155,13 +157,6 @@ impl WithNoCompletion for OpenAIClient {} // } impl WithChat for OpenAIClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { let (response, system_start, instant_start) = match make_parsed_request::( @@ -207,16 +202,12 @@ impl WithChat for OpenAIClient { model: response.model, request_options: self.properties.properties.clone(), metadata: LLMCompleteResponseMetadata { - baml_is_complete: match response.choices.first() { - Some(c) => matches!(c.finish_reason, Some(FinishReason::Stop)), + baml_is_complete: match response.choices.get(0) { + Some(c) => c.finish_reason.as_ref().is_some_and(|f| f == "stop"), None => false, }, - finish_reason: match response.choices.first() { - Some(c) => match c.finish_reason { - Some(FinishReason::Stop) => Some(FinishReason::Stop.to_string()), - Some(other) => Some(other.to_string()), - _ => None, - }, + finish_reason: match response.choices.get(0) { + Some(c) => c.finish_reason.clone(), None => None, }, prompt_tokens: usage.map(|u| u.prompt_tokens), @@ -374,18 +365,8 @@ impl SseResponseTrait for OpenAIClient { inner.content += content.as_str(); } inner.model = event.model; - match choice.finish_reason.as_ref() { - Some(FinishReason::Stop) => { - inner.metadata.baml_is_complete = true; - inner.metadata.finish_reason = - Some(FinishReason::Stop.to_string()); - } - finish_reason => { - inner.metadata.baml_is_complete = false; - inner.metadata.finish_reason = - finish_reason.as_ref().map(|r| r.to_string()); - } - } + inner.metadata.finish_reason = choice.finish_reason.clone(); + inner.metadata.baml_is_complete = choice.finish_reason.as_ref().is_some_and(|s| s == "stop"); } inner.latency = instant_start.elapsed(); if let Some(usage) = event.usage.as_ref() { @@ -424,7 +405,8 @@ macro_rules! make_openai_client { context: RenderContext_Client { name: $client.name.clone(), provider: $client.provider.to_string(), - default_role: $properties.default_role.clone(), + default_role: $properties.default_role(), + allowed_roles: $properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -445,7 +427,8 @@ macro_rules! make_openai_client { context: RenderContext_Client { name: $client.name().into(), provider: $client.elem().provider.to_string(), - default_role: $properties.default_role.clone(), + default_role: $properties.default_role(), + allowed_roles: $properties.allowed_roles(), }, features: ModelFeatures { chat: true, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs index d457b2dea..377ca6402 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/types.rs @@ -28,7 +28,7 @@ pub struct ChatCompletionGeneric { #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CompletionChoice { - pub finish_reason: Option, + pub finish_reason: Option, pub index: u32, pub text: String, } @@ -42,7 +42,7 @@ pub struct ChatCompletionChoice { /// `length` if the maximum number of tokens specified in the request was reached, /// `content_filter` if content was omitted due to a flag from our content filters, /// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. - pub finish_reason: Option, + pub finish_reason: Option, /// Log probability information for the choice. pub logprobs: Option, } @@ -78,7 +78,7 @@ pub struct ChatCompletionResponseMessage { #[derive(Deserialize, Clone, Debug)] pub struct ChatCompletionChoiceDelta { pub index: u64, - pub finish_reason: Option, + pub finish_reason: Option, pub delta: ChatCompletionMessageDelta, } @@ -99,7 +99,7 @@ pub struct ChatCompletionMessageDelta { // pub function_call: Option, } -#[derive(Debug, Deserialize, Clone, Copy, Default, PartialEq)] +#[derive(Debug, Deserialize, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ChatCompletionMessageRole { System, @@ -110,19 +110,6 @@ pub enum ChatCompletionMessageRole { Function, } -#[derive(Debug, Deserialize, strum_macros::Display, Clone, Copy, PartialEq, Serialize)] -#[serde(rename_all = "snake_case")] -#[strum(serialize_all = "snake_case")] -pub enum FinishReason { - Stop, - Length, - ToolCalls, - ContentFilter, - FunctionCall, - #[serde(other)] - Unknown, -} - #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ChatChoiceLogprobs { /// A list of message content tokens with log probability information. diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs index c5990cdb9..ec376f9c6 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs @@ -113,6 +113,15 @@ impl WithClientProperties for VertexClient { .stream .unwrap_or(true) } + fn finish_reason_filter(&self) -> &internal_llm_client::FinishReasonFilter { + &self.properties.finish_reason_filter + } + fn default_role(&self) -> String { + self.properties.default_role() + } + fn allowed_roles(&self) -> Vec { + self.properties.allowed_roles() + } } impl WithClient for VertexClient { @@ -236,13 +245,13 @@ impl WithStreamChat for VertexClient { impl VertexClient { pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.elem().provider, client.options(), ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name().into(), context: RenderContext_Client { name: client.name().into(), provider: client.elem().provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -263,14 +272,14 @@ impl VertexClient { pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result { let properties = resolve_properties(&client.provider, &client.unresolved_options()?, ctx)?; - let default_role = properties.default_role.clone(); Ok(Self { name: client.name.clone(), context: RenderContext_Client { name: client.name.clone(), provider: client.provider.to_string(), - default_role, + default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -390,13 +399,6 @@ impl RequestBuilder for VertexClient { } impl WithChat for VertexClient { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - Ok(internal_baml_jinja::ChatOptions::new( - self.properties.default_role.clone(), - None, - )) - } - async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { //non-streaming, complete response is returned let (response, system_now, instant_now) = diff --git a/engine/baml-runtime/src/internal/llm_client/traits/chat.rs b/engine/baml-runtime/src/internal/llm_client/traits/chat.rs index 2ee26e1ca..a94f09093 100644 --- a/engine/baml-runtime/src/internal/llm_client/traits/chat.rs +++ b/engine/baml-runtime/src/internal/llm_client/traits/chat.rs @@ -5,9 +5,20 @@ use crate::{internal::llm_client::LLMResponse, RuntimeContext}; use super::StreamResponse; -pub trait WithChat: Sync + Send { +pub trait WithChatOptions { fn chat_options(&self, ctx: &RuntimeContext) -> Result; +} +impl WithChatOptions for T +where + T: super::WithClientProperties, +{ + fn chat_options(&self, ctx: &RuntimeContext) -> Result { + Ok(ChatOptions::new(self.default_role(), Some(self.allowed_roles()))) + } +} + +pub trait WithChat: Sync + Send + WithChatOptions { #[allow(async_fn_in_trait)] async fn chat(&self, ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse; } @@ -25,12 +36,8 @@ pub trait WithNoChat {} impl WithChat for T where - T: WithNoChat + Send + Sync, + T: WithNoChat + Send + Sync + WithChatOptions, { - fn chat_options(&self, _ctx: &RuntimeContext) -> Result { - anyhow::bail!("Chat prompts are not supported by this provider") - } - #[allow(async_fn_in_trait)] async fn chat(&self, _: &RuntimeContext, _: &[RenderedChatMessage]) -> LLMResponse { LLMResponse::InternalFailure("Chat prompts are not supported by this provider".to_string()) diff --git a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs index ad43007e2..6e5172bb2 100644 --- a/engine/baml-runtime/src/internal/llm_client/traits/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/traits/mod.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, path::PathBuf, pin::Pin}; use anyhow::{Context, Result}; use aws_smithy_types::byte_stream::error::Error; -use internal_llm_client::AllowedRoleMetadata; +use internal_llm_client::{AllowedRoleMetadata, FinishReasonFilter}; use serde_json::{json, Map}; mod chat; @@ -36,6 +36,9 @@ pub trait WithRetryPolicy { pub trait WithClientProperties { fn allowed_metadata(&self) -> &AllowedRoleMetadata; fn supports_streaming(&self) -> bool; + fn finish_reason_filter(&self) -> &FinishReasonFilter; + fn default_role(&self) -> String; + fn allowed_roles(&self) -> Vec; } pub trait WithSingleCallable { @@ -143,10 +146,11 @@ pub trait WithRenderRawCurl { impl WithSingleCallable for T where - T: WithClient + WithChat + WithCompletion, + T: WithClient + WithChat + WithCompletion + WithClientProperties, { #[allow(async_fn_in_trait)] async fn single_call(&self, ctx: &RuntimeContext, prompt: &RenderedPrompt) -> LLMResponse { + log::warn!("debug single_call start: {:?}", prompt); if let RenderedPrompt::Chat(chat) = &prompt { match process_media_urls( self.model_features().resolve_media_urls, diff --git a/engine/baml-runtime/src/types/response.rs b/engine/baml-runtime/src/types/response.rs index a99270f8b..2c04bee01 100644 --- a/engine/baml-runtime/src/types/response.rs +++ b/engine/baml-runtime/src/types/response.rs @@ -218,6 +218,7 @@ pub enum TestFailReason<'a> { TestUnspecified(anyhow::Error), TestLLMFailure(&'a LLMResponse), TestParseFailure(&'a anyhow::Error), + TestFinishReasonFailed(&'a anyhow::Error), TestConstraintsFailure { checks: Vec<(String, bool)>, failed_assert: Option, @@ -232,6 +233,9 @@ impl PartialEq for TestFailReason<'_> { (Self::TestParseFailure(a), Self::TestParseFailure(b)) => { a.to_string() == b.to_string() } + (Self::TestFinishReasonFailed(a), Self::TestFinishReasonFailed(b)) => { + a.to_string() == b.to_string() + } _ => false, } } @@ -265,9 +269,13 @@ impl TestResponse { } } } else { - TestStatus::Fail(TestFailReason::TestParseFailure( - parsed.as_ref().unwrap_err(), - )) + let err = parsed.as_ref().unwrap_err(); + match err.downcast_ref::() { + Some(ExposedError::FinishReasonError { .. }) => { + TestStatus::Fail(TestFailReason::TestFinishReasonFailed(&err)) + } + _ => TestStatus::Fail(TestFailReason::TestParseFailure(&err)), + } } } else { TestStatus::Fail(TestFailReason::TestLLMFailure(func_res.llm_response())) diff --git a/engine/baml-runtime/tests/harness.rs b/engine/baml-runtime/tests/harness.rs index d76899547..a1bf91b13 100644 --- a/engine/baml-runtime/tests/harness.rs +++ b/engine/baml-runtime/tests/harness.rs @@ -45,7 +45,7 @@ impl Harness { cmd.args(args.split_ascii_whitespace()); cmd.current_dir(&self.test_dir); // cmd.env("RUST_BACKTRACE", "1"); - cmd.env("BAML_LOG", "debug,jsonish=info"); + // cmd.env("BAML_LOG", "debug,jsonish=info"); Ok(cmd) } diff --git a/engine/baml-runtime/tests/test_cli.rs b/engine/baml-runtime/tests/test_cli.rs index b550a0c20..857e3903e 100644 --- a/engine/baml-runtime/tests/test_cli.rs +++ b/engine/baml-runtime/tests/test_cli.rs @@ -17,10 +17,7 @@ use serde_json::json; // Run this with cargo test --features internal // run the CLI using debug build using: engine/target/debug/baml-runtime dev -#[cfg(all( - not(feature = "skip-integ-tests"), - any(feature = "OPENAI_API_KEY", env = "OPENAI_API_KEY") -))] +#[cfg(not(feature = "skip-integ-tests"))] mod test_cli { use super::*; use pretty_assertions::assert_eq; diff --git a/engine/baml-runtime/tests/test_runtime.rs b/engine/baml-runtime/tests/test_runtime.rs index 3568bc475..e10a957a4 100644 --- a/engine/baml-runtime/tests/test_runtime.rs +++ b/engine/baml-runtime/tests/test_runtime.rs @@ -9,7 +9,7 @@ mod internal_tests { use baml_runtime::BamlRuntime; use std::sync::Once; - use baml_runtime::internal::llm_client::orchestrator::OrchestrationScope; + // use baml_runtime::internal::llm_client::orchestrator::OrchestrationScope; use baml_runtime::InternalRuntimeInterface; use baml_types::BamlValue; @@ -255,8 +255,6 @@ mod internal_tests { )?; log::info!("Runtime:"); - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) .create_ctx_with_default(); @@ -336,8 +334,6 @@ test ImageReceiptTest { "##, )?; - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) .create_ctx_with_default(); @@ -419,8 +415,6 @@ test TestName { "##, )?; - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) .create_ctx_with_default(); @@ -489,8 +483,6 @@ test TestTree { "##, )?; - let missing_env_vars = runtime.internal().ir().required_env_vars(); - let ctx = runtime .create_ctx_manager(BamlValue::String("test".to_string()), None) .create_ctx_with_default(); diff --git a/engine/baml-schema-wasm/src/lib.rs b/engine/baml-schema-wasm/src/lib.rs index dca44e5a7..243f94848 100644 --- a/engine/baml-schema-wasm/src/lib.rs +++ b/engine/baml-schema-wasm/src/lib.rs @@ -1,4 +1,6 @@ +#[cfg(target_arch = "wasm32")] pub mod runtime_wasm; + use std::env; use wasm_bindgen::prelude::*; diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index 3553f296a..aafa61af1 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -422,6 +422,7 @@ pub enum TestStatus { Passed, LLMFailure, ParseFailure, + FinishReasonFailed, ConstraintsFailed, AssertFailed, UnableToRun, @@ -597,6 +598,9 @@ impl WasmTestResponse { baml_runtime::TestFailReason::TestUnspecified(_) => TestStatus::UnableToRun, baml_runtime::TestFailReason::TestLLMFailure(_) => TestStatus::LLMFailure, baml_runtime::TestFailReason::TestParseFailure(_) => TestStatus::ParseFailure, + baml_runtime::TestFailReason::TestFinishReasonFailed(_) => { + TestStatus::FinishReasonFailed + } baml_runtime::TestFailReason::TestConstraintsFailure { failed_assert, .. } => { @@ -784,7 +788,20 @@ impl WithRenderError for baml_runtime::TestFailReason<'_> { match &self { baml_runtime::TestFailReason::TestUnspecified(e) => Some(format!("{e:#}")), baml_runtime::TestFailReason::TestLLMFailure(f) => f.render_error(), - baml_runtime::TestFailReason::TestParseFailure(e) => Some(format!("{e:#}")), + baml_runtime::TestFailReason::TestParseFailure(e) + | baml_runtime::TestFailReason::TestFinishReasonFailed(e) => { + match e.downcast_ref::() { + Some(exposed_error) => match exposed_error { + baml_runtime::errors::ExposedError::ValidationError { message, .. } => { + Some(message.clone()) + } + baml_runtime::errors::ExposedError::FinishReasonError { + message, .. + } => Some(message.clone()), + }, + None => Some(format!("{e:#}")), + } + } baml_runtime::TestFailReason::TestConstraintsFailure { checks, failed_assert, @@ -847,10 +864,10 @@ fn get_dummy_value( TypeValue::Bool => "true".to_string(), TypeValue::Null => "null".to_string(), TypeValue::Media(BamlMediaType::Image) => { - "{ url \"https://imgs.xkcd.com/comics/standards.png\"}".to_string() + "{ url \"https://imgs.xkcd.com/comics/standards.png\" }".to_string() } TypeValue::Media(BamlMediaType::Audio) => { - "{ url \"https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg\"}".to_string() + "{ url \"https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg\" }".to_string() } }; diff --git a/engine/baml-schema-wasm/tests/test_file_manager.rs b/engine/baml-schema-wasm/tests/test_file_manager.rs index 489a93431..dfed4997a 100644 --- a/engine/baml-schema-wasm/tests/test_file_manager.rs +++ b/engine/baml-schema-wasm/tests/test_file_manager.rs @@ -1,5 +1,6 @@ // Run from the baml-schema-wasm folder with: // wasm-pack test --node +#[cfg(target_arch = "wasm32")] #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py b/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py index 1e6ae231d..2743ab978 100644 --- a/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py +++ b/engine/language_client_python/python_src/baml_py/internal_monkeypatch.py @@ -1,5 +1,5 @@ from .baml_py import BamlError - +from typing import Optional # Define the BamlValidationError exception with additional fields # note on custom exceptions https://github.com/PyO3/pyo3/issues/295 @@ -16,3 +16,17 @@ def __str__(self): def __repr__(self): return f"BamlValidationError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt})" + +class BamlClientFinishReasonError(BamlError): + def __init__(self, prompt: str, message: str, raw_output: str, finish_reason: Optional[str]): + super().__init__(message) + self.prompt = prompt + self.message = message + self.raw_output = raw_output + self.finish_reason = finish_reason + + def __str__(self): + return f"BamlClientFinishReasonError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt}, finish_reason={self.finish_reason})" + + def __repr__(self): + return f"BamlClientFinishReasonError(message={self.message}, raw_output={self.raw_output}, prompt={self.prompt}, finish_reason={self.finish_reason})" diff --git a/engine/language_client_python/src/errors.rs b/engine/language_client_python/src/errors.rs index a62cbeef6..3912da6b0 100644 --- a/engine/language_client_python/src/errors.rs +++ b/engine/language_client_python/src/errors.rs @@ -25,6 +25,17 @@ fn raise_baml_validation_error(prompt: String, message: String, raw_output: Stri }) } +#[allow(non_snake_case)] +fn raise_baml_client_finish_reason_error(prompt: String, raw_output: String, message: String, finish_reason: Option) -> PyErr { + Python::with_gil(|py| { + let internal_monkeypatch = py.import("baml_py.internal_monkeypatch").unwrap(); + let exception = internal_monkeypatch.getattr("BamlClientFinishReasonError").unwrap(); + let args = (prompt, message, raw_output, finish_reason); + let inst = exception.call1(args).unwrap(); + PyErr::from_value(inst) + }) +} + /// Defines the errors module with the BamlValidationError exception. /// IIRC the name of this function is the name of the module that pyo3 generates (errors.py) #[pymodule] @@ -64,6 +75,14 @@ impl BamlError { // If not, you may need to adjust this part based on the actual structure of ValidationError raise_baml_validation_error(prompt.clone(), message.clone(), raw_output.clone()) } + ExposedError::FinishReasonError { + prompt, + raw_output, + message, + finish_reason, + } => { + raise_baml_client_finish_reason_error(prompt.clone(), raw_output.clone(), message.clone(), finish_reason.clone()) + } } } else if let Some(er) = err.downcast_ref::() { PyErr::new::(format!("Invalid argument: {}", er)) diff --git a/engine/language_client_typescript/index.d.ts b/engine/language_client_typescript/index.d.ts index dafbd2ff7..96f7782aa 100644 --- a/engine/language_client_typescript/index.d.ts +++ b/engine/language_client_typescript/index.d.ts @@ -1,12 +1,19 @@ -export { BamlRuntime, FunctionResult, FunctionResultStream, BamlImage as Image, ClientBuilder, BamlAudio as Audio, invoke_runtime_cli, ClientRegistry, BamlLogEvent, } from './native'; -export { BamlStream } from './stream'; -export { BamlCtxManager } from './async_context_vars'; +export { BamlRuntime, FunctionResult, FunctionResultStream, BamlImage as Image, ClientBuilder, BamlAudio as Audio, invoke_runtime_cli, ClientRegistry, BamlLogEvent, } from "./native"; +export { BamlStream } from "./stream"; +export { BamlCtxManager } from "./async_context_vars"; +export declare class BamlClientFinishReasonError extends Error { + prompt: string; + raw_output: string; + constructor(prompt: string, raw_output: string, message: string); + toJSON(): string; + static from(error: Error): BamlClientFinishReasonError | undefined; +} export declare class BamlValidationError extends Error { prompt: string; raw_output: string; constructor(prompt: string, raw_output: string, message: string); - static from(error: Error): BamlValidationError | Error; toJSON(): string; + static from(error: Error): BamlValidationError | undefined; } -export declare function createBamlValidationError(error: Error): BamlValidationError | Error; +export declare function createBamlValidationError(error: Error): BamlValidationError | BamlClientFinishReasonError | Error; //# sourceMappingURL=index.d.ts.map \ No newline at end of file diff --git a/engine/language_client_typescript/index.d.ts.map b/engine/language_client_typescript/index.d.ts.map index 678922c63..c1b009247 100644 --- a/engine/language_client_typescript/index.d.ts.map +++ b/engine/language_client_typescript/index.d.ts.map @@ -1 +1 @@ -{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["typescript_src/index.ts"],"names":[],"mappings":"AAAA,OAAO,EACL,WAAW,EACX,cAAc,EACd,oBAAoB,EACpB,SAAS,IAAI,KAAK,EAClB,aAAa,EACb,SAAS,IAAI,KAAK,EAClB,kBAAkB,EAClB,cAAc,EACd,YAAY,GACb,MAAM,UAAU,CAAA;AACjB,OAAO,EAAE,UAAU,EAAE,MAAM,UAAU,CAAA;AACrC,OAAO,EAAE,cAAc,EAAE,MAAM,sBAAsB,CAAA;AAErD,qBAAa,mBAAoB,SAAQ,KAAK;IAC5C,MAAM,EAAE,MAAM,CAAA;IACd,UAAU,EAAE,MAAM,CAAA;gBAEN,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM;IAS/D,MAAM,CAAC,IAAI,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,KAAK;IAuBtD,MAAM,IAAI,MAAM;CAWjB;AAGD,wBAAgB,yBAAyB,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,KAAK,CAEnF"} \ No newline at end of file +{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["typescript_src/index.ts"],"names":[],"mappings":"AAAA,OAAO,EACL,WAAW,EACX,cAAc,EACd,oBAAoB,EACpB,SAAS,IAAI,KAAK,EAClB,aAAa,EACb,SAAS,IAAI,KAAK,EAClB,kBAAkB,EAClB,cAAc,EACd,YAAY,GACb,MAAM,UAAU,CAAC;AAClB,OAAO,EAAE,UAAU,EAAE,MAAM,UAAU,CAAC;AACtC,OAAO,EAAE,cAAc,EAAE,MAAM,sBAAsB,CAAC;AAEtD,qBAAa,2BAA4B,SAAQ,KAAK;IACpD,MAAM,EAAE,MAAM,CAAC;IACf,UAAU,EAAE,MAAM,CAAC;gBAEP,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM;IAS/D,MAAM,IAAI,MAAM;IAahB,MAAM,CAAC,IAAI,CAAC,KAAK,EAAE,KAAK,GAAG,2BAA2B,GAAG,SAAS;CAoBnE;AAED,qBAAa,mBAAoB,SAAQ,KAAK;IAC5C,MAAM,EAAE,MAAM,CAAC;IACf,UAAU,EAAE,MAAM,CAAC;gBAEP,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM;IAS/D,MAAM,IAAI,MAAM;IAahB,MAAM,CAAC,IAAI,CAAC,KAAK,EAAE,KAAK,GAAG,mBAAmB,GAAG,SAAS;CAiB3D;AAGD,wBAAgB,yBAAyB,CACvC,KAAK,EAAE,KAAK,GACX,mBAAmB,GAAG,2BAA2B,GAAG,KAAK,CAa3D"} \ No newline at end of file diff --git a/engine/language_client_typescript/index.js b/engine/language_client_typescript/index.js index c6a62e207..cdef3fe1b 100644 --- a/engine/language_client_typescript/index.js +++ b/engine/language_client_typescript/index.js @@ -1,6 +1,6 @@ "use strict"; Object.defineProperty(exports, "__esModule", { value: true }); -exports.createBamlValidationError = exports.BamlValidationError = exports.BamlCtxManager = exports.BamlStream = exports.BamlLogEvent = exports.ClientRegistry = exports.invoke_runtime_cli = exports.Audio = exports.ClientBuilder = exports.Image = exports.FunctionResultStream = exports.FunctionResult = exports.BamlRuntime = void 0; +exports.createBamlValidationError = exports.BamlValidationError = exports.BamlClientFinishReasonError = exports.BamlCtxManager = exports.BamlStream = exports.BamlLogEvent = exports.ClientRegistry = exports.invoke_runtime_cli = exports.Audio = exports.ClientBuilder = exports.Image = exports.FunctionResultStream = exports.FunctionResult = exports.BamlRuntime = void 0; var native_1 = require("./native"); Object.defineProperty(exports, "BamlRuntime", { enumerable: true, get: function () { return native_1.BamlRuntime; } }); Object.defineProperty(exports, "FunctionResult", { enumerable: true, get: function () { return native_1.FunctionResult; } }); @@ -15,47 +15,90 @@ var stream_1 = require("./stream"); Object.defineProperty(exports, "BamlStream", { enumerable: true, get: function () { return stream_1.BamlStream; } }); var async_context_vars_1 = require("./async_context_vars"); Object.defineProperty(exports, "BamlCtxManager", { enumerable: true, get: function () { return async_context_vars_1.BamlCtxManager; } }); -class BamlValidationError extends Error { +class BamlClientFinishReasonError extends Error { prompt; raw_output; constructor(prompt, raw_output, message) { super(message); - this.name = 'BamlValidationError'; + this.name = "BamlClientFinishReasonError"; this.prompt = prompt; this.raw_output = raw_output; - Object.setPrototypeOf(this, BamlValidationError.prototype); + Object.setPrototypeOf(this, BamlClientFinishReasonError.prototype); + } + toJSON() { + return JSON.stringify({ + name: this.name, + message: this.message, + raw_output: this.raw_output, + prompt: this.prompt, + }, null, 2); } static from(error) { - if (error.message.includes('BamlValidationError')) { + if (error.message.includes("BamlClientFinishReasonError")) { try { const errorData = JSON.parse(error.message); - if (errorData.type === 'BamlValidationError') { - return new BamlValidationError(errorData.prompt || '', errorData.raw_output || '', errorData.message || error.message); + if (errorData.type === "BamlClientFinishReasonError") { + return new BamlClientFinishReasonError(errorData.prompt || "", errorData.raw_output || "", errorData.message || error.message); } else { - console.warn('Not a BamlValidationError:', error); + console.warn("Not a BamlClientFinishReasonError:", error); } } catch (parseError) { // If JSON parsing fails, fall back to the original error - console.warn('Failed to parse BamlValidationError:', parseError); + console.warn("Failed to parse BamlClientFinishReasonError:", parseError); } } - // If it's not a BamlValidationError or parsing failed, return the original error - return error; + return undefined; + } +} +exports.BamlClientFinishReasonError = BamlClientFinishReasonError; +class BamlValidationError extends Error { + prompt; + raw_output; + constructor(prompt, raw_output, message) { + super(message); + this.name = "BamlValidationError"; + this.prompt = prompt; + this.raw_output = raw_output; + Object.setPrototypeOf(this, BamlValidationError.prototype); } toJSON() { return JSON.stringify({ + name: this.name, message: this.message, raw_output: this.raw_output, prompt: this.prompt, }, null, 2); } + static from(error) { + if (error.message.includes("BamlValidationError")) { + try { + const errorData = JSON.parse(error.message); + if (errorData.type === "BamlValidationError") { + return new BamlValidationError(errorData.prompt || "", errorData.raw_output || "", errorData.message || error.message); + } + } + catch (parseError) { + console.warn("Failed to parse BamlValidationError:", parseError); + } + } + return undefined; + } } exports.BamlValidationError = BamlValidationError; // Helper function to safely create a BamlValidationError function createBamlValidationError(error) { - return BamlValidationError.from(error); + const bamlValidationError = BamlValidationError.from(error); + if (bamlValidationError) { + return bamlValidationError; + } + const bamlClientFinishReasonError = BamlClientFinishReasonError.from(error); + if (bamlClientFinishReasonError) { + return bamlClientFinishReasonError; + } + // otherwise return the original error + return error; } exports.createBamlValidationError = createBamlValidationError; // No need for a separate throwBamlValidationError function in TypeScript diff --git a/engine/language_client_typescript/src/errors.rs b/engine/language_client_typescript/src/errors.rs index 4af13f6ea..8cde59acc 100644 --- a/engine/language_client_typescript/src/errors.rs +++ b/engine/language_client_typescript/src/errors.rs @@ -20,6 +20,17 @@ pub fn from_anyhow_error(err: anyhow::Error) -> napi::Error { message, raw_output: raw_response, } => throw_baml_validation_error(prompt, raw_response, message), + ExposedError::FinishReasonError { + prompt, + message, + raw_output: raw_response, + finish_reason, + } => throw_baml_client_finish_reason_error( + prompt, + raw_response, + message, + finish_reason.as_ref().map(|f| f.as_str()), + ), } } else if let Some(er) = err.downcast_ref::() { invalid_argument_error(&format!("{}", er)) @@ -79,3 +90,14 @@ pub fn throw_baml_validation_error(prompt: &str, raw_output: &str, message: &str }); napi::Error::new(napi::Status::GenericFailure, error_json.to_string()) } + +pub fn throw_baml_client_finish_reason_error(prompt: &str, raw_output: &str, message: &str, finish_reason: Option<&str>) -> napi::Error { + let error_json = serde_json::json!({ + "type": "BamlClientFinishReasonError", + "prompt": prompt, + "raw_output": raw_output, + "message": format!("BamlClientFinishReasonError: {}", message), + "finish_reason": finish_reason, + }); + napi::Error::new(napi::Status::GenericFailure, error_json.to_string()) +} diff --git a/engine/language_client_typescript/typescript_src/index.ts b/engine/language_client_typescript/typescript_src/index.ts index 27bd94611..d0aff086d 100644 --- a/engine/language_client_typescript/typescript_src/index.ts +++ b/engine/language_client_typescript/typescript_src/index.ts @@ -8,62 +8,119 @@ export { invoke_runtime_cli, ClientRegistry, BamlLogEvent, -} from './native' -export { BamlStream } from './stream' -export { BamlCtxManager } from './async_context_vars' +} from "./native"; +export { BamlStream } from "./stream"; +export { BamlCtxManager } from "./async_context_vars"; -export class BamlValidationError extends Error { - prompt: string - raw_output: string +export class BamlClientFinishReasonError extends Error { + prompt: string; + raw_output: string; constructor(prompt: string, raw_output: string, message: string) { - super(message) - this.name = 'BamlValidationError' - this.prompt = prompt - this.raw_output = raw_output + super(message); + this.name = "BamlClientFinishReasonError"; + this.prompt = prompt; + this.raw_output = raw_output; - Object.setPrototypeOf(this, BamlValidationError.prototype) + Object.setPrototypeOf(this, BamlClientFinishReasonError.prototype); } - static from(error: Error): BamlValidationError | Error { - if (error.message.includes('BamlValidationError')) { + toJSON(): string { + return JSON.stringify( + { + name: this.name, + message: this.message, + raw_output: this.raw_output, + prompt: this.prompt, + }, + null, + 2 + ); + } + + static from(error: Error): BamlClientFinishReasonError | undefined { + if (error.message.includes("BamlClientFinishReasonError")) { try { - const errorData = JSON.parse(error.message) - if (errorData.type === 'BamlValidationError') { - return new BamlValidationError( - errorData.prompt || '', - errorData.raw_output || '', - errorData.message || error.message, - ) + const errorData = JSON.parse(error.message); + if (errorData.type === "BamlClientFinishReasonError") { + return new BamlClientFinishReasonError( + errorData.prompt || "", + errorData.raw_output || "", + errorData.message || error.message + ); } else { - console.warn('Not a BamlValidationError:', error) + console.warn("Not a BamlClientFinishReasonError:", error); } } catch (parseError) { // If JSON parsing fails, fall back to the original error - console.warn('Failed to parse BamlValidationError:', parseError) + console.warn("Failed to parse BamlClientFinishReasonError:", parseError); } } + return undefined; + } +} + +export class BamlValidationError extends Error { + prompt: string; + raw_output: string; + + constructor(prompt: string, raw_output: string, message: string) { + super(message); + this.name = "BamlValidationError"; + this.prompt = prompt; + this.raw_output = raw_output; - // If it's not a BamlValidationError or parsing failed, return the original error - return error + Object.setPrototypeOf(this, BamlValidationError.prototype); } toJSON(): string { return JSON.stringify( { + name: this.name, message: this.message, raw_output: this.raw_output, prompt: this.prompt, }, null, - 2, - ) + 2 + ); + } + + static from(error: Error): BamlValidationError | undefined { + if (error.message.includes("BamlValidationError")) { + try { + const errorData = JSON.parse(error.message); + if (errorData.type === "BamlValidationError") { + return new BamlValidationError( + errorData.prompt || "", + errorData.raw_output || "", + errorData.message || error.message + ); + } + } catch (parseError) { + console.warn("Failed to parse BamlValidationError:", parseError); + } + } + return undefined; } } // Helper function to safely create a BamlValidationError -export function createBamlValidationError(error: Error): BamlValidationError | Error { - return BamlValidationError.from(error) +export function createBamlValidationError( + error: Error +): BamlValidationError | BamlClientFinishReasonError | Error { + const bamlValidationError = BamlValidationError.from(error); + if (bamlValidationError) { + return bamlValidationError; + } + + const bamlClientFinishReasonError = BamlClientFinishReasonError.from(error); + if (bamlClientFinishReasonError) { + return bamlClientFinishReasonError; + } + + // otherwise return the original error + return error; } // No need for a separate throwBamlValidationError function in TypeScript diff --git a/integ-tests/typescript/test-report.html b/integ-tests/typescript/test-report.html index 550bec660..d497dffff 100644 --- a/integ-tests/typescript/test-report.html +++ b/integ-tests/typescript/test-report.html @@ -257,9 +257,4 @@ font-size: 1rem; padding: 0 0.5rem; } -

Test Report

Started: 2024-12-01 23:40:27
Suites (1)
0 passed
1 failed
0 pending
Tests (67)
65 passed
2 failed
0 pending
Integ tests > should work for all inputs
single bool
passed
2.601s
Integ tests > should work for all inputs
single string list
passed
1.845s
Integ tests > should work for all inputs
return literal union
passed
0.507s
Integ tests > should work for all inputs
single class
passed
0.713s
Integ tests > should work for all inputs
multiple classes
passed
0.605s
Integ tests > should work for all inputs
single enum list
passed
0.517s
Integ tests > should work for all inputs
single float
passed
0.512s
Integ tests > should work for all inputs
single int
passed
0.606s
Integ tests > should work for all inputs
single literal int
passed
0.502s
Integ tests > should work for all inputs
single literal bool
passed
0.45s
Integ tests > should work for all inputs
single literal string
passed
0.443s
Integ tests > should work for all inputs
single class with literal prop
passed
0.864s
Integ tests > should work for all inputs
single class with literal union prop
passed
0.613s
Integ tests > should work for all inputs
single optional string
passed
0.466s
Integ tests > should work for all inputs
single map string to string
passed
0.594s
Integ tests > should work for all inputs
single map string to class
passed
0.862s
Integ tests > should work for all inputs
single map string to map
passed
0.633s
Integ tests > should work for all inputs
enum key in map
passed
1.007s
Integ tests > should work for all inputs
literal string union key in map
passed
0.899s
Integ tests > should work for all inputs
single literal string key in map
passed
0.64s
Integ tests
should work for all outputs
passed
5.602s
Integ tests
works with retries1
passed
2.254s
Integ tests
works with retries2
passed
2.803s
Integ tests
works with fallbacks
passed
2.568s
Integ tests
should work with image from url
passed
1.367s
Integ tests
should work with image from base 64
passed
9.909s
Integ tests
should work with audio base 64
passed
1.823s
Integ tests
should work with audio from url
passed
2.006s
Integ tests
should support streaming in OpenAI
passed
2.046s
Integ tests
should support streaming in Gemini
passed
9.545s
Integ tests
should support AWS
passed
1.963s
Integ tests
should support streaming in AWS
passed
1.827s
Integ tests
should allow overriding the region
passed
0.13s
Integ tests
should support OpenAI shorthand
passed
9.001s
Integ tests
should support OpenAI shorthand streaming
passed
10.386s
Integ tests
should support anthropic shorthand
passed
2.572s
Integ tests
should support anthropic shorthand streaming
passed
7.519s
Integ tests
should support streaming without iterating
passed
1.855s
Integ tests
should support streaming in Claude
passed
1.131s
Integ tests
should support vertex
failed
0.003s
Error: BamlError: Failed to read service account file: 
-
-Caused by:
-    No such file or directory (os error 2)
Integ tests
supports tracing sync
passed
0.013s
Integ tests
supports tracing async
passed
4.669s
Integ tests
should work with dynamic types single
passed
1.032s
Integ tests
should work with dynamic types enum
passed
1.022s
Integ tests
should work with dynamic literals
passed
1.126s
Integ tests
should work with dynamic types class
passed
1.129s
Integ tests
should work with dynamic inputs class
passed
0.582s
Integ tests
should work with dynamic inputs list
passed
0.699s
Integ tests
should work with dynamic output map
passed
0.81s
Integ tests
should work with dynamic output union
passed
1.702s
Integ tests
should work with nested classes
failed
0.119s
Error: BamlError: BamlClientError: Something went wrong with the LLM client: reqwest::Error { kind: Request, url: Url { scheme: "http", cannot_be_a_base: false, username: "", password: None, host: Some(Domain("localhost")), port: Some(11434), path: "/v1/chat/completions", query: None, fragment: None }, source: hyper_util::client::legacy::Error(Connect, ConnectError("tcp connect error", Os { code: 111, kind: ConnectionRefused, message: "Connection refused" })) }
-    at BamlStream.parsed [as getFinalResponse] (/workspaces/baml/engine/language_client_typescript/stream.js:58:39)
-    at Object.<anonymous> (/workspaces/baml/integ-tests/typescript/tests/integ-tests.test.ts:602:19)
Integ tests
should work with dynamic client
passed
0.793s
Integ tests
should work with 'onLogEvent'
passed
2.354s
Integ tests
should work with a sync client
passed
0.481s
Integ tests
should raise an error when appropriate
passed
1.575s
Integ tests
should raise a BAMLValidationError
passed
0.448s
Integ tests
should reset environment variables correctly
passed
2.103s
Integ tests
should use aliases when serializing input objects - classes
passed
1.127s
Integ tests
should use aliases when serializing, but still have original keys in jinja
passed
1.125s
Integ tests
should use aliases when serializing input objects - enums
passed
0.612s
Integ tests
should use aliases when serializing input objects - lists
passed
0.546s
Integ tests
constraints: should handle checks in return types
passed
0.686s
Integ tests
constraints: should handle checks in returned unions
passed
0.732s
Integ tests
constraints: should handle block-level checks
passed
1.088s
Integ tests
constraints: should handle nested-block-level checks
passed
0.669s
Integ tests
simple recursive type
passed
2.629s
Integ tests
mutually recursive type
passed
1.846s
\ No newline at end of file +

Test Report

Started: 2024-12-03 15:57:41
Suites (1)
1 passed
0 failed
0 pending
Tests (67)
1 passed
0 failed
66 pending
Integ tests > should work for all inputs
single bool
pending
0s
Integ tests > should work for all inputs
single string list
pending
0s
Integ tests > should work for all inputs
return literal union
pending
0s
Integ tests > should work for all inputs
single class
pending
0s
Integ tests > should work for all inputs
multiple classes
pending
0s
Integ tests > should work for all inputs
single enum list
pending
0s
Integ tests > should work for all inputs
single float
pending
0s
Integ tests > should work for all inputs
single int
pending
0s
Integ tests > should work for all inputs
single literal int
pending
0s
Integ tests > should work for all inputs
single literal bool
pending
0s
Integ tests > should work for all inputs
single literal string
pending
0s
Integ tests > should work for all inputs
single class with literal prop
pending
0s
Integ tests > should work for all inputs
single class with literal union prop
pending
0s
Integ tests > should work for all inputs
single optional string
pending
0s
Integ tests > should work for all inputs
single map string to string
pending
0s
Integ tests > should work for all inputs
single map string to class
pending
0s
Integ tests > should work for all inputs
single map string to map
pending
0s
Integ tests > should work for all inputs
enum key in map
pending
0s
Integ tests > should work for all inputs
literal string union key in map
pending
0s
Integ tests > should work for all inputs
single literal string key in map
pending
0s
Integ tests
should work for all outputs
pending
0s
Integ tests
works with retries1
pending
0s
Integ tests
works with retries2
pending
0s
Integ tests
works with fallbacks
pending
0s
Integ tests
should work with image from url
pending
0s
Integ tests
should work with image from base 64
pending
0s
Integ tests
should work with audio base 64
pending
0s
Integ tests
should work with audio from url
pending
0s
Integ tests
should support streaming in OpenAI
pending
0s
Integ tests
should support streaming in Gemini
pending
0s
Integ tests
should support AWS
pending
0s
Integ tests
should support streaming in AWS
pending
0s
Integ tests
should allow overriding the region
pending
0s
Integ tests
should support OpenAI shorthand
pending
0s
Integ tests
should support OpenAI shorthand streaming
pending
0s
Integ tests
should support anthropic shorthand
pending
0s
Integ tests
should support anthropic shorthand streaming
pending
0s
Integ tests
should support streaming without iterating
pending
0s
Integ tests
should support streaming in Claude
pending
0s
Integ tests
should support vertex
pending
0s
Integ tests
supports tracing sync
pending
0s
Integ tests
supports tracing async
pending
0s
Integ tests
should work with dynamic types single
pending
0s
Integ tests
should work with dynamic types enum
pending
0s
Integ tests
should work with dynamic literals
pending
0s
Integ tests
should work with dynamic types class
pending
0s
Integ tests
should work with dynamic inputs class
pending
0s
Integ tests
should work with dynamic inputs list
pending
0s
Integ tests
should work with dynamic output map
pending
0s
Integ tests
should work with dynamic output union
pending
0s
Integ tests
should work with nested classes
pending
0s
Integ tests
should work with dynamic client
pending
0s
Integ tests
should work with 'onLogEvent'
pending
0s
Integ tests
should work with a sync client
pending
0s
Integ tests
should raise an error when appropriate
pending
0s
Integ tests
should raise a BAMLValidationError
passed
0.875s
Integ tests
should reset environment variables correctly
pending
0s
Integ tests
should use aliases when serializing input objects - classes
pending
0s
Integ tests
should use aliases when serializing, but still have original keys in jinja
pending
0s
Integ tests
should use aliases when serializing input objects - enums
pending
0s
Integ tests
should use aliases when serializing input objects - lists
pending
0s
Integ tests
constraints: should handle checks in return types
pending
0s
Integ tests
constraints: should handle checks in returned unions
pending
0s
Integ tests
constraints: should handle block-level checks
pending
0s
Integ tests
constraints: should handle nested-block-level checks
pending
0s
Integ tests
simple recursive type
pending
0s
Integ tests
mutually recursive type
pending
0s
\ No newline at end of file diff --git a/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts b/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts index a534e7888..e9c202661 100644 --- a/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts +++ b/typescript/playground-common/src/baml_wasm_web/test_uis/testHooks.ts @@ -10,7 +10,13 @@ export const showTestsAtom = atom(false) export const showClientGraphAtom = atom(false) export type TestStatusType = 'queued' | 'running' | 'done' | 'error' -export type DoneTestStatusType = 'passed' | 'llm_failed' | 'parse_failed' | 'constraints_failed' | 'error' +export type DoneTestStatusType = + | 'passed' + | 'llm_failed' + | 'finish_reason_failed' + | 'parse_failed' + | 'constraints_failed' + | 'error' export type TestState = | { status: 'queued' @@ -37,12 +43,26 @@ export const testStatusAtom = atomFamily( (a, b) => a === b, ) export const runningTestsAtom = atom([]) + +// Match the Rust enum +// engine/baml-schema-wasm/src/runtime_wasm/mod.rs +enum RustTestStatus { + Passed, + LLMFailure, + ParseFailure, + FinishReasonFailed, + ConstraintsFailed, + AssertFailed, + UnableToRun, +} + export const statusCountAtom = atom({ queued: 0, running: 0, done: { passed: 0, llm_failed: 0, + finish_reason_failed: 0, parse_failed: 0, constraints_failed: 0, error: 0, @@ -135,6 +155,7 @@ export const useRunHooks = () => { done: { passed: 0, llm_failed: 0, + finish_reason_failed: 0, parse_failed: 0, constraints_failed: 0, error: 0, @@ -189,15 +210,17 @@ export const useRunHooks = () => { const { res, elapsed } = result.value // console.log('result', i, result.value.res.llm_response(), 'batch[i]', batch[i]) - let status: Number = res.status() + let status: RustTestStatus = res.status() as unknown as RustTestStatus let response_status: DoneTestStatusType = 'error' - if (status === 0) { + if (status === RustTestStatus.Passed) { response_status = 'passed' - } else if (status === 1) { + } else if (status === RustTestStatus.LLMFailure) { response_status = 'llm_failed' - } else if (status === 2) { + } else if (status === RustTestStatus.ParseFailure) { response_status = 'parse_failed' - } else if (status === 3 || status === 4) { + } else if (status === RustTestStatus.FinishReasonFailed) { + response_status = 'finish_reason_failed' + } else if (status === RustTestStatus.ConstraintsFailed || status === RustTestStatus.AssertFailed) { response_status = 'constraints_failed' } else { response_status = 'error' diff --git a/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx b/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx index bda2d47a8..7570b1ad2 100644 --- a/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx +++ b/typescript/playground-common/src/baml_wasm_web/test_uis/test_result.tsx @@ -58,6 +58,8 @@ const TestStatusMessage: React.FC<{ testStatus: DoneTestStatusType }> = ({ testS return
LLM Failed
case 'parse_failed': return
Parse Failed
+ case 'finish_reason_failed': + return
Finish Reason Failed
case 'constraints_failed': return
Constraints Failed
case 'error': @@ -101,9 +103,25 @@ const TestStatusIcon: React.FC<{ ) } -type FilterValues = 'queued' | 'running' | 'error' | 'llm_failed' | 'parse_failed' | 'constraints_failed' | 'passed' +type FilterValues = + | 'queued' + | 'running' + | 'error' + | 'llm_failed' + | 'parse_failed' + | 'constraints_failed' + | 'passed' + | 'finish_reason_failed' const filterAtom = atom( - new Set(['running', 'error', 'llm_failed', 'parse_failed', 'constraints_failed', 'passed']), + new Set([ + 'running', + 'error', + 'llm_failed', + 'parse_failed', + 'constraints_failed', + 'passed', + 'finish_reason_failed', + ]), ) const checkFilter = (filter: Set, status: TestStatusType, test_status?: DoneTestStatusType) => { @@ -150,6 +168,17 @@ const ParsedTestResult: React.FC<{ doneStatus: string; parsed?: WasmParsedTestRe } }, [parsed, hasClosedIntroToChecksDialog, setShowIntroToChecksDialog]) + if (doneStatus === 'finish_reason_failed') { + return ( +
+
Pre-parse Error
+
+ {failure &&
{failure}
} +
+
+ ) + } + if (doneStatus === 'parse_failed' || parsed !== undefined) { return (