Skip to content

Commit

Permalink
Add support for Client Registry (#683)
Browse files Browse the repository at this point in the history
BAML functions now support passing in which model / configuration they
use at runtime.
* Enables dynamically picking which options they want to run.
* Tracing captures dynamic properties
* Retry policies must be created in BAML
  • Loading branch information
hellovai authored Jul 10, 2024
1 parent 67f9c6a commit c0fb454
Show file tree
Hide file tree
Showing 59 changed files with 1,536 additions and 573 deletions.
96 changes: 96 additions & 0 deletions docs/docs/calling-baml/client-registry.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
---
title: "Client Registry"
---

If you need to modify the model / parameters for an LLM client at runtime, you can modify the `ClientRegistry` for any specified function.

<CodeGroup>

```python Python
from baml_py import ClientRegistry

async def run():
cr = ClientRegistry()
# Creates a new client
cr.add_llm_client(name='MyAmazingClient', provider='openai', options={
"model": "gpt-4o",
"temperature": 0.7,
"api_key": "sk-..."
})
# Sets MyAmazingClient as the primary client
cr.set_primary('MyAmazingClient')

# ExtractResume will now use MyAmazingClient as the calling client
res = await b.ExtractResume("...", { "client_registry": cr })
```

```typescript TypeScript
import { ClientRegistry } from '@boundaryml/baml'

async function run() {
const cr = new ClientRegistry()
// Creates a new client
cr.addLlmClient({ name: 'MyAmazingClient', provider: 'openai', options: {
model: "gpt-4o",
temperature: 0.7,
api_key: "sk-..."
}})
// Sets MyAmazingClient as the primary client
cr.setPrimary('MyAmazingClient')

// ExtractResume will now use MyAmazingClient as the calling client
const res = await b.ExtractResume("...", { clientRegistry: cr })
}
```

```ruby Ruby
Not available yet
```

</CodeGroup>

## ClientRegistry Interface
import ClientConstructorParams from '/snippets/client-params.mdx'


<Tip>
Note: `ClientRegistry` is imported from `baml_py` in Python and `@boundaryml/baml` in TypeScript, not `baml_client`.

As we mature `ClientRegistry`, we will add a more type-safe and ergonomic interface directly in `baml_client`. See [Github issue #766](https://github.com/BoundaryML/baml/issues/766).
</Tip>

Methods use `snake_case` in Python and `camelCase` in TypeScript.

### add_llm_client / addLlmClient
A function to add an LLM client to the registry.

<ParamField
path="name"
type="string"
required
>
The name of the client.

<Warning>
Using the exact same name as a client also defined in .baml files overwrites the existing client whenever the ClientRegistry is used.
</Warning>
</ParamField>

<ClientConstructorParams />

<ParamField path="retry_policy" type="string">
The name of a retry policy that is already defined in a .baml file. See [Retry Policies](/docs/snippets/clients/retry.mdx).
</ParamField>

### set_primary / setPrimary
This sets the client for the function to use. (i.e. replaces the `client` property in a function)

<ParamField
path="name"
type="string"
required
>
The name of the client to use.

This can be a new client that was added with `add_llm_client` or an existing client that is already in a .baml file.
</ParamField>
Empty file.
2 changes: 1 addition & 1 deletion docs/docs/get-started/what-is-baml.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Share your creations and ask questions in our [Discord](https://discord.gg/BTNBe
## Starter projects

- [BAML + NextJS 14 + Streaming](https://github.com/BoundaryML/baml-examples/tree/main/nextjs-starter)
- [BAML + FastAPI + Streaming](https://github.com/BoundaryML/baml-examples/tree/main/fastapi-starter)
- [BAML + FastAPI + Streaming](https://github.com/BoundaryML/baml-examples/tree/main/python-fastapi-starter)

## First steps
We recommend checking the examples in [PromptFiddle.com](https://promptfiddle.com). Once you're ready to start, [install the toolchain](/docs/get-started/quickstart/python) and read the [guides](/docs/calling-baml/calling-functions).
23 changes: 2 additions & 21 deletions docs/docs/snippets/clients/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,11 @@ function MakeHaiku(topic: string) -> string {

## Fields

<ParamField path="provider" required>
This configures which provider to use. The provider is responsible for handling the actual API calls to the LLM service. The provider is a required field.
import ClientConstructorParams from '/snippets/client-constructor.mdx'

The configuration modifies the URL request BAML runtime makes.

| Provider Name | Docs | Notes |
| -------------- | -------------------------------- | ---------------------------------------------------------- |
| `openai` | [OpenAI](providers/openai) | Anything that follows openai's API exactly |
| `ollama` | [Ollama](providers/ollama) | Alias for an openai client but with default ollama options |
| `azure-openai` | [Azure OpenAI](providers/azure) | |
| `anthropic` | [Anthropic](providers/anthropic) | |
| `google-ai` | [Google AI](providers/gemini) | |
| `fallback` | [Fallback](fallback) | Used to chain models conditional on failures |
| `round-robin` | [Round Robin](round-robin) | Used to load balance |

</ParamField>
<ClientConstructorParams />

<ParamField path="retry_policy">
The name of the retry policy. See [Retry
Policy](/docs/snippets/clients/retry).
</ParamField>

<ParamField path="options">
These vary per provider. Please see provider specific documentation for more
information. Generally they are pass through options to the POST request made
to the LLM.
</ParamField>
38 changes: 37 additions & 1 deletion docs/docs/snippets/clients/providers/other.mdx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: Others (e.g. openrouter)
title: Others (e.g. groq, openrouter)
---

Since many model providers are settling on following the OpenAI Chat API spec, the recommended way to use them is to use the `openai` provider.
Expand All @@ -26,3 +26,39 @@ client<llm> MyClient {
}
}
```

### Groq

https://groq.com - Fast AI Inference

You can use Groq's openai interface with BAML.

See https://console.groq.com/docs/openai for more information.

```rust BAML
client<llm> MyClient {
provider openai
options {
base_url "https://api.groq.com/openai/v1"
api_key env.GROQ_API_KEY
model "llama3-70b-8192"
}
}
```

### Together AI

https://www.together.ai/ - The fastest cloud platform for building and running generative AI.

See https://docs.together.ai/docs/openai-api-compatibility for more information.

```rust BAML
client<llm> MyClient {
provider openai
options {
base_url "https://api.together.ai/v1"
api_key env.TOGETHER_API_KEY
model "meta-llama/Llama-3-70b-chat-hf"
}
}
```
13 changes: 9 additions & 4 deletions docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,21 @@
]
},
{
"group": "Calling BAML Functions",
"group": "Advanced BAML Snippets",
"pages": [
"docs/calling-baml/dynamic-types",
"docs/calling-baml/client-registry"
]
},
{
"group": "BAML with Python/TS/Ruby",
"pages": [
"docs/calling-baml/generate-baml-client",
"docs/calling-baml/set-env-vars",
"docs/calling-baml/calling-functions",
"docs/calling-baml/streaming",
"docs/calling-baml/concurrent-calls",
"docs/calling-baml/multi-modal",
"docs/calling-baml/dynamic-types",
"docs/calling-baml/dynamic-clients"
"docs/calling-baml/multi-modal"
]
},
{
Expand Down
23 changes: 23 additions & 0 deletions docs/snippets/client-constructor.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<ParamField path="provider" type="string" required>
This configures which provider to use. The provider is responsible for handling the actual API calls to the LLM service. The provider is a required field.

The configuration modifies the URL request BAML runtime makes.

| Provider Name | Docs | Notes |
| -------------- | -------------------------------- | ---------------------------------------------------------- |
| `openai` | [OpenAI](/docs/snippets/clients/providers/openai) | Anything that follows openai's API exactly |
| `ollama` | [Ollama](/docs/snippets/clients/providers/ollama) | Alias for an openai client but with default ollama options |
| `azure-openai` | [Azure OpenAI](/docs/snippets/clients/providers/azure) | |
| `anthropic` | [Anthropic](/docs/snippets/clients/providers/anthropic) | |
| `google-ai` | [Google AI](/docs/snippets/clients/providers/gemini) | |
| `fallback` | [Fallback](/docs/snippets/clients/fallback) | Used to chain models conditional on failures |
| `round-robin` | [Round Robin](/docs/snippets/clients/round-robin) | Used to load balance |

</ParamField>

<ParamField path="options" type="dict[str, Any]" required>
These vary per provider. Please see provider specific documentation for more
information. Generally they are pass through options to the POST request made
to the LLM.
</ParamField>

File renamed without changes.
61 changes: 61 additions & 0 deletions engine/baml-runtime/src/client_registry/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// This is designed to build any type of client, not just primitives
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::Arc;

use baml_types::{BamlMap, BamlValue};
use serde::Serialize;

use crate::{internal::llm_client::llm_provider::LLMProvider, RuntimeContext};

#[derive(Clone)]
pub enum PrimitiveClient {
OpenAI,
Anthropic,
Google,
}

#[derive(Serialize, Clone)]
pub struct ClientProperty {
pub name: String,
pub provider: String,
pub retry_policy: Option<String>,
pub options: BamlMap<String, BamlValue>,
}

#[derive(Clone)]
pub struct ClientRegistry {
clients: HashMap<String, ClientProperty>,
primary: Option<String>,
}

impl ClientRegistry {
pub fn new() -> Self {
Self {
clients: Default::default(),
primary: None,
}
}

pub fn add_client(&mut self, client: ClientProperty) {
self.clients.insert(client.name.clone(), client);
}

pub fn set_primary(&mut self, primary: String) {
self.primary = Some(primary);
}

pub fn to_clients(
&self,
ctx: &RuntimeContext,
) -> Result<(Option<String>, HashMap<String, Arc<LLMProvider>>)> {
let mut clients = HashMap::new();
for (name, client) in &self.clients {
let provider = LLMProvider::try_from((client, ctx))
.context(format!("Failed to parse client: {}", name))?;
clients.insert(name.into(), Arc::new(provider));
}
// TODO: Also do validation here
Ok((self.primary.clone(), clients))
}
}
30 changes: 28 additions & 2 deletions engine/baml-runtime/src/internal/llm_client/llm_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::sync::Arc;
use anyhow::Result;
use internal_baml_core::ir::ClientWalker;

use crate::{runtime_interface::InternalClientLookup, RuntimeContext};
use crate::{
client_registry::ClientProperty, runtime_interface::InternalClientLookup, RuntimeContext,
};

use super::{
orchestrator::{
Expand All @@ -20,6 +22,15 @@ pub enum LLMProvider {
Strategy(LLMStrategyProvider),
}

impl std::fmt::Debug for LLMProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LLMProvider::Primitive(provider) => write!(f, "Primitive({})", provider),
LLMProvider::Strategy(provider) => write!(f, "Strategy({})", provider),
}
}
}

impl WithRetryPolicy for LLMProvider {
fn retry_policy_name(&self) -> Option<&str> {
match self {
Expand All @@ -37,7 +48,22 @@ impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for LLMProvider {
"baml-fallback" | "fallback" | "baml-round-robin" | "round-robin" => {
LLMStrategyProvider::try_from((client, ctx)).map(LLMProvider::Strategy)
}
_name => LLMPrimitiveProvider::try_from((client, ctx))
_ => LLMPrimitiveProvider::try_from((client, ctx))
.map(Arc::new)
.map(LLMProvider::Primitive),
}
}
}

impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMProvider {
type Error = anyhow::Error;

fn try_from(value: (&ClientProperty, &RuntimeContext)) -> Result<Self> {
match value.0.provider.as_str() {
"baml-fallback" | "fallback" | "baml-round-robin" | "round-robin" => {
LLMStrategyProvider::try_from(value).map(LLMProvider::Strategy)
}
_ => LLMPrimitiveProvider::try_from(value)
.map(Arc::new)
.map(LLMProvider::Primitive),
}
Expand Down
23 changes: 23 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod traits;

use anyhow::Result;

use internal_baml_core::ir::ClientWalker;
use internal_baml_jinja::RenderedPrompt;
use serde::Serialize;
use std::error::Error;
Expand Down Expand Up @@ -169,3 +170,25 @@ impl std::fmt::Display for LLMCompleteResponse {
write!(f, "{}", self.content.dimmed())
}
}

// For parsing args
fn resolve_properties_walker(
client: &ClientWalker,
ctx: &crate::RuntimeContext,
) -> Result<std::collections::HashMap<String, serde_json::Value>> {
use anyhow::Context;
(&client.item.elem.options)
.iter()
.map(|(k, v)| {
Ok((
k.into(),
ctx.resolve_expression::<serde_json::Value>(v)
.context(format!(
"client {} could not resolve options.{}",
client.name(),
k
))?,
))
})
.collect::<Result<std::collections::HashMap<_, _>>>()
}
Loading

0 comments on commit c0fb454

Please sign in to comment.