diff --git a/engine/baml-lib/jinja-runtime/src/lib.rs b/engine/baml-lib/jinja-runtime/src/lib.rs index 82ffa7eea..e94d6c6e1 100644 --- a/engine/baml-lib/jinja-runtime/src/lib.rs +++ b/engine/baml-lib/jinja-runtime/src/lib.rs @@ -25,6 +25,7 @@ pub struct RenderContext_Client { pub name: String, pub provider: String, pub default_role: String, + pub allowed_roles: Vec, } #[derive(Debug)] @@ -49,6 +50,7 @@ fn render_minijinja( mut ctx: RenderContext, template_string_macros: &[TemplateStringMacro], default_role: String, + allowed_roles: Vec, ) -> Result { let mut env = get_env(); @@ -240,7 +242,11 @@ fn render_minijinja( // Only add the message if it contains meaningful content if !parts.is_empty() { chat_messages.push(RenderedChatMessage { - role: role.as_ref().unwrap_or(&default_role).to_string(), + role: match role.as_ref() { + Some(r) if allowed_roles.contains(r) => r.clone(), + Some(_) => default_role.clone(), + None => default_role.clone(), + }, allow_duplicate_role, parts, }); @@ -410,12 +416,14 @@ pub fn render_prompt( let eval_ctx = EvaluationContext::new(env_vars, false); let minijinja_args: minijinja::Value = args.clone().to_minijinja_value(ir, &eval_ctx); let default_role = ctx.client.default_role.clone(); + let allowed_roles = ctx.client.allowed_roles.clone(); let rendered = render_minijinja( template, &minijinja_args, ctx, template_string_macros, default_role, + allowed_roles, ); match rendered { @@ -505,6 +513,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -565,6 +574,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -623,6 +633,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -690,6 +701,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -767,6 +779,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -817,6 +830,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -856,6 +870,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -895,6 +910,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -934,6 +950,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -995,6 +1012,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -1054,6 +1072,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -1131,6 +1150,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]), @@ -1185,6 +1205,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1235,6 +1256,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1281,6 +1303,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1300,6 +1323,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1342,6 +1366,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1361,6 +1386,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1403,6 +1429,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1476,6 +1503,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1569,6 +1597,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1674,6 +1703,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1750,6 +1780,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1809,6 +1840,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), @@ -1913,6 +1945,7 @@ mod render_tests { name: "gpt4".to_string(), provider: "openai".to_string(), default_role: "system".to_string(), + allowed_roles: vec!["system".to_string()], }, output_format: OutputFormatContent::new_string(), tags: HashMap::new(), diff --git a/engine/baml-lib/llm-client/src/clients/anthropic.rs b/engine/baml-lib/llm-client/src/clients/anthropic.rs index 596b35523..fc9a7403a 100644 --- a/engine/baml-lib/llm-client/src/clients/anthropic.rs +++ b/engine/baml-lib/llm-client/src/clients/anthropic.rs @@ -63,7 +63,15 @@ impl ResolvedAnthropic { } pub fn default_role(&self) -> String { - self.role_selection.default_or_else(|| "user".to_string()) + self.role_selection + .default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) } } diff --git a/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs b/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs index 208b6f09d..4b5dc8d1a 100644 --- a/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs +++ b/engine/baml-lib/llm-client/src/clients/aws_bedrock.rs @@ -78,7 +78,15 @@ impl ResolvedAwsBedrock { } pub fn default_role(&self) -> String { - self.role_selection.default_or_else(|| "user".to_string()) + self.role_selection + .default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) } } @@ -201,11 +209,11 @@ impl UnresolvedAwsBedrock { }), "stop_sequences" => inference_config.stop_sequences = match v.into_array() { Ok((stop_sequences, _)) => Some(stop_sequences.into_iter().filter_map(|s| match s.into_str() { - Ok((s, _)) => Some(s), - Err(e) => { - properties.push_error(format!("stop_sequences values must be a string: got {}", e.r#type()), e.meta().clone()); - None - } + Ok((s, _)) => Some(s), + Err(e) => { + properties.push_error(format!("stop_sequences values must be a string: got {}", e.r#type()), e.meta().clone()); + None + } }) .collect::>()), Err(e) => { diff --git a/engine/baml-lib/llm-client/src/clients/google_ai.rs b/engine/baml-lib/llm-client/src/clients/google_ai.rs index e0772c781..3f36fd61a 100644 --- a/engine/baml-lib/llm-client/src/clients/google_ai.rs +++ b/engine/baml-lib/llm-client/src/clients/google_ai.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata}; use anyhow::Result; use crate::{ - AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection + FinishReasonFilter, RolesSelection, UnresolvedFinishReasonFilter, UnresolvedRolesSelection }; use baml_types::{EvaluationContext, StringOr, UnresolvedValue}; @@ -69,7 +69,14 @@ impl ResolvedGoogleAI { } pub fn default_role(&self) -> String { - self.role_selection.default_or_else(|| "user".to_string()) + self.role_selection.default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) } } diff --git a/engine/baml-lib/llm-client/src/clients/helpers.rs b/engine/baml-lib/llm-client/src/clients/helpers.rs index 589aea051..1abc9866e 100644 --- a/engine/baml-lib/llm-client/src/clients/helpers.rs +++ b/engine/baml-lib/llm-client/src/clients/helpers.rs @@ -183,7 +183,7 @@ impl PropertyHandler { }) } - pub fn ensure_roles_selection(&mut self) -> UnresolvedRolesSelection { + pub(crate) fn ensure_roles_selection(&mut self) -> UnresolvedRolesSelection { let allowed_roles = self.ensure_allowed_roles(); let default_role = self.ensure_default_role(allowed_roles.as_ref().unwrap_or(&vec![ StringOr::Value("user".to_string()), @@ -276,8 +276,8 @@ impl PropertyHandler { ) } (None, Some((_, deny, _))) => { - UnresolvedFinishReasonFilter::DenyList(deny.into_iter().filter_map(|v| match v.to_str() { - Ok(s) => Some(s.0), + UnresolvedFinishReasonFilter::DenyList(deny.into_iter().filter_map(|v| match v.into_str() { + Ok((s, _)) => Some(s.clone()), Err(other) => { self.push_error( "values in finish_reason_deny_list must be strings.", diff --git a/engine/baml-lib/llm-client/src/clients/vertex.rs b/engine/baml-lib/llm-client/src/clients/vertex.rs index 1a5b217c6..e49da79ca 100644 --- a/engine/baml-lib/llm-client/src/clients/vertex.rs +++ b/engine/baml-lib/llm-client/src/clients/vertex.rs @@ -149,7 +149,14 @@ impl ResolvedVertex { } pub fn default_role(&self) -> String { - self.role_selection.default_or_else(|| "user".to_string()) + self.role_selection.default_or_else(|| { + let allowed_roles = self.allowed_roles(); + if allowed_roles.contains(&"user".to_string()) { + "user".to_string() + } else { + allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string()) + } + }) } } diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs index 38f5ba7ed..a285630e7 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/anthropic/anthropic_client.rs @@ -268,6 +268,7 @@ impl AnthropicClient { name: client.name.clone(), provider: client.provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -290,6 +291,7 @@ impl AnthropicClient { name: client.name().into(), provider: client.elem().provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -368,7 +370,7 @@ impl RequestBuilder for AnthropicClient { } impl WithChat for AnthropicClient { - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { + async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { let (response, system_now, instant_now) = match make_parsed_request::< AnthropicMessageResponse, >( diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs index 66a26757b..69c223642 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/aws/aws_client.rs @@ -72,6 +72,7 @@ impl AwsClient { name: client.name.clone(), provider: client.provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -94,6 +95,7 @@ impl AwsClient { name: client.name().into(), provider: client.elem().provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs index 6e235c68a..8ba6c8639 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs @@ -209,6 +209,7 @@ impl GoogleAIClient { name: client.name().into(), provider: client.elem().provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -236,6 +237,7 @@ impl GoogleAIClient { name: client.name.clone(), provider: client.provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -308,7 +310,7 @@ impl RequestBuilder for GoogleAIClient { } impl WithChat for GoogleAIClient { - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { + async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { //non-streaming, complete response is returned let (response, system_now, instant_now) = match make_parsed_request::(self, either::Either::Right(prompt), false) @@ -338,7 +340,7 @@ impl WithChat for GoogleAIClient { return LLMResponse::LLMFailure(LLMErrorResponse { client: self.context.name.to_string(), model: None, - prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.clone()), + prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.to_vec()), start_time: system_now, request_options: self.properties.properties.clone(), latency: instant_now.elapsed(), @@ -349,7 +351,7 @@ impl WithChat for GoogleAIClient { LLMResponse::Success(LLMCompleteResponse { client: self.context.name.to_string(), - prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.clone()), + prompt: internal_baml_jinja::RenderedPrompt::Chat(prompt.to_vec()), content: content.parts[0].text.clone(), start_time: system_now, latency: instant_now.elapsed(), diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs index 64003d758..398fa0e52 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs @@ -157,7 +157,7 @@ impl WithNoCompletion for OpenAIClient {} // } impl WithChat for OpenAIClient { - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { + async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { let (response, system_start, instant_start) = match make_parsed_request::( self, @@ -406,6 +406,7 @@ macro_rules! make_openai_client { name: $client.name.clone(), provider: $client.provider.to_string(), default_role: $properties.default_role(), + allowed_roles: $properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -427,6 +428,7 @@ macro_rules! make_openai_client { name: $client.name().into(), provider: $client.elem().provider.to_string(), default_role: $properties.default_role(), + allowed_roles: $properties.allowed_roles(), }, features: ModelFeatures { chat: true, diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs index 611842901..ec376f9c6 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/vertex/vertex_client.rs @@ -251,6 +251,7 @@ impl VertexClient { name: client.name().into(), provider: client.elem().provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -278,6 +279,7 @@ impl VertexClient { name: client.name.clone(), provider: client.provider.to_string(), default_role: properties.default_role(), + allowed_roles: properties.allowed_roles(), }, features: ModelFeatures { chat: true, @@ -397,7 +399,7 @@ impl RequestBuilder for VertexClient { } impl WithChat for VertexClient { - async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec) -> LLMResponse { + async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse { //non-streaming, complete response is returned let (response, system_now, instant_now) = match make_parsed_request::(self, either::Either::Right(prompt), false)