Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes to cURL rendering and mime_type overriding #763

Merged
merged 12 commits into from
Jul 8, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -335,26 +335,21 @@ impl RequestBuilder for AnthropicClient {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
should_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder> {
let destiniation_url = if should_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
} else {
&self.properties.base_url
};
let mut req = self.client.post(if prompt.is_left() {
format!(
"{}/v1/complete",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/v1/complete", destiniation_url)
} else {
format!(
"{}/v1/messages",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/v1/messages", destiniation_url)
});

for (key, value) in &self.properties.headers {
Expand All @@ -364,12 +359,9 @@ impl RequestBuilder for AnthropicClient {
req = req.header("x-api-key", key);
}

req = req.header("baml-original-url", self.properties.base_url.as_str());
req = req.header(
"baml-render-url",
format!("{}/v1/messages", self.properties.base_url),
);

if should_proxy {
req = req.header("baml-original-url", self.properties.base_url.as_str());
}
let mut body = json!(self.properties.properties);
let body_obj = body.as_object_mut().unwrap();
match prompt {
Expand All @@ -384,7 +376,6 @@ impl RequestBuilder for AnthropicClient {
if stream {
body_obj.insert("stream".into(), true.into());
}
log::debug!("Request body: {:#?}", body);

Ok(req.json(&body))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ impl AwsClient {
fn build_request(
&self,
ctx: &RuntimeContext,

chat_messages: &Vec<RenderedChatMessage>,
) -> Result<bedrock::operation::converse::ConverseInput> {
let mut system_message = None;
Expand Down Expand Up @@ -511,7 +512,7 @@ impl TryInto<bedrock::types::Message> for AwsChatMessage<'_> {
ChatMessagePart::Image(media) | ChatMessagePart::Audio(media) => match media {
BamlMedia::Url(_, _) => {
anyhow::bail!(
"BAML internal error: media URL should have been resolved to base64"
"BAML internal error (AWSBedrock): media URL should have been resolved to base64"
)
}
BamlMedia::Base64(BamlMediaType::Image, media) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,34 +276,46 @@ impl RequestBuilder for GoogleClient {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
should_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder> {
let mut should_stream = "generateContent";
if stream {
should_stream = "streamGenerateContent?alt=sse";
}

let baml_original_url = format!(
"{}/models/{}:{}",
self.properties.base_url,
self.properties.model_id.as_ref().unwrap_or(&"".to_string()),
should_stream
);

let mut req = self.client.post(
let destination_url = if should_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&baml_original_url)
.clone(),
);
.unwrap_or(&"".to_string())
.clone()
} else {
format!(
"{}/models/{}:{}",
self.properties.base_url,
self.properties.model_id.as_ref().unwrap_or(&"".to_string()),
should_stream
)
};

let mut req = self.client.post(destination_url);

for (key, value) in &self.properties.headers {
req = req.header(key, value);
}

req = req.header("baml-original-url", baml_original_url.clone());
req = req.header("baml-render-url", baml_original_url);
if should_proxy {
let baml_original_url = format!(
"{}/models/{}:{}",
self.properties.base_url,
self.properties.model_id.as_ref().unwrap_or(&"".to_string()),
should_stream
);

req = req.header("baml-original-url", baml_original_url.clone());
}

req = req.header(
"x-goog-api-key",
self.properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,26 +214,22 @@ impl RequestBuilder for OpenAIClient {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
should_proxy: bool,

stream: bool,
) -> Result<reqwest::RequestBuilder> {
let destination_url = if should_proxy {
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
} else {
&self.properties.base_url
};
let mut req = self.client.post(if prompt.is_left() {
format!(
"{}/completions",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/completions", destination_url)
} else {
format!(
"{}/chat/completions",
self.properties
.proxy_url
.as_ref()
.unwrap_or(&self.properties.base_url)
.clone()
)
format!("{}/chat/completions", destination_url)
});

if !self.properties.query_params.is_empty() {
Expand All @@ -246,12 +242,11 @@ impl RequestBuilder for OpenAIClient {
if let Some(key) = &self.properties.api_key {
req = req.bearer_auth(key)
}
req = req.header("baml-original-url", self.properties.base_url.as_str());

req = req.header(
"baml-render-url",
format!("{}/chat/completions", self.properties.base_url),
);
if should_proxy {
req = req.header("baml-original-url", self.properties.base_url.as_str());
}

let mut body = json!(self.properties.properties);
let body_obj = body.as_object_mut().unwrap();
match prompt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait RequestBuilder {
async fn build_request(
&self,
prompt: either::Either<&String, &Vec<RenderedChatMessage>>,
should_proxy: bool,
stream: bool,
) -> Result<reqwest::RequestBuilder>;

Expand All @@ -38,7 +39,7 @@ pub async fn make_request(
log::debug!("Making request using client {}", client.context().name);

let req = match client
.build_request(prompt, stream)
.build_request(prompt, true, stream)
.await
.context("Failed to build request")
{
Expand Down
40 changes: 19 additions & 21 deletions engine/baml-runtime/src/internal/llm_client/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,15 @@ where
match part {
ChatMessagePart::Image(BamlMedia::Url(_, media_url))
| ChatMessagePart::Audio(BamlMedia::Url(_, media_url)) => {
let (base64, mime_type) = if media_url.url.starts_with("data:") {



let (base64, mut mime_type) = if media_url.url.starts_with("data:") {
let parts: Vec<&str> =
media_url.url.splitn(2, ',').collect();
let base64 = parts.get(1).unwrap().to_string();
let prefix = parts.get(0).unwrap();
let mime_type =
let mut mime_type =
prefix.splitn(2, ':').next().unwrap().to_string()
.split('/').last().unwrap().to_string();

Expand Down Expand Up @@ -125,13 +128,17 @@ where
};
let base64 = BASE64_STANDARD.encode(&bytes);
let inferred_type = infer::get(&bytes);
let mime_type = inferred_type.map_or_else(
let mut mime_type = inferred_type.map_or_else(
|| "application/octet-stream".into(),
|t| t.extension().into(),
);
(base64, mime_type)
};

if let Some(media_type) = &media_url.media_type {
mime_type = media_type.clone().split('/').last().unwrap().to_string();
}

Ok(if matches!(part, ChatMessagePart::Image(_)) {
ChatMessagePart::Image(BamlMedia::Base64(
BamlMediaType::Image,
Expand Down Expand Up @@ -403,37 +410,28 @@ where
stream: bool,
) -> Result<String> {
let rendered_prompt = RenderedPrompt::Chat(prompt.clone());

let chat_messages = self.curl_call(ctx, &rendered_prompt).await?;
let request_builder = self
.build_request(either::Right(&chat_messages), stream)
.build_request(either::Right(&chat_messages),false, stream)
.await?;
let mut request = request_builder.build()?;
let request: reqwest::Request = request_builder.build()?;
let url_header_value = {
let headers = request.headers_mut();
let url_header_value = headers
.get("baml-render-url")
.ok_or(anyhow::anyhow!("Missing header 'baml-render-url'"))?;

let url_header_value = request.url();
url_header_value.to_owned()
};

let url_str = url_header_value
.to_str()
.map_err(|_| anyhow::anyhow!("Invalid header 'baml-render-url'"))?;
let new_url = Url::from_str(url_str)?;
*request.url_mut() = new_url;

{
let headers = request.headers_mut();
headers.remove("baml-original-url");
headers.remove("baml-render-url");
}
.to_string();



let body = request
.body()
.map(|b| b.as_bytes().unwrap_or_default().to_vec())
.unwrap_or_default(); // Add this line to handle the Option
let request_str = to_curl_command(url_str, "POST", request.headers(), body);
let request_str = to_curl_command(url_str.as_str(), "POST", request.headers(), body);

Ok(request_str)
}
Expand Down
4 changes: 0 additions & 4 deletions engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ impl WasmPrompt {
#[wasm_bindgen]
pub fn as_chat(&self) -> Option<Vec<WasmChatMessage>> {
if let RenderedPrompt::Chat(s) = &self.prompt {
log::info!(
"Chat role: {:?}",
s.iter().map(|m| m.role.clone()).collect::<Vec<_>>()
);
Some(
s.iter()
.map(|m| WasmChatMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ export const availableFunctionsAtom = atom((get) => {
return runtime.list_functions()
})

export const streamCurl = atom(true)

const asyncCurlAtom = atom(async (get) => {
const runtime = get(selectedRuntimeAtom)
const func = get(selectedFunctionAtom)
Expand All @@ -360,10 +362,10 @@ const asyncCurlAtom = atom(async (get) => {
.map((input) => [input.name, JSON.parse(input.value)]),
)
try {
return await func.render_raw_curl(runtime, params, false)
return await func.render_raw_curl(runtime, params, get(streamCurl))
} catch (e) {
console.error(e)
return 'Error rendering curl command'
return `${e}`
}
})

Expand Down
15 changes: 12 additions & 3 deletions typescript/playground-common/src/shared/FunctionPanel.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/// Content once a function has been selected.
import { useAppState } from './AppStateContext'
import { useAtomValue, useSetAtom } from 'jotai'
import { renderPromptAtom, selectedFunctionAtom, curlAtom } from '../baml_wasm_web/EventListener'
import { renderPromptAtom, selectedFunctionAtom, curlAtom, streamCurl } from '../baml_wasm_web/EventListener'
import TestResults from '../baml_wasm_web/test_uis/test_result'
import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '../components/ui/resizable'
import { TooltipProvider } from '../components/ui/tooltip'
Expand All @@ -10,6 +10,7 @@ import FunctionTestSnippet from './TestSnippet'
import { Copy } from 'lucide-react'
import { Button } from '../components/ui/button'
import { CheckboxHeader } from './CheckboxHeader'
import { Switch } from '../components/ui/switch'
import CustomErrorBoundary from '../utils/ErrorFallback'
const handleCopy = (text: string) => () => {
navigator.clipboard.writeText(text)
Expand All @@ -20,10 +21,18 @@ const CurlSnippet: React.FC = () => {

return (
<div>
<div className='flex justify-end'>
<div className='flex justify-end items-center space-x-2 p-2 rounded-md shadow-sm'>
<label className='flex items-center space-x-1 mr-2'>
<Switch
className='data-[state=checked]:bg-vscode-button-background'
checked={useAtomValue(streamCurl)}
onCheckedChange={useSetAtom(streamCurl)}
/>
<span>View Stream Request</span>
</label>
<Button
onClick={handleCopy(rawCurl)}
className='py-0 m-0 text-xs text-white bg-transparent copy-button hover:bg-indigo-500'
className='py-1 px-3 text-xs text-white bg-indigo-600 rounded-md hover:bg-indigo-500'
>
<Copy size={16} />
</Button>
Expand Down
2 changes: 1 addition & 1 deletion typescript/vscode-ext/packages/vscode/src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export function activate(context: vscode.ExtensionContext) {
let originalUrl = req.headers['baml-original-url']
if (typeof originalUrl === 'string') {
delete req.headers['baml-original-url']
delete req.headers['baml-render-url']

req.headers['origin'] = `http://localhost:${port}`

// Ensure the URL does not end with a slash
Expand Down
Loading