Skip to content

Commit

Permalink
Add support for absolute file paths in the file specifier (#881)
Browse files Browse the repository at this point in the history
  • Loading branch information
hellovai authored Aug 15, 2024
1 parent a8d3479 commit fcd189e
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 98 deletions.
2 changes: 1 addition & 1 deletion engine/baml-runtime/src/internal/llm_client/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ where
RenderedPrompt::Chat(chat) => {
// We never need to resolve media URLs here: webview rendering understands how to handle URLs and file refs
let chat =
process_media_urls(ResolveMediaUrls::Never, false, None, ctx, &chat).await?;
process_media_urls(ResolveMediaUrls::Never, true, None, ctx, &chat).await?;
RenderedPrompt::Chat(chat)
}
};
Expand Down
6 changes: 5 additions & 1 deletion engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1371,11 +1371,15 @@ impl WasmFunction {
rt: &WasmRuntime,
test_name: String,
wasm_call_context: &WasmCallContext,
get_baml_src_cb: js_sys::Function,
) -> JsResult<WasmPrompt> {
let missing_env_vars = rt.runtime.internal().ir().required_env_vars();
let ctx = rt
.runtime
.create_ctx_manager(BamlValue::String("wasm".to_string()), None)
.create_ctx_manager(
BamlValue::String("wasm".to_string()),
js_fn_to_baml_src_reader(get_baml_src_cb),
)
.create_ctx_with_default(missing_env_vars.iter());

let params = rt
Expand Down
63 changes: 38 additions & 25 deletions engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ impl From<ChatMessagePart> for WasmChatMessagePart {
}
}

#[wasm_bindgen]
#[derive(Clone, Copy)]
pub enum WasmChatMessagePartMediaType {
Url,
File,
Error,
}

#[wasm_bindgen(getter_with_clone)]
pub struct WasmChatMessagePartMedia {
pub r#type: WasmChatMessagePartMediaType,
pub content: String,
}

#[wasm_bindgen]
impl WasmChatMessagePart {
#[wasm_bindgen]
Expand Down Expand Up @@ -94,35 +108,34 @@ impl WasmChatMessagePart {
}

#[wasm_bindgen]
// TODO: this needs to signal to TS how it should be rendered
// currently we're only rendering file paths, but also need to support url & b64
pub fn as_media(&self) -> JsValue {
pub fn as_media(&self) -> Option<WasmChatMessagePartMedia> {
let ChatMessagePart::Media(m) = &self.part else {
return JsValue::NULL;
return None;
};
match &m.content {
BamlMediaContent::Url(u) => json!({
"type": "url",
"url": u.url.clone(),

}),
BamlMediaContent::Base64(MediaBase64 { base64 }) => json!({
"type": "url",
"url": format!("data:{};base64,{}", m.mime_type.as_deref().unwrap_or(""), base64.clone())
}),
Some(match &m.content {
BamlMediaContent::Url(u) => WasmChatMessagePartMedia {
r#type: WasmChatMessagePartMediaType::Url,
content: u.url.clone(),
},
BamlMediaContent::Base64(MediaBase64 { base64 }) => WasmChatMessagePartMedia {
r#type: WasmChatMessagePartMediaType::Url,
content: format!(
"data:{};base64,{}",
m.mime_type.as_deref().unwrap_or("type/unknown"),
base64.clone()
),
},
BamlMediaContent::File(f) => match f.path() {
Ok(path) => json!({
"type": "path",
"path": path.to_string_lossy().into_owned(),
}),
Err(e) => json!({
"type": "error",
"error": format!("Error resolving file '{}': {:#}", f.relpath.display(), e),
}),
Ok(path) => WasmChatMessagePartMedia {
r#type: WasmChatMessagePartMediaType::File,
content: path.to_string_lossy().into_owned(),
},
Err(e) => WasmChatMessagePartMedia {
r#type: WasmChatMessagePartMediaType::Error,
content: format!("Error resolving file '{}': {:#}", f.relpath.display(), e),
},
},
}
.serialize(&serde_wasm_bindgen::Serializer::json_compatible())
.unwrap_or(JsValue::NULL)
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ const renderPromptAtomAsync = atom(async (get) => {
wasmCallContext.node_index = orch_index

try {
return await func.render_prompt_for_test(runtime, test_case.name, wasmCallContext)
return await func.render_prompt_for_test(runtime, test_case.name, wasmCallContext, async (path: string) => {
return await vscode.readFile(path)
})
} catch (e) {
if (e instanceof Error) {
return e.message
Expand Down
57 changes: 57 additions & 0 deletions typescript/playground-common/src/baml_wasm_web/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ export interface GetWebviewUriRequest {
vscodeCommand: 'GET_WEBVIEW_URI'
bamlSrc: string
path: string
contents?: true
}

export interface GetWebviewUriResponse {
uri: string
contents?: string
readError?: string
}

type ApiPairs = [
Expand All @@ -98,4 +101,58 @@ type ApiPairs = [
[GetWebviewUriRequest, GetWebviewUriResponse],
]

// Serialization for binary data (like images)
function serializeBinaryData(uint8Array: Uint8Array): string {
return uint8Array.reduce((data, byte) => data + String.fromCharCode(byte), '')
}

// Deserialization for binary data
function deserializeBinaryData(serialized: string): Uint8Array {
return new Uint8Array(serialized.split('').map((char) => char.charCodeAt(0)))
}

// Base64 encoding
function base64Encode(str: string): string {
const base64chars: string = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'
let result: string = ''
let i: number
for (i = 0; i < str.length; i += 3) {
const chunk: number = (str.charCodeAt(i) << 16) | (str.charCodeAt(i + 1) << 8) | str.charCodeAt(i + 2)
result +=
base64chars.charAt((chunk & 16515072) >> 18) +
base64chars.charAt((chunk & 258048) >> 12) +
base64chars.charAt((chunk & 4032) >> 6) +
base64chars.charAt(chunk & 63)
}
if (str.length % 3 === 1) result = result.slice(0, -2) + '=='
if (str.length % 3 === 2) result = result.slice(0, -1) + '='
return result
}

// Base64 decoding
function base64Decode(str: string): string {
const base64chars: string = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'
while (str[str.length - 1] === '=') {
str = str.slice(0, -1)
}
let result: string = ''
for (let i = 0; i < str.length; i += 4) {
const chunk: number =
(base64chars.indexOf(str[i]) << 18) |
(base64chars.indexOf(str[i + 1]) << 12) |
(base64chars.indexOf(str[i + 2]) << 6) |
base64chars.indexOf(str[i + 3])
result += String.fromCharCode((chunk & 16711680) >> 16, (chunk & 65280) >> 8, chunk & 255)
}
return result.slice(0, result.length - (str.length % 4 ? 4 - (str.length % 4) : 0))
}

export function encodeBuffer(arr: Uint8Array): string {
return serializeBinaryData(arr)
}

export function decodeBuffer(str: string): Uint8Array {
return deserializeBinaryData(str)
}

export type WebviewToVscodeRpc = RequestUnion<ApiPairs>
57 changes: 29 additions & 28 deletions typescript/playground-common/src/shared/FunctionPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ import {
streamCurlAtom,
rawCurlLoadable,
} from '../baml_wasm_web/EventListener'
import {
// We _deliberately_ only import types from wasm, instead of importing the module: wasm load is async,
// so we can only load wasm symbols through wasmAtom, not directly by importing wasm-schema-web
type WasmChatMessagePartMedia,
WasmChatMessagePartMediaType,
} from '@gloo-ai/baml-schema-wasm-web/baml_schema_build'
import TestResults from '../baml_wasm_web/test_uis/test_result'
import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '../components/ui/resizable'
import { TooltipProvider } from '../components/ui/tooltip'
Expand Down Expand Up @@ -124,31 +130,25 @@ const CurlSnippet: React.FC = () => {
)
}

type WasmChatMessagePartMedia =
| {
type: 'url'
url: string
}
| {
type: 'path'
path: string
}

const WebviewMedia: React.FC<{ bamlMediaType: 'image' | 'audio'; media: WasmChatMessagePartMedia }> = ({
bamlMediaType,
media,
}) => {
const pathAsUri = useSWR({ swr: 'WebviewMedia', ...media }, async () => {
const pathAsUri = useSWR({ swr: 'WebviewMedia', type: media.type, content: media.content }, async () => {
switch (media.type) {
case 'path':
const uri = await vscode.asWebviewUri('', media.path)
// Do a manual check to assert that the image exists
if ((await fetch(uri, { method: 'HEAD' })).status !== 200) {
throw new Error('file not found')
}
return uri
case 'url':
return media.url
case WasmChatMessagePartMediaType.File:
// const uri = await vscode.readFile('', media.content)
// // Do a manual check to assert that the image exists
// if ((await fetch(uri, { method: 'HEAD' })).status !== 200) {
// throw new Error('file not found')
// }
return `file://${media.content}`
case WasmChatMessagePartMediaType.Url:
return media.content
case WasmChatMessagePartMediaType.Error:
return { error: media.content }
default:
return { error: 'unknown media type' }
}
})

Expand All @@ -159,7 +159,6 @@ const WebviewMedia: React.FC<{ bamlMediaType: 'image' | 'audio'; media: WasmChat
<div>
Error loading {bamlMediaType}: {error}
</div>
<div>{media.type === 'path' ? media.path.replace('file://', '') : media.url}</div>
</div>
)
}
Expand All @@ -168,7 +167,7 @@ const WebviewMedia: React.FC<{ bamlMediaType: 'image' | 'audio'; media: WasmChat
return <div>Loading {bamlMediaType}...</div>
}

const mediaUrl = pathAsUri.data
const mediaUrl = pathAsUri.data as unknown as string

return (
<div className='p-1'>
Expand Down Expand Up @@ -254,15 +253,17 @@ const PromptPreview: React.FC = () => {
)
if (part.is_image()) {
const media = part.as_media()
if (!media) return <div>Error loading image: this chat message part is not media</div>
if (media.type === 'error') return <div>Error loading image: {media.error}</div>
return <WebviewMedia key={idx} bamlMediaType='image' media={part.as_media()} />
if (!media) return <div key={idx}>Error loading image: this chat message part is not media</div>
if (media.type === WasmChatMessagePartMediaType.Error)
return <div key={idx}>Error loading image 1: {media.content}</div>
return <WebviewMedia key={idx} bamlMediaType='image' media={media} />
}
if (part.is_audio()) {
const media = part.as_media()
if (!media) return <div>Error loading audio: this chat message part is not media</div>
if (media.type === 'error') return <div>Error loading audio: {media.error}</div>
return <WebviewMedia key={idx} bamlMediaType='audio' media={part.as_media()} />
if (!media) return <div key={idx}>Error loading audio: this chat message part is not media</div>
if (media.type === WasmChatMessagePartMediaType.Error)
return <div key={idx}>Error loading audio 1: {media.content}</div>
return <WebviewMedia key={idx} bamlMediaType='audio' media={media} />
}
return null
})}
Expand Down
34 changes: 22 additions & 12 deletions typescript/playground-common/src/utils/vscode.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { GetWebviewUriRequest, GetWebviewUriResponse } from '../baml_wasm_web/rpc'
import { decodeBuffer, GetWebviewUriRequest, GetWebviewUriResponse } from '../baml_wasm_web/rpc'
import type { WebviewApi } from 'vscode-webview'

const RPC_TIMEOUT_MS = 5000
Expand Down Expand Up @@ -47,20 +47,30 @@ class VSCodeAPIWrapper {
}

public async readFile(path: string): Promise<Uint8Array> {
const uri = await this.asWebviewUri('', path)
const resp = await fetch(uri)
const uri = await this.readLocalFile('', path)
console.log('read file', uri)

if (!resp.ok) {
if (resp.status === 404) {
throw new Error(`File does not exist: '${path}'`)
}
throw new Error(`Fetch via vscode resulted in status=${resp.status} (see network logs for more details)`)
if (uri.readError) {
throw new Error(`Failed to read file: ${path}\n${uri.readError}`)
}
if (uri.contents) {
const contents = uri.contents
// throw new Error(`not implemented: ${Array.isArray(contents)}: \n ${JSON.stringify(contents)}`)
return decodeBuffer(contents)
}

const blob = await resp.blob()
const arrayBuffer = await blob.arrayBuffer()
throw new Error(`Unknown error: '${path}'`)
}

async readLocalFile(bamlSrc: string, path: string): Promise<GetWebviewUriResponse> {
const resp = await this.rpc<GetWebviewUriRequest, GetWebviewUriResponse>({
vscodeCommand: 'GET_WEBVIEW_URI',
bamlSrc,
path,
contents: true,
})

return new Uint8Array(arrayBuffer)
return resp
}

public async asWebviewUri(bamlSrc: string, path: string): Promise<string> {
Expand All @@ -79,7 +89,7 @@ class VSCodeAPIWrapper {
this.rpcTable.set(rpcId, { resolve: resolve as (resp: unknown) => void })

this.postMessage({
rpcMethod: (data as any).vscodeCommand,
rpcMethod: (data as unknown as { vscodeCommand: string }).vscodeCommand,
rpcId,
data,
})
Expand Down
3 changes: 2 additions & 1 deletion typescript/playground-common/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
"lib": [
"dom",
"dom.iterable",
"esnext"
"esnext",
"ES2015"
],
"useDefineForClassFields": true,
"allowJs": true,
Expand Down
Loading

0 comments on commit fcd189e

Please sign in to comment.