Skip to content

Commit

Permalink
Add ability for certain models to disable streaming. (#1157)
Browse files Browse the repository at this point in the history
openai/o1-* models by default don't support streaming
*/* by default do support streaming

users can add the `supports_streaming <bool>` as an option to their
client to manually configure this

✅ Raw Curl works
✅ Shorthand clients work well for `o1-*` models without any twiddling
✅ Docs updated
<img width="620" alt="Screenshot 2024-11-11 at 11 46 39 AM"
src="https://github.com/user-attachments/assets/7d1fdd7d-a8ce-4049-87a9-8ef6a19cf759">



<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Add `supports_streaming` option to configure streaming for models,
defaulting `openai/o1-*` to non-streaming, with documentation updates.
> 
>   - **Behavior**:
> - Adds `supports_streaming` option to client configuration to manually
enable/disable streaming.
> - `openai/o1-*` models default to non-streaming; other models default
to streaming.
> - Updates `supports_streaming()` in `WithClientProperties` to reflect
new behavior.
>   - **Code Changes**:
> - Modifies `resolve_properties()` in `anthropic_client.rs`,
`aws_client.rs`, `googleai_client.rs`, `openai_client.rs`, and
`vertex_client.rs` to handle `SupportedRequestModes`.
>     - Adds `SupportedRequestModes` struct in `llm_client/mod.rs`.
>     - Updates `WithStreamable` trait to handle non-streaming fallback.
>   - **Documentation**:
> - Updates various `.mdx` files to document `supports_streaming`
option.
> - Adds new snippets `supports-streaming.mdx` and
`supports-streaming-openai.mdx`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 4e8efd1. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Nov 11, 2024
1 parent 1c96c14 commit 09c6549
Show file tree
Hide file tree
Showing 34 changed files with 390 additions and 78 deletions.
12 changes: 8 additions & 4 deletions engine/baml-lib/schema-ast/src/parser/parse_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,15 @@ fn parse_string_literal(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Expre
if content.contains(' ') {
Expression::StringValue(content, span)
} else {
match Identifier::from((content.as_str(), span.clone())) {
Identifier::Invalid(..) | Identifier::String(..) => {
Expression::StringValue(content, span)
if content.eq("true") || content.eq("false") {
Expression::BoolValue(content.eq("true"), span)
} else {
match Identifier::from((content.as_str(), span.clone())) {
Identifier::Invalid(..) | Identifier::String(..) => {
Expression::StringValue(content, span)
}
identifier => Expression::Identifier(identifier),
}
identifier => Expression::Identifier(identifier),
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ pub enum AllowedMetadata {
Only(HashSet<String>),
}

#[derive(Clone, Serialize, Deserialize)]
pub struct SupportedRequestModes {
// If unset, treat as auto
pub stream: Option<bool>,
}

impl AllowedMetadata {
pub fn is_allowed(&self, key: &str) -> bool {
match self {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::internal::llm_client::{
properties_hander::PropertiesHandler,
traits::{ToProviderMessage, ToProviderMessageExt, WithClientProperties},
AllowedMetadata, ResolveMediaUrls,
AllowedMetadata, ResolveMediaUrls, SupportedRequestModes,
};
use std::collections::HashMap;

Expand Down Expand Up @@ -46,6 +46,7 @@ struct PostRequestProperities {
allowed_metadata: AllowedMetadata,
// These are passed directly to the Anthropic API.
properties: HashMap<String, serde_json::Value>,
supported_request_modes: SupportedRequestModes,
}

// represents client that interacts with the Anthropic API
Expand Down Expand Up @@ -82,13 +83,16 @@ fn resolve_properties(
.entry("anthropic-version".to_string())
.or_insert("2023-06-01".to_string());

let supported_request_modes = properties.pull_supported_request_modes()?;

let mut properties = properties.finalize();
// Anthropic has a very low max_tokens by default, so we increase it to 4096.
properties
.entry("max_tokens".into())
.or_insert_with(|| 4096.into());
let properties = properties;


Ok(PostRequestProperities {
default_role,
base_url,
Expand All @@ -97,6 +101,7 @@ fn resolve_properties(
allowed_metadata,
properties,
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
supported_request_modes,
})
}

Expand All @@ -114,6 +119,9 @@ impl WithClientProperties for AnthropicClient {
fn client_properties(&self) -> &HashMap<String, serde_json::Value> {
&self.properties.properties
}
fn supports_streaming(&self) -> bool {
self.properties.supported_request_modes.stream.unwrap_or(true)
}
}

impl WithClient for AnthropicClient {
Expand Down Expand Up @@ -351,7 +359,7 @@ impl RequestBuilder for AnthropicClient {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.unwrap_or_else(|| &self.properties.base_url)
} else {
&self.properties.base_url
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::collections::HashMap;
use anyhow::Result;
use crate::{
internal::llm_client::{properties_hander::PropertiesHandler, SharedProperties},
RuntimeContext,
};
use super::PostRequestProperties;

pub fn resolve_properties(
mut properties: PropertiesHandler,
ctx: &RuntimeContext,
) -> Result<PostRequestProperties> {
let shared = properties.pull_shared_properties("system");

// Override defaults in shared
let shared = SharedProperties {
base_url: shared.base_url
.map(|url| url.unwrap_or_else(|| "https://api.anthropic.com".into())),
api_key: shared.api_key
.map(|key| key.or_else(|| ctx.env.get("ANTHROPIC_API_KEY").map(|s| s.to_string()))),
headers: shared.headers.map(|mut h| {
h.entry("anthropic-version".to_string())
.or_insert("2023-06-01".to_string());
h
}),
..shared
};

let mut properties = properties.finalize();
properties.entry("max_tokens".into())
.or_insert_with(|| 4096.into());

Ok(PostRequestProperties {
shared,
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
})
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use crate::internal::llm_client::properties_hander::SharedProperties;

pub struct PostRequestProperties {
pub shared: SharedProperties,
pub proxy_url: Option<String>,
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use web_time::Instant;
use web_time::SystemTime;

use crate::internal::llm_client::traits::{ToProviderMessageExt, WithClientProperties};
use crate::internal::llm_client::AllowedMetadata;
use crate::internal::llm_client::{AllowedMetadata, SupportedRequestModes};
use crate::internal::llm_client::{
primitive::request::RequestBuilder,
traits::{
Expand All @@ -44,6 +44,7 @@ struct RequestProperties {

request_options: HashMap<String, serde_json::Value>,
ctx_env: HashMap<String, String>,
supported_request_modes: SupportedRequestModes,
}

// represents client that interacts with the Anthropic API
Expand Down Expand Up @@ -88,6 +89,8 @@ fn resolve_properties(client: &ClientWalker, ctx: &RuntimeContext) -> Result<Req
.remove_str("region")
.unwrap_or_else(|_| ctx.env.get("AWS_REGION").map(|s| s.to_string()));

let supported_request_modes = properties.pull_supported_request_modes()?;

Ok(RequestProperties {
model_id,
aws_region,
Expand All @@ -96,6 +99,7 @@ fn resolve_properties(client: &ClientWalker, ctx: &RuntimeContext) -> Result<Req
allowed_metadata,
request_options: properties.finalize(),
ctx_env: ctx.env.clone(),
supported_request_modes,
})
}

Expand Down Expand Up @@ -301,6 +305,9 @@ impl WithClientProperties for AwsClient {
fn allowed_metadata(&self) -> &crate::internal::llm_client::AllowedMetadata {
&self.properties.allowed_metadata
}
fn supports_streaming(&self) -> bool {
self.properties.supported_request_modes.stream.unwrap_or(true)
}
}

impl WithClient for AwsClient {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::internal::llm_client::properties_hander::{PropertiesHandler};
use crate::internal::llm_client::traits::{
ToProviderMessage, ToProviderMessageExt, WithClientProperties,
};
use crate::internal::llm_client::{AllowedMetadata, ResolveMediaUrls};
use crate::internal::llm_client::{AllowedMetadata, ResolveMediaUrls, SupportedRequestModes};
use crate::RuntimeContext;
use crate::{
internal::llm_client::{
Expand Down Expand Up @@ -38,6 +38,7 @@ struct PostRequestProperities {
model_id: Option<String>,
properties: HashMap<String, serde_json::Value>,
allowed_metadata: AllowedMetadata,
supported_request_modes: SupportedRequestModes,
}

pub struct GoogleAIClient {
Expand Down Expand Up @@ -69,15 +70,18 @@ fn resolve_properties(
let allowed_metadata = properties.pull_allowed_role_metadata()?;
let headers = properties.pull_headers()?;

let supported_request_modes = properties.pull_supported_request_modes()?;

Ok(PostRequestProperities {
default_role,
api_key,
headers,
properties: properties.finalize(),
base_url,
model_id: Some(model_id),
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
model_id: Some(model_id),
properties: properties.finalize(),
allowed_metadata,
supported_request_modes,
})
}

Expand All @@ -94,6 +98,9 @@ impl WithClientProperties for GoogleAIClient {
fn allowed_metadata(&self) -> &crate::internal::llm_client::AllowedMetadata {
&self.properties.allowed_metadata
}
fn supports_streaming(&self) -> bool {
self.properties.supported_request_modes.stream.unwrap_or(true)
}
}

impl WithClient for GoogleAIClient {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::collections::HashMap;
use anyhow::Result;
use crate::{
internal::llm_client::{properties_hander::PropertiesHandler, SharedProperties},
RuntimeContext,
};
use super::PostRequestProperties;

pub fn resolve_properties(
mut properties: PropertiesHandler,
ctx: &RuntimeContext,
) -> Result<PostRequestProperties> {
let shared = properties.pull_shared_properties("user");

// Override defaults in shared
let shared = SharedProperties {
base_url: shared.base_url
.map(|url| url.unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string())),
api_key: shared.api_key
.map(|key| key.or_else(|| ctx.env.get("GOOGLE_API_KEY").map(|s| s.to_string()))),
..shared
};

let model_id = properties.remove_str("model")?
.unwrap_or_else(|| "gemini-1.5-flash".to_string());

Ok(PostRequestProperties {
shared,
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
model_id: Some(model_id),
})
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use crate::internal::llm_client::properties_hander::SharedProperties;

pub struct PostRequestProperties {
pub shared: SharedProperties,
pub proxy_url: Option<String>,
pub model_id: Option<String>,
}
3 changes: 3 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ impl WithClientProperties for LLMPrimitiveProvider {
fn allowed_metadata(&self) -> &super::AllowedMetadata {
match_llm_provider!(self, allowed_metadata)
}
fn supports_streaming(&self) -> bool {
match_llm_provider!(self, supports_streaming)
}
}

impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMPrimitiveProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ impl WithClientProperties for OpenAIClient {
fn allowed_metadata(&self) -> &crate::internal::llm_client::AllowedMetadata {
&self.properties.allowed_metadata
}
fn supports_streaming(&self) -> bool {
match self.properties.supported_request_modes.stream {
Some(v) => v,
None => {
match self.properties.properties.get("model") {
Some(serde_json::Value::String(model)) => {
// OpenAI's streaming is not available for o1-* models
!model.starts_with("o1-")
}
_ => true,
}
}
}
}
}

impl WithClient for OpenAIClient {
Expand Down Expand Up @@ -228,16 +242,11 @@ impl RequestBuilder for OpenAIClient {
allow_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder> {
// Never proxy requests to Ollama
let allow_proxy = allow_proxy
&& self.properties.proxy_url.is_some()
&& !self.properties.base_url.starts_with("http://localhost");

let destination_url = if allow_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.unwrap_or_else(|| &self.properties.base_url)
} else {
&self.properties.base_url
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,18 @@ pub fn resolve_properties(
query_params.insert("api-version".to_string(), v.to_string());
};

let mut properties = properties.finalize();
properties
.entry("max_tokens".into())
let supported_request_modes = properties.pull_supported_request_modes()?;


let properties = {
let mut properties = properties.finalize();
// Azure has very low default max_tokens, so we set it to 4096
properties
.entry("max_tokens".into())
.or_insert_with(|| 4096.into());
let properties = properties;
properties
};


Ok(PostRequestProperties {
default_role,
Expand All @@ -61,9 +68,8 @@ pub fn resolve_properties(
headers,
properties,
allowed_metadata,
// Replace proxy_url with code below to disable proxying
// proxy_url: None,
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
query_params,
supported_request_modes,
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub fn resolve_properties(
Some(api_key) if !api_key.is_empty() => Some(api_key),
_ => None,
};
let supported_request_modes = properties.pull_supported_request_modes()?;

let properties = properties.finalize();

Expand All @@ -36,12 +37,9 @@ pub fn resolve_properties(
api_key,
headers,
properties,
proxy_url: ctx
.env
.get("BOUNDARY_PROXY_URL")
.map(|s| Some(s.to_string()))
.unwrap_or(None),
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
query_params: Default::default(),
allowed_metadata,
supported_request_modes,
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub(crate) mod generic;
pub(crate) mod ollama;
pub(crate) mod openai;

use crate::internal::llm_client::AllowedMetadata;
use crate::internal::llm_client::{AllowedMetadata, SupportedRequestModes};
use std::collections::HashMap;

pub struct PostRequestProperties {
Expand All @@ -16,4 +16,5 @@ pub struct PostRequestProperties {
// These are passed directly to the OpenAI API.
pub properties: HashMap<String, serde_json::Value>,
pub allowed_metadata: AllowedMetadata,
pub supported_request_modes: SupportedRequestModes,
}
Loading

0 comments on commit 09c6549

Please sign in to comment.