Skip to content

Commit

Permalink
Fix minor bugs (#728)
Browse files Browse the repository at this point in the history
* await in streaming 
* bedrock curl request 
* expose BamlLogEvent
  • Loading branch information
hellovai authored Jun 28, 2024
1 parent 8b2a25b commit be74999
Show file tree
Hide file tree
Showing 14 changed files with 82 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ impl WithRenderRawCurl for AwsClient {
// TODO(sam): this is fucked up. The SDK actually hides all the serializers inside the crate and doesn't let the user access them.

Ok(format!(
"aws bedrock converse --model-id {} --messages {} {}",
"Note, this is not yet complete!\n\nSee: https://docs.aws.amazon.com/cli/latest/reference/bedrock-runtime/converse.html\n\naws bedrock converse --model-id {} --messages {} {}",
converse_input.model_id.unwrap_or("<model_id>".to_string()),
"<messages>",
"TODO"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { BamlLogEvent } from '@boundaryml/baml';
import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX } from './globals';

const traceAsync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnAync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const traceSync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnSync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const setTags = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsertTags.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const flush = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const onLogEvent = (...args: Parameters<typeof DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent>) =>
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent(...args);
const traceAsync =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnAync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const traceSync =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnSync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const setTags =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsertTags.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const flush =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const onLogEvent = (callback: (event: BamlLogEvent) => void) =>
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent(callback)

export { traceAsync, traceSync, setTags, flush, onLogEvent }
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import functools
import inspect
import typing
from .baml_py import RuntimeContextManager, BamlRuntime, BamlSpan
from .baml_py import BamlLogEvent, RuntimeContextManager, BamlRuntime, BamlSpan
import atexit
import threading

Expand Down Expand Up @@ -65,7 +65,7 @@ def end_trace(self, span: BamlSpan, response: typing.Any) -> None:
def flush(self) -> None:
self.rt.flush()

def on_log_event(self, handler: typing.Callable[[str], None]) -> None:
def on_log_event(self, handler: typing.Callable[[BamlLogEvent], None]) -> None:
self.rt.set_log_event_callback(handler)

def trace_fn(self, func: F) -> F:
Expand Down
16 changes: 8 additions & 8 deletions engine/language_client_python/python_src/baml_py/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
final_coerce: Callable[[FunctionResult], FinalOutputType],
ctx_manager: RuntimeContextManager,
tb: Optional[TypeBuilder],

):
self.__ffi_stream = ffi_stream.on_event(self.__enqueue)
self.__partial_coerce = partial_coerce
Expand All @@ -48,10 +47,10 @@ def __enqueue(self, data: FunctionResult) -> None:
self.__event_queue.put_nowait(data)

async def __drive_to_completion(self) -> FunctionResult:

try:
retval = await self.__ffi_stream.done(self.__ctx_manager)

self.__future.set_result(retval)
return retval
except Exception as e:
Expand All @@ -62,23 +61,24 @@ async def __drive_to_completion(self) -> FunctionResult:

def __drive_to_completion_in_bg(self) -> concurrent.futures.Future[FunctionResult]:
if self.__task is None:
self.__task = threading.Thread(target = self.threading_target, daemon=True)
self.__task = threading.Thread(target=self.threading_target, daemon=True)
self.__task.start()
return self.__future

def threading_target(self):
asyncio.run(self.__drive_to_completion(), debug=True)


async def __aiter__(self):
# TODO: This is deliberately __aiter__ and not __iter__ because we want to
# ensure that the caller is using an async for loop.
# Eventually we do not want to create a new thread for each stream.
self.__drive_to_completion_in_bg()
while True:
event = await self.__event_queue.get()
event = self.__event_queue.get()
if event is None:

break
yield self.__partial_coerce(event.parsed())

async def get_final_response(self):
final = self.__drive_to_completion_in_bg()
return self.__final_coerce((await asyncio.wrap_future(final)).parsed())
return self.__final_coerce((await asyncio.wrap_future(final)).parsed())
2 changes: 2 additions & 0 deletions engine/language_client_python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ fn baml_py(_: Python<'_>, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<types::EnumValueBuilder>()?;
m.add_class::<types::ClassPropertyBuilder>()?;
m.add_class::<types::FieldType>()?;
m.add_class::<types::BamlLogEvent>()?;
m.add_class::<types::LogEventMetadata>()?;

m.add_wrapped(wrap_pyfunction!(invoke_runtime_cli))?;

Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_python/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub use function_result_stream::FunctionResultStream;
pub use function_results::FunctionResult;
pub use image::BamlImagePy;

pub use runtime::BamlRuntime;
pub use runtime::{BamlLogEvent, BamlRuntime, LogEventMetadata};
pub use runtime_ctx_manager::RuntimeContextManager;
pub use span::BamlSpan;
pub use type_builder::*;
5 changes: 2 additions & 3 deletions engine/language_client_typescript/async_context_vars.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { BamlLogEvent } from '../native';
import { BamlSpan, RuntimeContextManager, BamlRuntime } from './native';
import { BamlSpan, RuntimeContextManager, BamlRuntime, BamlLogEvent } from './native';
export declare class CtxManager {
private rt;
private ctx;
Expand All @@ -10,7 +9,7 @@ export declare class CtxManager {
startTraceAsync(name: string, args: Record<string, any>): BamlSpan;
endTrace(span: BamlSpan, response: any): void;
flush(): void;
onLogEvent(callback: (error: any, event: BamlLogEvent) => void): void;
onLogEvent(callback: (event: BamlLogEvent) => void): void;
traceFnSync<ReturnType, F extends (...args: any[]) => ReturnType>(name: string, func: F): F;
traceFnAync<ReturnType, F extends (...args: any[]) => Promise<ReturnType>>(name: string, func: F): F;
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion engine/language_client_typescript/async_context_vars.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ class CtxManager {
this.rt.flush();
}
onLogEvent(callback) {
this.rt.setLogEventCallback(callback);
this.rt.setLogEventCallback((error, param) => {
if (!error) {
callback(param);
}
});
}
traceFnSync(name, func) {
return ((...args) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { BamlLogEvent } from '../native'
import { BamlSpan, RuntimeContextManager, BamlRuntime } from './native'
import { BamlSpan, RuntimeContextManager, BamlRuntime, BamlLogEvent } from './native'
import { AsyncLocalStorage } from 'async_hooks'

export class CtxManager {
Expand Down Expand Up @@ -56,8 +55,12 @@ export class CtxManager {
this.rt.flush()
}

onLogEvent(callback: (error: any, event: BamlLogEvent) => void): void {
this.rt.setLogEventCallback(callback)
onLogEvent(callback: (event: BamlLogEvent) => void): void {
this.rt.setLogEventCallback((error: any, param: BamlLogEvent) => {
if (!error) {
callback(param)
}
})
}

traceFnSync<ReturnType, F extends (...args: any[]) => ReturnType>(name: string, func: F): F {
Expand Down
41 changes: 27 additions & 14 deletions integ-tests/python/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import pytest
from dotenv import load_dotenv
from base64_test_data import image_b64, audio_b64
Expand Down Expand Up @@ -62,7 +63,7 @@ class MyCustomClass(NamedArgsSingleClass):
@pytest.mark.asyncio
async def accepts_subclass_of_baml_type():
print("calling with class")
res = await b.TestFnNamedArgsSingleClass(
_ = await b.TestFnNamedArgsSingleClass(
myArg=MyCustomClass(
key="key", key_two=True, key_three=52, date=datetime.datetime.now()
)
Expand Down Expand Up @@ -107,7 +108,7 @@ async def test_should_work_with_image_url():
@pytest.mark.asyncio
async def test_should_work_with_image_base64():
res = await b.TestImageInput(img=baml_py.Image.from_base64("image/png", image_b64))
assert "green" in res.lower()
assert "green" in res.lower() or "orge" in res.lower()


@pytest.mark.asyncio
Expand Down Expand Up @@ -153,18 +154,21 @@ async def test_gemini():
print(f"LLM output from Gemini: {geminiRes}")
assert len(geminiRes) > 0, "Expected non-empty result but got empty."


@pytest.mark.asyncio
async def test_gemini_streaming():
geminiRes = await b.stream.TestGemini(input="Dr. Pepper").get_final_response()
print(f"LLM output from Gemini: {geminiRes}")

assert len(geminiRes) > 0, "Expected non-empty result but got empty."


@pytest.mark.asyncio
async def test_aws():
res = await b.TestAws(input="Mt Rainier is tall")
assert len(res) > 0, "Expected non-empty result but got empty."


@pytest.mark.asyncio
async def test_aws_streaming():
res = await b.stream.TestAws(input="Mt Rainier is tall").get_final_response()
Expand All @@ -173,7 +177,9 @@ async def test_aws_streaming():

@pytest.mark.asyncio
async def test_streaming():
stream = b.stream.PromptTestStreaming(input="Programming languages are fun to create")
stream = b.stream.PromptTestStreaming(
input="Programming languages are fun to create"
)
msgs = []

start_time = asyncio.get_event_loop().time()
Expand All @@ -182,14 +188,17 @@ async def test_streaming():
msgs.append(msg)
if len(msgs) == 1:
first_msg_time = asyncio.get_event_loop().time()

last_msg_time = asyncio.get_event_loop().time()


final = await stream.get_final_response()

assert first_msg_time - start_time <= 1.5, "Expected first message within 1 second but it took longer."
assert last_msg_time - start_time >= 1, "Expected last message after 1.5 seconds but it was earlier."
assert (
first_msg_time - start_time <= 1.5
), "Expected first message within 1 second but it took longer."
assert (
last_msg_time - start_time >= 1
), "Expected last message after 1.5 seconds but it was earlier."
assert len(final) > 0, "Expected non-empty final but got empty."
assert len(msgs) > 0, "Expected at least one streamed response but got none."
for prev_msg, msg in zip(msgs, msgs[1:]):
Expand All @@ -201,9 +210,10 @@ async def test_streaming():
)
assert msgs[-1] == final, "Expected last stream message to match final response."


@pytest.mark.asyncio
async def test_streaming_uniterated():
final = await b.stream.PromptTestOpenAI(
final = await b.stream.PromptTestStreaming(
input="The color blue makes me sad"
).get_final_response()
assert len(final) > 0, "Expected non-empty final but got empty."
Expand Down Expand Up @@ -236,9 +246,10 @@ async def test_streaming_claude():
@pytest.mark.asyncio
async def test_streaming_gemini():
stream = b.stream.TestGemini(input="Dr.Pepper")
msgs = []
msgs: List[str] = []
async for msg in stream:
msgs.append(msg)
if msg is not None:
msgs.append(msg)
final = await stream.get_final_response()

assert len(final) > 0, "Expected non-empty final but got empty."
Expand Down Expand Up @@ -305,7 +316,7 @@ async def trace_thread_pool_async():
# Create 10 tasks and execute them
futures = [executor.submit(trace_async_gather) for _ in range(10)]
for future in concurrent.futures.as_completed(futures):
res = await future.result()
_ = await future.result()


@trace
Expand Down Expand Up @@ -504,28 +515,30 @@ async def test_nested_class_streaming():
assert len(msgs) > 0, "Expected at least one streamed response but got none."
print("final ", final.model_dump(mode="json"))


@pytest.mark.asyncio
async def test_event_log_hook():
def event_log_hook(event):
def event_log_hook(event: baml_py.baml_py.BamlLogEvent):
print("Event log hook1: ")
print("Event log event ", event)

on_log_event(event_log_hook)
res = await b.TestFnNamedArgsSingleStringList(["a", "b", "c"])
assert res


@pytest.mark.asyncio
async def test_aws_bedrock():
## unstreamed
# res = await b.TestAws("lightning in a rock")
# print("unstreamed", res)


## streamed
stream = b.stream.TestAws("lightning in a rock")

async for msg in stream:
print("streamed ", repr(msg[-100:]))
if msg:
print("streamed ", repr(msg[-100:]))

res = await stream.get_final_response()
print("streamed final", res)
Expand Down
17 changes: 11 additions & 6 deletions integ-tests/typescript/baml_client/tracing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ $ pnpm add @boundaryml/baml
// @ts-nocheck
// biome-ignore format: autogenerated code
/* eslint-disable */
import { BamlLogEvent } from '@boundaryml/baml';
import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX } from './globals';

const traceAsync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnAync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const traceSync = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnSync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const setTags = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsertTags.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const flush = DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const onLogEvent = (...args: Parameters<typeof DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent>) =>
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent(...args);
const traceAsync =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnAync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const traceSync =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.traceFnSync.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const setTags =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.upsertTags.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const flush =
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.flush.bind(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
const onLogEvent = (callback: (event: BamlLogEvent) => void) =>
DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX.onLogEvent(callback)

export { traceAsync, traceSync, setTags, flush, onLogEvent }
2 changes: 1 addition & 1 deletion integ-tests/typescript/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"test": "jest",
"build:debug": "cd ../../engine/language_client_typescript && pnpm run build:debug && cd - && pnpm i",
"build": "cd ../../engine/language_client_typescript && npm run build && cd - && pnpm i",
"integ-tests": "BAML_LOG=info infisical run --env=dev -- pnpm test -- --silent false --testTimeout 30000",
"integ-tests": "BAML_LOG=info infisical run --env=test -- pnpm test -- --silent false --testTimeout 30000",
"integ-tests:dotenv": "BAML_LOG=info dotenv -e ../.env -- pnpm test -- --silent false --testTimeout 30000",
"generate": "baml-cli generate --from ../baml_src"
},
Expand Down
5 changes: 3 additions & 2 deletions integ-tests/typescript/tests/integ-tests.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import TypeBuilder from '../baml_client/type_builder'
import { RecursivePartialNull } from '../baml_client/client'
import { config } from 'dotenv'
import { BamlLogEvent } from '@boundaryml/baml/native'
config()

describe('Integ tests', () => {
Expand Down Expand Up @@ -353,8 +354,8 @@ describe('Integ tests', () => {
})

it("should work with 'onLogEvent'", async () => {
onLogEvent((error: any, param2) => {
console.log('msg', error, 'param2', param2)
onLogEvent((param2) => {
console.log('onLogEvent', param2)
})
const res = await b.TestFnNamedArgsSingleStringList(['a', 'b', 'c'])
expect(res).toContain('a')
Expand Down

0 comments on commit be74999

Please sign in to comment.