Skip to content

Commit

Permalink
fix issues with clinets
Browse files Browse the repository at this point in the history
  • Loading branch information
hellovai committed Dec 3, 2024
1 parent aa6d118 commit 7732f8b
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 20 deletions.
35 changes: 34 additions & 1 deletion engine/baml-lib/jinja-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct RenderContext_Client {
pub name: String,
pub provider: String,
pub default_role: String,
pub allowed_roles: Vec<String>,
}

#[derive(Debug)]
Expand All @@ -49,6 +50,7 @@ fn render_minijinja(
mut ctx: RenderContext,
template_string_macros: &[TemplateStringMacro],
default_role: String,
allowed_roles: Vec<String>,
) -> Result<RenderedPrompt, minijinja::Error> {
let mut env = get_env();

Expand Down Expand Up @@ -240,7 +242,11 @@ fn render_minijinja(
// Only add the message if it contains meaningful content
if !parts.is_empty() {
chat_messages.push(RenderedChatMessage {
role: role.as_ref().unwrap_or(&default_role).to_string(),
role: match role.as_ref() {
Some(r) if allowed_roles.contains(r) => r.clone(),
Some(_) => default_role.clone(),
None => default_role.clone(),
},
allow_duplicate_role,
parts,
});
Expand Down Expand Up @@ -410,12 +416,14 @@ pub fn render_prompt(
let eval_ctx = EvaluationContext::new(env_vars, false);
let minijinja_args: minijinja::Value = args.clone().to_minijinja_value(ir, &eval_ctx);
let default_role = ctx.client.default_role.clone();
let allowed_roles = ctx.client.allowed_roles.clone();
let rendered = render_minijinja(
template,
&minijinja_args,
ctx,
template_string_macros,
default_role,
allowed_roles,
);

match rendered {
Expand Down Expand Up @@ -505,6 +513,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -565,6 +574,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -623,6 +633,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -690,6 +701,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -767,6 +779,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -817,6 +830,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -856,6 +870,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -895,6 +910,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -934,6 +950,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -995,6 +1012,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -1054,6 +1072,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -1131,6 +1150,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::from([("ROLE".to_string(), BamlValue::String("john doe".into()))]),
Expand Down Expand Up @@ -1185,6 +1205,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1235,6 +1256,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1281,6 +1303,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand All @@ -1300,6 +1323,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1342,6 +1366,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand All @@ -1361,6 +1386,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1403,6 +1429,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1476,6 +1503,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1569,6 +1597,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1674,6 +1703,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1750,6 +1780,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1809,6 +1840,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down Expand Up @@ -1913,6 +1945,7 @@ mod render_tests {
name: "gpt4".to_string(),
provider: "openai".to_string(),
default_role: "system".to_string(),
allowed_roles: vec!["system".to_string()],
},
output_format: OutputFormatContent::new_string(),
tags: HashMap::new(),
Expand Down
10 changes: 9 additions & 1 deletion engine/baml-lib/llm-client/src/clients/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,15 @@ impl ResolvedAnthropic {
}

pub fn default_role(&self) -> String {
self.role_selection.default_or_else(|| "user".to_string())
self.role_selection
.default_or_else(|| {
let allowed_roles = self.allowed_roles();
if allowed_roles.contains(&"user".to_string()) {
"user".to_string()
} else {
allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string())
}
})
}
}

Expand Down
20 changes: 14 additions & 6 deletions engine/baml-lib/llm-client/src/clients/aws_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,15 @@ impl ResolvedAwsBedrock {
}

pub fn default_role(&self) -> String {
self.role_selection.default_or_else(|| "user".to_string())
self.role_selection
.default_or_else(|| {
let allowed_roles = self.allowed_roles();
if allowed_roles.contains(&"user".to_string()) {
"user".to_string()
} else {
allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string())
}
})
}
}

Expand Down Expand Up @@ -201,11 +209,11 @@ impl UnresolvedAwsBedrock {
}),
"stop_sequences" => inference_config.stop_sequences = match v.into_array() {
Ok((stop_sequences, _)) => Some(stop_sequences.into_iter().filter_map(|s| match s.into_str() {
Ok((s, _)) => Some(s),
Err(e) => {
properties.push_error(format!("stop_sequences values must be a string: got {}", e.r#type()), e.meta().clone());
None
}
Ok((s, _)) => Some(s),
Err(e) => {
properties.push_error(format!("stop_sequences values must be a string: got {}", e.r#type()), e.meta().clone());
None
}
})
.collect::<Vec<_>>()),
Err(e) => {
Expand Down
11 changes: 9 additions & 2 deletions engine/baml-lib/llm-client/src/clients/google_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashSet;
use crate::{AllowedRoleMetadata, SupportedRequestModes, UnresolvedAllowedRoleMetadata};
use anyhow::Result;
use crate::{
AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection
FinishReasonFilter, RolesSelection, UnresolvedFinishReasonFilter, UnresolvedRolesSelection
};

use baml_types::{EvaluationContext, StringOr, UnresolvedValue};
Expand Down Expand Up @@ -69,7 +69,14 @@ impl ResolvedGoogleAI {
}

pub fn default_role(&self) -> String {
self.role_selection.default_or_else(|| "user".to_string())
self.role_selection.default_or_else(|| {
let allowed_roles = self.allowed_roles();
if allowed_roles.contains(&"user".to_string()) {
"user".to_string()
} else {
allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string())
}
})
}
}

Expand Down
6 changes: 3 additions & 3 deletions engine/baml-lib/llm-client/src/clients/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
})
}

pub fn ensure_roles_selection(&mut self) -> UnresolvedRolesSelection {
pub(crate) fn ensure_roles_selection(&mut self) -> UnresolvedRolesSelection {
let allowed_roles = self.ensure_allowed_roles();
let default_role = self.ensure_default_role(allowed_roles.as_ref().unwrap_or(&vec![
StringOr::Value("user".to_string()),
Expand Down Expand Up @@ -276,8 +276,8 @@ impl<Meta: Clone> PropertyHandler<Meta> {
)
}
(None, Some((_, deny, _))) => {
UnresolvedFinishReasonFilter::DenyList(deny.into_iter().filter_map(|v| match v.to_str() {
Ok(s) => Some(s.0),
UnresolvedFinishReasonFilter::DenyList(deny.into_iter().filter_map(|v| match v.into_str() {
Ok((s, _)) => Some(s.clone()),
Err(other) => {
self.push_error(
"values in finish_reason_deny_list must be strings.",
Expand Down
9 changes: 8 additions & 1 deletion engine/baml-lib/llm-client/src/clients/vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,14 @@ impl ResolvedVertex {
}

pub fn default_role(&self) -> String {
self.role_selection.default_or_else(|| "user".to_string())
self.role_selection.default_or_else(|| {
let allowed_roles = self.allowed_roles();
if allowed_roles.contains(&"user".to_string()) {
"user".to_string()
} else {
allowed_roles.first().cloned().unwrap_or_else(|| "user".to_string())
}
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ impl AnthropicClient {
name: client.name.clone(),
provider: client.provider.to_string(),
default_role: properties.default_role(),
allowed_roles: properties.allowed_roles(),
},
features: ModelFeatures {
chat: true,
Expand All @@ -290,6 +291,7 @@ impl AnthropicClient {
name: client.name().into(),
provider: client.elem().provider.to_string(),
default_role: properties.default_role(),
allowed_roles: properties.allowed_roles(),
},
features: ModelFeatures {
chat: true,
Expand Down Expand Up @@ -368,7 +370,7 @@ impl RequestBuilder for AnthropicClient {
}

impl WithChat for AnthropicClient {
async fn chat(&self, _ctx: &RuntimeContext, prompt: &Vec<RenderedChatMessage>) -> LLMResponse {
async fn chat(&self, _ctx: &RuntimeContext, prompt: &[RenderedChatMessage]) -> LLMResponse {
let (response, system_now, instant_now) = match make_parsed_request::<
AnthropicMessageResponse,
>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl AwsClient {
name: client.name.clone(),
provider: client.provider.to_string(),
default_role: properties.default_role(),
allowed_roles: properties.allowed_roles(),
},
features: ModelFeatures {
chat: true,
Expand All @@ -94,6 +95,7 @@ impl AwsClient {
name: client.name().into(),
provider: client.elem().provider.to_string(),
default_role: properties.default_role(),
allowed_roles: properties.allowed_roles(),
},
features: ModelFeatures {
chat: true,
Expand Down
Loading

0 comments on commit 7732f8b

Please sign in to comment.