Skip to content

Commit

Permalink
Add support for dynamic clients
Browse files Browse the repository at this point in the history
BAML functions now support passing in which model / configuration they use at runtime.
* Enables dynamically picking which options they want to run.
* Tracing captures dynamic properties
* Retry policies must be defined in BAML
  • Loading branch information
hellovai committed Jun 16, 2024
1 parent 54547a3 commit 2fdbb1e
Show file tree
Hide file tree
Showing 51 changed files with 1,298 additions and 480 deletions.
61 changes: 61 additions & 0 deletions engine/baml-runtime/src/client_builder/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// This is designed to build any type of client, not just primitives
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::Arc;

use baml_types::{BamlMap, BamlValue};
use serde::Serialize;

use crate::{internal::llm_client::llm_provider::LLMProvider, RuntimeContext};

#[derive(Clone)]
pub enum PrimitiveClient {
OpenAI,
Anthropic,
Google,
}

#[derive(Serialize, Clone)]
pub struct ClientProperty {
pub name: String,
pub provider: String,
pub retry_policy: Option<String>,
pub options: BamlMap<String, BamlValue>,
}

#[derive(Clone)]
pub struct ClientBuilder {
clients: HashMap<String, ClientProperty>,
primary: Option<String>,
}

impl ClientBuilder {
pub fn new() -> Self {
Self {
clients: Default::default(),
primary: None,
}
}

pub fn add_client(&mut self, client: ClientProperty) {
self.clients.insert(client.name.clone(), client);
}

pub fn set_primary(&mut self, primary: String) {
self.primary = Some(primary);
}

pub fn to_clients(
&self,
ctx: &RuntimeContext,
) -> Result<(Option<String>, HashMap<String, Arc<LLMProvider>>)> {
let mut clients = HashMap::new();
for (name, client) in &self.clients {
let provider = LLMProvider::try_from((client, ctx))
.context(format!("Failed to parse client: {}", name))?;
clients.insert(name.into(), Arc::new(provider));
}
// TODO: Also do validation here
Ok((self.primary.clone(), clients))
}
}
30 changes: 28 additions & 2 deletions engine/baml-runtime/src/internal/llm_client/llm_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::sync::Arc;
use anyhow::Result;
use internal_baml_core::ir::ClientWalker;

use crate::{runtime_interface::InternalClientLookup, RuntimeContext};
use crate::{
client_builder::ClientProperty, runtime_interface::InternalClientLookup, RuntimeContext,
};

use super::{
orchestrator::{
Expand All @@ -20,6 +22,15 @@ pub enum LLMProvider {
Strategy(LLMStrategyProvider),
}

impl std::fmt::Debug for LLMProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LLMProvider::Primitive(provider) => write!(f, "Primitive({})", provider),
LLMProvider::Strategy(provider) => write!(f, "Strategy({})", provider),
}
}
}

impl WithRetryPolicy for LLMProvider {
fn retry_policy_name(&self) -> Option<&str> {
match self {
Expand All @@ -37,7 +48,22 @@ impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for LLMProvider {
"baml-fallback" | "fallback" | "baml-round-robin" | "round-robin" => {
LLMStrategyProvider::try_from((client, ctx)).map(LLMProvider::Strategy)
}
name => LLMPrimitiveProvider::try_from((client, ctx))
_ => LLMPrimitiveProvider::try_from((client, ctx))
.map(Arc::new)
.map(LLMProvider::Primitive),
}
}
}

impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMProvider {
type Error = anyhow::Error;

fn try_from(value: (&ClientProperty, &RuntimeContext)) -> Result<Self> {
match value.0.provider.as_str() {
"baml-fallback" | "fallback" | "baml-round-robin" | "round-robin" => {
LLMStrategyProvider::try_from(value).map(LLMProvider::Strategy)
}
_ => LLMPrimitiveProvider::try_from(value)
.map(Arc::new)
.map(LLMProvider::Primitive),
}
Expand Down
23 changes: 23 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod traits;

use anyhow::Result;

use internal_baml_core::ir::ClientWalker;
use internal_baml_jinja::RenderedPrompt;
use serde::Serialize;

Expand Down Expand Up @@ -166,3 +167,25 @@ impl std::fmt::Display for LLMCompleteResponse {
write!(f, "{}", self.content.dimmed())
}
}

// For parsing args
fn resolve_properties_walker(
client: &ClientWalker,
ctx: &crate::RuntimeContext,
) -> Result<std::collections::HashMap<String, serde_json::Value>> {
use anyhow::Context;
(&client.item.elem.options)
.iter()
.map(|(k, v)| {
Ok((
k.into(),
ctx.resolve_expression::<serde_json::Value>(v)
.context(format!(
"client {} could not resolve options.{}",
client.name(),
k
))?,
))
})
.collect::<Result<std::collections::HashMap<_, _>>>()
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use internal_baml_jinja::{
use reqwest::Response;

use crate::{
client_builder::ClientProperty,
internal::llm_client::{
primitive::{
anthropic::types::{AnthropicMessageResponse, StopReason},
Expand Down Expand Up @@ -61,23 +62,9 @@ pub struct AnthropicClient {
// resolves/constructs PostRequestProperties from the client's options and runtime context, fleshing out the needed headers and parameters
// basically just reads the client's options and matches them to needed properties or defaults them
fn resolve_properties(
client: &ClientWalker,
mut properties: HashMap<String, serde_json::Value>,
ctx: &RuntimeContext,
) -> Result<PostRequestProperities> {
let mut properties = (&client.item.elem.options)
.iter()
.map(|(k, v)| {
Ok((
k.into(),
ctx.resolve_expression::<serde_json::Value>(v)
.context(format!(
"client {} could not resolve options.{}",
client.name(),
k
))?,
))
})
.collect::<Result<HashMap<_, _>>>()?;
// this is a required field
properties
.entry("max_tokens".into())
Expand Down Expand Up @@ -309,10 +296,36 @@ impl WithStreamChat for AnthropicClient {

// constructs base client and resolves properties based on context
impl AnthropicClient {
pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result<Self> {
Ok(Self {
name: client.name.clone(),
properties: resolve_properties(
client
.options
.iter()
.map(|(k, v)| Ok((k.clone(), json!(v))))
.collect::<Result<HashMap<_, _>>>()?,
ctx,
)?,
context: RenderContext_Client {
name: client.name.clone(),
provider: client.provider.clone(),
},
features: ModelFeatures {
chat: true,
completion: false,
anthropic_system_constraints: true,
},
retry_policy: client.retry_policy.clone(),
client: create_client()?,
})
}

pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result<AnthropicClient> {
let properties = super::super::resolve_properties_walker(client, ctx)?;
Ok(Self {
name: client.name().into(),
properties: resolve_properties(client, ctx)?,
properties: resolve_properties(properties, ctx)?,
context: RenderContext_Client {
name: client.name().into(),
provider: client.elem().provider.clone(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::client_builder::ClientProperty;
use crate::RuntimeContext;
use crate::{
internal::llm_client::{
Expand Down Expand Up @@ -42,25 +43,9 @@ pub struct GoogleClient {
}

fn resolve_properties(
client: &ClientWalker,
mut properties: HashMap<String, serde_json::Value>,
ctx: &RuntimeContext,
) -> Result<PostRequestProperities, anyhow::Error> {
let mut properties = (&client.item.elem.options)
.iter()
.map(|(k, v)| {
Ok((
k.into(),
ctx.resolve_expression::<serde_json::Value>(v)
.context(format!(
"client {} could not resolve options.{}",
client.name(),
k
))?,
))
})
.collect::<Result<HashMap<_, _>>>()?;
// this is a required field

let default_role = properties
.remove("default_role")
.and_then(|v| v.as_str().map(|s| s.to_string()))
Expand Down Expand Up @@ -234,10 +219,11 @@ impl WithStreamChat for GoogleClient {
}

impl GoogleClient {
pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result<GoogleClient> {
pub fn new(client: &ClientWalker, ctx: &RuntimeContext) -> Result<Self> {
let properties = super::super::resolve_properties_walker(client, ctx)?;
Ok(Self {
name: client.name().into(),
properties: resolve_properties(client, ctx)?,
properties: resolve_properties(properties, ctx)?,
context: RenderContext_Client {
name: client.name().into(),
provider: client.elem().provider.clone(),
Expand All @@ -255,6 +241,31 @@ impl GoogleClient {
client: create_client()?,
})
}

pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result<Self> {
Ok(Self {
name: client.name.clone(),
properties: resolve_properties(
client
.options
.iter()
.map(|(k, v)| Ok((k.clone(), json!(v))))
.collect::<Result<HashMap<_, _>>>()?,
ctx,
)?,
context: RenderContext_Client {
name: client.name.clone(),
provider: client.provider.clone(),
},
features: ModelFeatures {
chat: true,
completion: false,
anthropic_system_constraints: false,
},
retry_policy: client.retry_policy.clone(),
client: create_client()?,
})
}
}

impl RequestBuilder for GoogleClient {
Expand Down
43 changes: 41 additions & 2 deletions engine/baml-runtime/src/internal/llm_client/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use baml_types::BamlValue;
use internal_baml_core::ir::{repr::IntermediateRepr, ClientWalker};

use crate::{
internal::prompt_renderer::PromptRenderer, runtime_interface::InternalClientLookup,
RuntimeContext,
client_builder::ClientProperty, internal::prompt_renderer::PromptRenderer,
runtime_interface::InternalClientLookup, RuntimeContext,
};

use self::{
Expand Down Expand Up @@ -53,6 +53,42 @@ macro_rules! match_llm_provider {
};
}

impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMPrimitiveProvider {
type Error = anyhow::Error;

fn try_from((value, ctx): (&ClientProperty, &RuntimeContext)) -> Result<Self> {
match value.provider.as_str() {
"openai" => OpenAIClient::dynamic_new(value, ctx).map(LLMPrimitiveProvider::OpenAI),
"azure-openai" => {
OpenAIClient::dynamic_new_azure(value, ctx).map(LLMPrimitiveProvider::OpenAI)
}
"ollama" => {
OpenAIClient::dynamic_new_ollama(value, ctx).map(LLMPrimitiveProvider::OpenAI)
}
"anthropic" => {
AnthropicClient::dynamic_new(value, ctx).map(LLMPrimitiveProvider::Anthropic)
}
"google-ai" => GoogleClient::dynamic_new(value, ctx).map(LLMPrimitiveProvider::Google),
other => {
let options = [
"openai",
"anthropic",
"ollama",
"google-ai",
"azure-openai",
"fallback",
"round-robin",
];
anyhow::bail!(
"Unsupported provider: {}. Available ones are: {}",
other,
options.join(", ")
)
}
}
}
}

impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for LLMPrimitiveProvider {
type Error = anyhow::Error;

Expand All @@ -76,6 +112,7 @@ impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for LLMPrimitiveProvider {
"openai",
"anthropic",
"ollama",
"google-ai",
"azure-openai",
"fallback",
"round-robin",
Expand Down Expand Up @@ -176,3 +213,5 @@ impl RequestBuilder for LLMPrimitiveProvider {
match_llm_provider!(self, build_request, prompt, stream)
}
}

use super::resolve_properties_walker;
Loading

0 comments on commit 2fdbb1e

Please sign in to comment.