Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes to cURL rendering and mime_type overriding #763

Merged
merged 12 commits into from
Jul 8, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -335,26 +335,21 @@ impl RequestBuilder for AnthropicClient {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
allow_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder> {
let destination_url = if allow_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
} else {
&self.properties.base_url
};
let mut req = self.client.post(if prompt.is_left() {
format!(
"{}/v1/complete",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/v1/complete", destination_url)
} else {
format!(
"{}/v1/messages",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/v1/messages", destination_url)
});

for (key, value) in &self.properties.headers {
Expand All @@ -364,12 +359,9 @@ impl RequestBuilder for AnthropicClient {
req = req.header("x-api-key", key);
}

req = req.header("baml-original-url", self.properties.base_url.as_str());
req = req.header(
"baml-render-url",
format!("{}/v1/messages", self.properties.base_url),
);

if allow_proxy {
req = req.header("baml-original-url", self.properties.base_url.as_str());
}
let mut body = json!(self.properties.properties);
let body_obj = body.as_object_mut().unwrap();
match prompt {
Expand All @@ -384,7 +376,6 @@ impl RequestBuilder for AnthropicClient {
if stream {
body_obj.insert("stream".into(), true.into());
}
log::debug!("Request body: {:#?}", body);

Ok(req.json(&body))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ impl AwsClient {
fn build_request(
&self,
ctx: &RuntimeContext,

chat_messages: &Vec<RenderedChatMessage>,
) -> Result<bedrock::operation::converse::ConverseInput> {
let mut system_message = None;
Expand Down Expand Up @@ -511,7 +512,7 @@ impl TryInto<bedrock::types::Message> for AwsChatMessage<'_> {
ChatMessagePart::Image(media) | ChatMessagePart::Audio(media) => match media {
BamlMedia::Url(_, _) => {
anyhow::bail!(
"BAML internal error: media URL should have been resolved to base64"
"BAML internal error (AWSBedrock): media URL should have been resolved to base64"
)
}
BamlMedia::Base64(BamlMediaType::Image, media) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,34 +276,46 @@ impl RequestBuilder for GoogleClient {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
allow_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder> {
let mut should_stream = "generateContent";
if stream {
should_stream = "streamGenerateContent?alt=sse";
}

let baml_original_url = format!(
"{}/models/{}:{}",
self.properties.base_url,
self.properties.model_id.as_ref().unwrap_or(&"".to_string()),
should_stream
);

let mut req = self.client.post(
let destination_url = if allow_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&baml_original_url)
.clone(),
);
.unwrap_or(&"".to_string())
.clone()
} else {
format!(
"{}/models/{}:{}",
self.properties.base_url,
self.properties.model_id.as_ref().unwrap_or(&"".to_string()),
should_stream
)
};

let mut req = self.client.post(destination_url);

for (key, value) in &self.properties.headers {
req = req.header(key, value);
}

req = req.header("baml-original-url", baml_original_url.clone());
req = req.header("baml-render-url", baml_original_url);
if allow_proxy {
let baml_original_url = format!(
"{}/models/{}:{}",
self.properties.base_url,
self.properties.model_id.as_ref().unwrap_or(&"".to_string()),
should_stream
);

req = req.header("baml-original-url", baml_original_url.clone());
}

req = req.header(
"x-goog-api-key",
self.properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,26 +214,22 @@ impl RequestBuilder for OpenAIClient {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
allow_proxy: bool,

stream: bool,
) -> Result<reqwest::RequestBuilder> {
let destination_url = if allow_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
} else {
&self.properties.base_url
};
let mut req = self.client.post(if prompt.is_left() {
format!(
"{}/completions",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/completions", destination_url)
} else {
format!(
"{}/chat/completions",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/chat/completions", destination_url)
});

if !self.properties.query_params.is_empty() {
Expand All @@ -246,12 +242,11 @@ impl RequestBuilder for OpenAIClient {
if let Some(key) = &self.properties.api_key {
req = req.bearer_auth(key)
}
req = req.header("baml-original-url", self.properties.base_url.as_str());

req = req.header(
"baml-render-url",
format!("{}/chat/completions", self.properties.base_url),
);
if allow_proxy {
req = req.header("baml-original-url", self.properties.base_url.as_str());
}

let mut body = json!(self.properties.properties);
let body_obj = body.as_object_mut().unwrap();
match prompt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait RequestBuilder {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
allow_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder>;

Expand All @@ -38,7 +39,7 @@ pub async fn make_request(
log::debug!("Making request using client {}", client.context().name);

let req = match client
.build_request(prompt, stream)
.build_request(prompt, true, stream)
.await
.context("Failed to build request")
{
Expand Down
Loading
Loading