Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes to cURL rendering and mime_type overriding #763

Merged
merged 12 commits into from
Jul 8, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ impl TryInto<bedrock::types::Message> for AwsChatMessage<'_> {
ChatMessagePart::Image(media) | ChatMessagePart::Audio(media) => match media {
BamlMedia::Url(_, _) => {
anyhow::bail!(
"BAML internal error: media URL should have been resolved to base64"
"BAML internal error (AWSBedrock): media URL should have been resolved to base64"
)
}
BamlMedia::Base64(BamlMediaType::Image, media) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ impl RequestBuilder for OpenAIClient {
)
});

log::info!("Base url: {}", &self.properties.base_url);
log::info!("Proxy url: {:?}", &self.properties.proxy_url);

if !self.properties.query_params.is_empty() {
req = req.query(&self.properties.query_params);
}
Expand All @@ -248,10 +251,18 @@ impl RequestBuilder for OpenAIClient {
}
req = req.header("baml-original-url", self.properties.base_url.as_str());

req = req.header(
"baml-render-url",
format!("{}/chat/completions", self.properties.base_url),
);
let mut url =
reqwest::Url::parse(&format!("{}/chat/completions", self.properties.base_url))?;
{
let mut pairs = url.query_pairs_mut();
for (key, value) in &self.properties.query_params {
pairs.append_pair(key, value);
}
}
let full_url_with_query = url.to_string();

req = req.header("baml-render-url", full_url_with_query);

let mut body = json!(self.properties.properties);
let body_obj = body.as_object_mut().unwrap();
match prompt {
Expand Down
29 changes: 19 additions & 10 deletions engine/baml-runtime/src/internal/llm_client/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,15 @@ where
match part {
ChatMessagePart::Image(BamlMedia::Url(_, media_url))
| ChatMessagePart::Audio(BamlMedia::Url(_, media_url)) => {
let (base64, mime_type) = if media_url.url.starts_with("data:") {



let (base64, mut mime_type) = if media_url.url.starts_with("data:") {
let parts: Vec<&str> =
media_url.url.splitn(2, ',').collect();
let base64 = parts.get(1).unwrap().to_string();
let prefix = parts.get(0).unwrap();
let mime_type =
let mut mime_type =
prefix.splitn(2, ':').next().unwrap().to_string()
.split('/').last().unwrap().to_string();

Expand Down Expand Up @@ -125,13 +128,17 @@ where
};
let base64 = BASE64_STANDARD.encode(&bytes);
let inferred_type = infer::get(&bytes);
let mime_type = inferred_type.map_or_else(
let mut mime_type = inferred_type.map_or_else(
|| "application/octet-stream".into(),
|t| t.extension().into(),
);
(base64, mime_type)
};

if let Some(media_type) = &media_url.media_type {
mime_type = media_type.clone().split('/').last().unwrap().to_string();
}

Ok(if matches!(part, ChatMessagePart::Image(_)) {
ChatMessagePart::Image(BamlMedia::Base64(
BamlMediaType::Image,
Expand Down Expand Up @@ -403,7 +410,7 @@ where
stream: bool,
) -> Result<String> {
let rendered_prompt = RenderedPrompt::Chat(prompt.clone());

log::info!("Stream is {}", stream);
let chat_messages = self.curl_call(ctx, &rendered_prompt).await?;
let request_builder = self
.build_request(either::Right(&chat_messages), stream)
Expand All @@ -420,14 +427,16 @@ where
let url_str = url_header_value
.to_str()
.map_err(|_| anyhow::anyhow!("Invalid header 'baml-render-url'"))?;
let new_url = Url::from_str(url_str)?;
let mut new_url = Url::from_str(url_str)?;
new_url.set_query(request.url().query()); // Preserve query parameters

*request.url_mut() = new_url;

{
let headers = request.headers_mut();
headers.remove("baml-original-url");
headers.remove("baml-render-url");
}

let headers = request.headers_mut();
headers.remove("baml-original-url");
headers.remove("baml-render-url");


let body = request
.body()
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use wasm_bindgen::prelude::*;

#[wasm_bindgen(start)]
pub fn on_wasm_init() {
match console_log::init_with_level(log::Level::Warn) {
match console_log::init_with_level(log::Level::Info) {
Ok(_) => web_sys::console::log_1(&"Initialized BAML runtime logging".into()),
Err(e) => web_sys::console::log_1(
&format!("Failed to initialize BAML runtime logging: {:?}", e).into(),
Expand Down
4 changes: 0 additions & 4 deletions engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ impl WasmPrompt {
#[wasm_bindgen]
pub fn as_chat(&self) -> Option<Vec<WasmChatMessage>> {
if let RenderedPrompt::Chat(s) = &self.prompt {
log::info!(
"Chat role: {:?}",
s.iter().map(|m| m.role.clone()).collect::<Vec<_>>()
);
Some(
s.iter()
.map(|m| WasmChatMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ export const availableFunctionsAtom = atom((get) => {
return runtime.list_functions()
})

export const streamCurl = atom(true)

const asyncCurlAtom = atom(async (get) => {
const runtime = get(selectedRuntimeAtom)
const func = get(selectedFunctionAtom)
Expand All @@ -360,10 +362,10 @@ const asyncCurlAtom = atom(async (get) => {
.map((input) => [input.name, JSON.parse(input.value)]),
)
try {
return await func.render_raw_curl(runtime, params, false)
return await func.render_raw_curl(runtime, params, get(streamCurl))
} catch (e) {
console.error(e)
return 'Error rendering curl command'
return `${e}`
}
})

Expand Down
15 changes: 12 additions & 3 deletions typescript/playground-common/src/shared/FunctionPanel.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/// Content once a function has been selected.
import { useAppState } from './AppStateContext'
import { useAtomValue, useSetAtom } from 'jotai'
import { renderPromptAtom, selectedFunctionAtom, curlAtom } from '../baml_wasm_web/EventListener'
import { renderPromptAtom, selectedFunctionAtom, curlAtom, streamCurl } from '../baml_wasm_web/EventListener'
import TestResults from '../baml_wasm_web/test_uis/test_result'
import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '../components/ui/resizable'
import { TooltipProvider } from '../components/ui/tooltip'
Expand All @@ -10,6 +10,7 @@ import FunctionTestSnippet from './TestSnippet'
import { Copy } from 'lucide-react'
import { Button } from '../components/ui/button'
import { CheckboxHeader } from './CheckboxHeader'
import { Switch } from '../components/ui/switch'
import CustomErrorBoundary from '../utils/ErrorFallback'
const handleCopy = (text: string) => () => {
navigator.clipboard.writeText(text)
Expand All @@ -20,10 +21,18 @@ const CurlSnippet: React.FC = () => {

return (
<div>
<div className='flex justify-end'>
<div className='flex justify-end items-center space-x-2 p-2 rounded-md shadow-sm'>
<label className='flex items-center space-x-1 mr-2'>
<Switch
className='data-[state=checked]:bg-vscode-button-background'
checked={useAtomValue(streamCurl)}
onCheckedChange={useSetAtom(streamCurl)}
/>
<span>View Stream Request</span>
</label>
<Button
onClick={handleCopy(rawCurl)}
className='py-0 m-0 text-xs text-white bg-transparent copy-button hover:bg-indigo-500'
className='py-1 px-3 text-xs text-white bg-indigo-600 rounded-md hover:bg-indigo-500'
>
<Copy size={16} />
</Button>
Expand Down
Loading